"""Base provider module for LLM configurations.
This module provides the base classes and utilities for all LLM provider implementations
in the Haive framework. It includes the base configuration class with metadata support,
rate limiting capabilities, and common functionality shared across all providers.
The module structure ensures consistent interfaces, proper error handling for optional
dependencies, and clean separation of concerns between different LLM providers.
Classes:
BaseLLMProvider: Abstract base class for all LLM provider configurations
ProviderImportError: Custom exception for provider import failures
Examples:
Creating a custom provider:
.. code-block:: python
from haive.core.models.llm.providers.base import BaseLLMProvider
from haive.core.models.llm.provider_types import LLMProvider
class CustomLLMProvider(BaseLLMProvider):
provider = LLMProvider.CUSTOM
def _get_chat_class(self):
from langchain_custom import ChatCustom
return ChatCustom
def _get_default_model(self):
return "custom-model-v1"
.. autosummary::
:toctree: generated/
BaseLLMProvider
ProviderImportError
"""
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Self
from pydantic import BaseModel, Field, SecretStr, field_validator, model_validator
from haive.core.common.mixins.secure_config import SecureConfigMixin
from haive.core.models.llm.provider_types import LLMProvider
from haive.core.models.llm.rate_limiting_mixin import RateLimitingMixin
from haive.core.models.metadata_mixin import ModelMetadataMixin
logger = logging.getLogger(__name__)
[docs]
class ProviderImportError(ImportError):
"""Custom exception for provider-specific import failures.
This exception provides clearer error messages when LLM provider
dependencies are not installed, including the package name needed
for installation.
Attributes:
provider: The provider that failed to import
package: The package name to install
message: Custom error message
"""
def __init__(self, provider: str, package: str, message: str | None = None):
"""Initialize the provider import error.
Args:
provider: Name of the provider
package: Package name for pip install
message: Optional custom message
"""
self.provider = provider
self.package = package
if message is None:
message = f"{provider} provider is not available. Please install it with: pip install {package}"
super().__init__(message)
[docs]
class BaseLLMProvider(
SecureConfigMixin, ModelMetadataMixin, RateLimitingMixin, BaseModel, ABC
):
"""Abstract base class for all LLM provider configurations.
This class provides the common functionality and interface that all
LLM provider implementations must follow. It includes:
- Secure API key management with environment variable fallbacks
- Model metadata access (context windows, capabilities, pricing)
- Rate limiting configuration
- Common configuration parameters
- Safe import handling for optional dependencies
Subclasses must implement:
- _get_chat_class(): Return the LangChain chat class
- _get_default_model(): Return the default model name
- _get_import_package(): Return the pip package name
Attributes:
provider: The LLM provider enum value
model: The specific model identifier
name: Optional friendly name for the model
api_key: Secure storage of API key with env fallback
cache_enabled: Whether to enable response caching
cache_ttl: Time-to-live for cached responses
extra_params: Additional provider-specific parameters
debug: Enable detailed debug output
Examples:
Creating a provider configuration:
.. code-block:: python
from haive.core.models.llm.providers.openai import OpenAIProvider
provider = OpenAIProvider(
model="gpt-4",
temperature=0.7,
max_tokens=1000
)
llm = provider.instantiate()
"""
provider: LLMProvider = Field(..., description="The LLM provider identifier")
model: str | None = Field(None, description="The model to use")
name: str | None = Field(None, description="Friendly display name")
api_key: SecretStr = Field(
default_factory=lambda: SecretStr(""), description="API key for the provider"
)
cache_enabled: bool = Field(default=True, description="Enable response caching")
cache_ttl: int | None = Field(
default=300, description="Cache time-to-live in seconds"
)
extra_params: dict[str, Any] | None = Field(
default_factory=dict, description="Additional provider-specific parameters"
)
debug: bool = Field(default=False, description="Enable debug output")
requests_per_second: float | None = Field(
default=None,
description="Maximum number of requests per second. None means no limit.",
ge=0,
)
tokens_per_second: int | None = Field(
default=None,
description="Maximum number of tokens per second. None means no limit.",
ge=0,
)
tokens_per_minute: int | None = Field(
default=None,
description="Maximum number of tokens per minute. None means no limit.",
ge=0,
)
max_retries: int = Field(
default=3,
description="Maximum number of retries for rate-limited requests.",
ge=0,
)
retry_delay: float = Field(
default=1.0, description="Base delay between retries in seconds.", ge=0
)
check_every_n_seconds: float | None = Field(
default=None,
description="How often to check rate limits. None uses default.",
ge=0,
)
burst_size: int | None = Field(
default=None,
description="Maximum burst size for rate limiting. None uses default.",
ge=1,
)
model_config = {"arbitrary_types_allowed": True}
[docs]
@model_validator(mode="after")
def set_defaults(self) -> Self:
"""Set default values after initialization.
This validator ensures that model and name have appropriate
default values if not provided during initialization.
Returns:
The validated instance
"""
if self.model is None:
self.model = self._get_default_model()
if self.name is None:
self.name = self.model
return self
@abstractmethod
def _get_chat_class(self) -> type[Any]:
"""Get the LangChain chat class for this provider.
This method must be implemented by each provider to return
the appropriate LangChain chat class. It should handle imports
and raise ProviderImportError if dependencies are missing.
Returns:
The LangChain chat class
Raises:
ProviderImportError: If required dependencies are not installed
"""
@abstractmethod
def _get_default_model(self) -> str:
"""Get the default model name for this provider.
Returns:
The default model identifier
"""
@abstractmethod
def _get_import_package(self) -> str:
"""Get the pip package name for this provider.
Returns:
The package name for pip install
"""
def _get_env_key_name(self) -> str:
"""Get the environment variable name for API key.
Returns:
The environment variable name (e.g., OPENAI_API_KEY)
"""
provider_upper = self.provider.value.upper()
if provider_upper == "TOGETHER_AI":
return "TOGETHER_AI_API_KEY"
if provider_upper == "FIREWORKS_AI":
return "FIREWORKS_AI_API_KEY"
return f"{provider_upper}_API_KEY"
[docs]
@field_validator("api_key")
@classmethod
def load_api_key(cls, v: SecretStr, info) -> SecretStr:
"""Load API key from environment if not provided.
Args:
v: The provided API key value
info: Validation info containing the instance
Returns:
The API key (from input or environment)
"""
if v.get_secret_value() == "" and hasattr(info, "data"):
provider = info.data.get("provider")
if provider:
env_key = f"{provider.value.upper()}_API_KEY"
env_value = os.getenv(env_key, "")
return SecretStr(env_value)
return v
def _get_initialization_params(self, **kwargs) -> dict[str, Any]:
"""Get parameters for initializing the LLM.
This method prepares all parameters needed to instantiate
the LangChain chat model, including model name, API key,
and any provider-specific parameters.
Args:
**kwargs: Additional parameters to include
Returns:
Dictionary of initialization parameters
"""
params = {
"model": self.model,
"cache": self.cache_enabled,
**(self.extra_params or {}),
**kwargs,
}
api_key = self.get_api_key()
if api_key:
api_key_param = self._get_api_key_param_name()
if api_key_param:
params[api_key_param] = api_key
return params
def _get_api_key_param_name(self) -> str | None:
"""Get the parameter name for API key.
Different providers use different parameter names for API keys.
This method returns the appropriate parameter name.
Returns:
The parameter name for API key, or None if no key needed
"""
return "api_key"
[docs]
def instantiate(self, **kwargs) -> Any:
"""Instantiate the LLM with rate limiting if configured.
This method creates an instance of the LLM using the provider's
chat class and configuration. It also applies rate limiting
if any rate limit parameters are configured.
Args:
**kwargs: Additional parameters to pass to the LLM
Returns:
The instantiated LLM, potentially wrapped with rate limiting
Raises:
ProviderImportError: If provider dependencies are not installed
ValueError: If required configuration is missing
RuntimeError: If instantiation fails
"""
try:
chat_class = self._get_chat_class()
except ImportError as e:
raise ProviderImportError(
provider=self.provider.value, package=self._get_import_package()
) from e
self._validate_config()
params = self._get_initialization_params(**kwargs)
try:
llm = chat_class(**params)
except Exception as e:
logger.exception(
f"Failed to instantiate {self.provider.value} model: {e!s}"
)
raise RuntimeError(
f"Failed to instantiate {self.provider.value} model: {e!s}"
) from e
llm = self.apply_rate_limiting(llm)
return llm
def _validate_config(self) -> None:
"""Validate the configuration before instantiation.
This method checks that all required configuration is present
and valid. Subclasses can override to add provider-specific
validation.
Raises:
ValueError: If configuration is invalid
"""
if self._requires_api_key() and (not self.get_api_key()):
env_key = self._get_env_key_name()
raise ValueError(
f"{self.provider.value} API key is required. Please set {env_key} environment variable or provide an API key."
)
def _requires_api_key(self) -> bool:
"""Check if this provider requires an API key.
Returns:
True if API key is required (default), False otherwise
"""
return True
[docs]
@classmethod
def get_models(cls) -> list[str]:
"""Get available models for this provider.
This method attempts to retrieve the list of available models
from the provider's API. Not all providers support this.
Returns:
List of available model names
Raises:
NotImplementedError: If provider doesn't support listing models
"""
raise NotImplementedError(
f"Provider {cls.__name__} does not support listing models"
)