Source code for haive.core.models.llm.providers.ai21

"""AI21 Labs Provider Module.

This module implements the AI21 Labs language model provider for the Haive framework,
supporting Jurassic models known for their strong performance on various NLP tasks.

The provider handles API key management, model configuration, and safe imports of
the langchain-ai21 package dependencies.

Examples:
    Basic usage:

    .. code-block:: python

        from haive.core.models.llm.providers.ai21 import AI21Provider

        provider = AI21Provider(
            model="j2-ultra",
            temperature=0.7,
            max_tokens=1000
        )
        llm = provider.instantiate()

    With custom parameters:

    .. code-block:: python

        provider = AI21Provider(
            model="j2-grande-instruct",
            temperature=0.1,
            top_p=0.9,
            frequency_penalty=0.2
        )

.. autosummary::
   :toctree: generated/

   AI21Provider
"""

from typing import Any

from pydantic import Field

from haive.core.models.llm.provider_types import LLMProvider
from haive.core.models.llm.providers.base import BaseLLMProvider, ProviderImportError


[docs] class AI21Provider(BaseLLMProvider): """AI21 Labs language model provider configuration. This provider supports AI21's Jurassic family of models including J2-Ultra, J2-Mid, and instruction-tuned variants optimized for various tasks. Attributes: provider (LLMProvider): Always LLMProvider.AI21 model (str): The AI21 model to use temperature (float): Sampling temperature (0.0-2.0) max_tokens (int): Maximum tokens in response top_p (float): Nucleus sampling parameter top_k_return (int): Number of top tokens to consider frequency_penalty (dict): Frequency penalty settings presence_penalty (dict): Presence penalty settings count_penalty (dict): Count penalty settings Examples: Ultra model for complex tasks: .. code-block:: python provider = AI21Provider( model="j2-ultra", temperature=0.7, max_tokens=2000 ) Instruct model with penalties: .. code-block:: python provider = AI21Provider( model="j2-grande-instruct", temperature=0.1, frequency_penalty={"scale": 0.2, "apply_to_whitespaces": False} ) """ provider: LLMProvider = Field( default=LLMProvider.AI21, description="Provider identifier" ) # AI21 model parameters temperature: float | None = Field( default=None, ge=0, le=2, description="Sampling temperature" ) max_tokens: int | None = Field( default=None, ge=1, description="Maximum tokens in response" ) top_p: float | None = Field( default=None, ge=0, le=1, description="Nucleus sampling parameter" ) top_k_return: int | None = Field( default=None, ge=1, description="Number of top tokens to return probabilities for", ) frequency_penalty: dict[str, Any] | None = Field( default=None, description="Frequency penalty settings" ) presence_penalty: dict[str, Any] | None = Field( default=None, description="Presence penalty settings" ) count_penalty: dict[str, Any] | None = Field( default=None, description="Count penalty settings" ) def _get_chat_class(self) -> type[Any]: """Get the AI21 chat class.""" try: from langchain_ai21 import ChatAI21 return ChatAI21 except ImportError as e: raise ProviderImportError( provider=self.provider.value, package=self._get_import_package(), message="AI21 requires langchain-ai21. Install with: pip install langchain-ai21", ) from e def _get_default_model(self) -> str: """Get the default AI21 model.""" return "j2-ultra" def _get_import_package(self) -> str: """Get the required package name.""" return "langchain-ai21" def _get_initialization_params(self, **kwargs) -> dict[str, Any]: """Get AI21-specific initialization parameters.""" params = { "model": self.model, **kwargs, } # Add model parameters if specified if self.temperature is not None: params["temperature"] = self.temperature if self.max_tokens is not None: params["max_tokens"] = self.max_tokens if self.top_p is not None: params["top_p"] = self.top_p if self.top_k_return is not None: params["top_k_return"] = self.top_k_return if self.frequency_penalty is not None: params["frequency_penalty"] = self.frequency_penalty if self.presence_penalty is not None: params["presence_penalty"] = self.presence_penalty if self.count_penalty is not None: params["count_penalty"] = self.count_penalty # Add API key api_key = self.get_api_key() if api_key: params["api_key"] = api_key # Add extra params params.update(self.extra_params or {}) return params def _get_env_key_name(self) -> str: """Get the environment variable name for API key.""" return "AI21_API_KEY"
[docs] @classmethod def get_models(cls) -> list[str]: """Get available AI21 models.""" return [ "j2-ultra", "j2-mid", "j2-light", "j2-grande", "j2-grande-instruct", "j2-jumbo-instruct", ]