Source code for haive.core.models.metadata_mixin
"""Model metadata mixin for LLM configurations.
This module provides a mixin class that adds comprehensive model metadata
access to LLM configuration classes, including context windows, pricing,
and capability information.
"""
import logging
from typing import Any
from haive.core.models.metadata import get_model_metadata
logger = logging.getLogger(__name__)
[docs]
class ModelMetadataMixin:
"""Mixin to add comprehensive model metadata methods to LLMConfig classes.
This mixin provides access to model capabilities, context window sizes,
pricing information, and other metadata from the model catalog.
"""
[docs]
def get_context_window(self) -> int:
"""Get the maximum context window size for this model.
Returns:
int: Total context window size (input + output tokens)
"""
metadata = self._get_model_metadata()
# Return max_tokens if specified
if "max_tokens" in metadata:
return metadata.get("max_tokens", 0)
# Otherwise calculate from input + output tokens
input_tokens = metadata.get("max_input_tokens", 0)
output_tokens = metadata.get("max_output_tokens", 0)
# If both are specified, return their sum
if input_tokens > 0 and output_tokens > 0:
return input_tokens + output_tokens
# If only one is specified, return that
if input_tokens > 0:
return input_tokens
if output_tokens > 0:
return output_tokens
# Default fallback based on common models
getattr(self, "model", "").lower()
return 0
[docs]
def get_max_input_tokens(self) -> int:
"""Get the maximum input tokens for this model.
Returns:
int: Maximum input tokens the model can accept
"""
metadata = self._get_model_metadata()
return metadata.get("max_input_tokens", self.get_context_window())
[docs]
def get_max_output_tokens(self) -> int:
"""Get the maximum output tokens for this model.
Returns:
int: Maximum output tokens the model can generate
"""
metadata = self._get_model_metadata()
return metadata.get("max_output_tokens", metadata.get("max_tokens", 0))
[docs]
def get_token_pricing(self) -> tuple[float, float]:
"""Get the token pricing for this model.
Returns:
Tuple[float, float]: (input_cost_per_token, output_cost_per_token)
"""
metadata = self._get_model_metadata()
input_cost = metadata.get("input_cost_per_token", 0.0)
output_cost = metadata.get("output_cost_per_token", 0.0)
return (input_cost, output_cost)
[docs]
def get_batch_token_pricing(self) -> tuple[float, float]:
"""Get the batch token pricing for this model.
Returns:
Tuple[float, float]: (input_batch_cost, output_batch_cost)
"""
metadata = self._get_model_metadata()
input_cost = metadata.get("input_cost_per_token_batches", 0.0)
output_cost = metadata.get("output_cost_per_token_batches", 0.0)
return (input_cost, output_cost)
[docs]
def supports_feature(self, feature: str) -> bool:
"""Check if this model supports a specific feature.
Args:
feature: Feature name (e.g., "vision", "function_calling")
Returns:
bool: True if the model supports the feature, False otherwise
"""
metadata = self._get_model_metadata()
# Check for "supports_X" format
feature_key = f"supports_{feature}"
if feature_key in metadata:
return bool(metadata[feature_key])
# Check for specific features in other formats
if feature == "web_search" and "search_context_cost_per_query" in metadata:
return True
# Check supported modalities
if (
feature in ["text", "image", "video", "audio"]
and "supported_modalities" in metadata
):
return feature in metadata["supported_modalities"]
return False
[docs]
def get_search_context_costs(self) -> dict[str, float]:
"""Get the search context costs for this model.
Returns:
Dict[str, float]: Dictionary mapping context sizes to costs
"""
metadata = self._get_model_metadata()
search_costs = metadata.get("search_context_cost_per_query", {})
return search_costs
[docs]
def get_supported_endpoints(self) -> list[str]:
"""Get the supported API endpoints for this model.
Returns:
List[str]: List of supported endpoints
"""
metadata = self._get_model_metadata()
return metadata.get("supported_endpoints", [])
[docs]
def get_supported_modalities(self) -> list[str]:
"""Get the supported input modalities for this model.
Returns:
List[str]: List of supported modalities (e.g., "text", "image")
"""
metadata = self._get_model_metadata()
return metadata.get("supported_modalities", ["text"])
[docs]
def get_supported_output_modalities(self) -> list[str]:
"""Get the supported output modalities for this model.
Returns:
List[str]: List of supported output modalities
"""
metadata = self._get_model_metadata()
return metadata.get("supported_output_modalities", ["text"])
[docs]
def get_deprecation_date(self) -> str | None:
"""Get the deprecation date for this model, if available.
Returns:
Optional[str]: Deprecation date in YYYY-MM-DD format, or None if not deprecated
"""
metadata = self._get_model_metadata()
return metadata.get("deprecation_date")
[docs]
def get_model_mode(self) -> str:
"""Get the mode for this model.
Returns:
str: Model mode (e.g., "chat", "embedding", "completion")
"""
metadata = self._get_model_metadata()
return metadata.get("mode", "chat")
def _get_model_metadata(self) -> dict[str, Any]:
"""Get metadata for the current model, using alias if available.
Returns:
Dictionary of model metadata
"""
# Use model_alias if available, otherwise use model name
model_name_for_lookup = getattr(self, "model_alias", None) or self.model
# Get provider value, handling different formats
if hasattr(self.provider, "value"):
provider_value = self.provider.value
else:
provider_value = str(self.provider)
# Use the imported function for actual lookup
return get_model_metadata(model_name_for_lookup, provider_value)
# Property getters for common capabilities
@property
def supports_vision(self) -> bool:
"""Check if model supports vision/image inputs."""
return self.supports_feature("vision")
@property
def supports_function_calling(self) -> bool:
"""Check if model supports function calling."""
return self.supports_feature("function_calling")
@property
def supports_parallel_function_calling(self) -> bool:
"""Check if model supports parallel function calling."""
return self.supports_feature("parallel_function_calling")
@property
def supports_system_messages(self) -> bool:
"""Check if model supports system messages."""
return self.supports_feature("system_messages")
@property
def supports_tool_choice(self) -> bool:
"""Check if model supports tool choice."""
return self.supports_feature("tool_choice")
@property
def supports_response_schema(self) -> bool:
"""Check if model supports response schema."""
return self.supports_feature("response_schema")
@property
def supports_reasoning(self) -> bool:
"""Check if model supports reasoning."""
return self.supports_feature("reasoning")
@property
def supports_web_search(self) -> bool:
"""Check if model supports web search."""
return self.supports_feature("web_search")
@property
def supports_audio_input(self) -> bool:
"""Check if model supports audio input."""
return self.supports_feature("audio_input")
@property
def supports_audio_output(self) -> bool:
"""Check if model supports audio output."""
return self.supports_feature("audio_output")
@property
def supports_pdf_input(self) -> bool:
"""Check if model supports PDF input."""
return self.supports_feature("pdf_input")
@property
def supports_prompt_caching(self) -> bool:
"""Check if model supports prompt caching."""
return self.supports_feature("prompt_caching")
@property
def supports_native_streaming(self) -> bool:
"""Check if model supports native streaming."""
return self.supports_feature("native_streaming")
@property
def max_tokens(self) -> int:
"""Get maximum total tokens for this model."""
return self.get_context_window()
@property
def max_input_tokens(self) -> int:
"""Get maximum input tokens for this model."""
return self.get_max_input_tokens()
@property
def max_output_tokens(self) -> int:
"""Get maximum output tokens for this model."""
return self.get_max_output_tokens()