Source code for haive.core.models.metadata
"""Model metadata utilities for LLM configurations.
This module provides utilities for downloading, caching, and accessing
model metadata from LiteLLM's model_prices_and_context_window.json.
"""
import json
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
import requests
[docs]
@dataclass
class ModelMetadata:
"""A class to store and provide model metadata.
This class encapsulates metadata about a language model, including
its pricing, context window limits, and provider information.
"""
name: str
provider: str | None = None
metadata: dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = get_model_metadata(self.name, self.provider)
@property
def context_window(self) -> int:
"""Get the context window size for this model."""
return self.metadata.get("context_window", 2048)
@property
def pricing(self) -> dict[str, float]:
"""Get the pricing information for this model."""
return {
"input": self.metadata.get("input_cost_per_token", 0.0),
"output": self.metadata.get("output_cost_per_token", 0.0),
}
logger = logging.getLogger(__name__)
# Singleton metadata cache
_MODEL_METADATA_CACHE = {}
_METADATA_LAST_UPDATED = None
_METADATA_CACHE_TTL = timedelta(hours=24)
_METADATA_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
_METADATA_CACHE_FILE = Path.home() / ".haive" / "cache" / "model_metadata.json"
def _ensure_cache_dir():
"""Ensure the cache directory exists."""
cache_dir = _METADATA_CACHE_FILE.parent
cache_dir.mkdir(parents=True, exist_ok=True)
def _load_metadata_from_cache() -> dict[str, Any]:
"""Load metadata from cache file."""
try:
if _METADATA_CACHE_FILE.exists():
with open(_METADATA_CACHE_FILE) as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load metadata from cache: {e}")
return {}
def _save_metadata_to_cache(metadata: dict[str, Any]) -> None:
"""Save metadata to cache file."""
try:
_ensure_cache_dir()
with open(_METADATA_CACHE_FILE, "w") as f:
json.dump(metadata, f, indent=2)
except Exception as e:
logger.warning(f"Failed to save metadata to cache: {e}")
def _download_metadata() -> dict[str, Any]:
"""Download metadata from LiteLLM's repository."""
try:
response = requests.get(_METADATA_URL, timeout=10)
response.raise_for_status()
data = response.json()
return data
except Exception as e:
logger.warning(f"Failed to download metadata: {e}")
return {}
[docs]
def get_model_metadata(
model_name: str, provider: str | None = None, force_refresh: bool = False
) -> dict[str, Any]:
"""Get metadata for a specific model with improved matching.
This function tries to find the most relevant model metadata based on the
model name and provider, with multiple fallback strategies.
Args:
model_name: Name of the model (e.g., "gpt-4", "claude-3-opus")
provider: Optional provider prefix (e.g., "azure", "anthropic")
force_refresh: Force download fresh metadata
Returns:
Model metadata dictionary or empty dict if not found
"""
global _MODEL_METADATA_CACHE, _METADATA_LAST_UPDATED
# Check if we need to refresh metadata
if (
force_refresh
or not _MODEL_METADATA_CACHE
or (
_METADATA_LAST_UPDATED is None
or datetime.now() - _METADATA_LAST_UPDATED > _METADATA_CACHE_TTL
)
):
# Try to load from cache first
if not force_refresh:
cache_data = _load_metadata_from_cache()
if cache_data:
_MODEL_METADATA_CACHE = cache_data
_METADATA_LAST_UPDATED = datetime.now()
# Download if still needed
if force_refresh or not _MODEL_METADATA_CACHE:
fresh_data = _download_metadata()
if fresh_data:
_MODEL_METADATA_CACHE = fresh_data
_METADATA_LAST_UPDATED = datetime.now()
_save_metadata_to_cache(fresh_data)
# No metadata available
if not _MODEL_METADATA_CACHE:
return {}
# Normalize inputs
normalized_model = model_name.lower().strip()
normalized_provider = provider.lower().strip() if provider else None
# Try different lookup strategies
# 1. Exact match with provider prefix
if provider:
provider_model_name = f"{normalized_provider}/{normalized_model}"
if provider_model_name in _MODEL_METADATA_CACHE:
return _MODEL_METADATA_CACHE[provider_model_name]
# Try with different provider format
alternative_provider_name = f"{normalized_provider}-{normalized_model}"
if alternative_provider_name in _MODEL_METADATA_CACHE:
return _MODEL_METADATA_CACHE[alternative_provider_name]
# 2. Exact match without provider
if normalized_model in _MODEL_METADATA_CACHE:
return _MODEL_METADATA_CACHE[normalized_model]
# 3. Best match search - check for model name contained within keys
# Priority: provider+model > model > model base version
candidates = []
for key in _MODEL_METADATA_CACHE:
if key == "sample_spec":
continue
# Skip entries with wrong provider
if provider and key.startswith(f"{normalized_provider}/"):
# Provider prefix match - higher priority
if normalized_model in key.lower():
# Add with high score (exact provider match)
candidates.append((key, 100 + len(normalized_model)))
elif provider and "litellm_provider" in _MODEL_METADATA_CACHE[key]:
# Check internal provider field
if (
_MODEL_METADATA_CACHE[key]["litellm_provider"].lower()
== normalized_provider
) and normalized_model in key.lower():
# Add with high score (internal provider match)
candidates.append((key, 90 + len(normalized_model)))
elif normalized_model in key.lower():
# Add with medium score (model name match)
candidates.append((key, 50 + len(normalized_model)))
# Base model check (without version)
elif any(segment in key.lower() for segment in normalized_model.split("-")):
# Add with lower score (partial match)
candidates.append((key, 20))
# Return best match if we have candidates
if candidates:
# Sort by score descending
candidates.sort(key=lambda x: x[1], reverse=True)
return _MODEL_METADATA_CACHE[candidates[0][0]]
# 4. Fallback - return empty dictionary
logger.warning(f"No metadata found for model {model_name} with provider {provider}")
return {}