Source code for haive.core.models.llm.providers.xai
"""xAI Provider Module.This module implements the xAI language model provider for the Haive framework,supporting Grok models developed by Elon Musk's xAI company.The provider handles API key management, model configuration, and safe imports ofthe langchain-xai package dependencies.Examples: Basic usage: .. code-block:: python from haive.core.models.llm.providers.xai import XAIProvider provider = XAIProvider( model="grok-beta", temperature=0.7, max_tokens=1000 ) llm = provider.instantiate() With custom parameters:: provider = XAIProvider( model="grok-1", temperature=0.1, top_p=0.9, stream=True ).. autosummary:: :toctree: generated/ XAIProvider"""fromtypingimportAnyfrompydanticimportFieldfromhaive.core.models.llm.provider_typesimportLLMProviderfromhaive.core.models.llm.providers.baseimportBaseLLMProvider,ProviderImportError
[docs]classXAIProvider(BaseLLMProvider):"""xAI language model provider configuration. This provider supports xAI's Grok family of models known for their real-time information access and conversational capabilities. Attributes: provider (LLMProvider): Always LLMProvider.XAI model (str): The xAI model to use temperature (float): Sampling temperature (0.0-2.0) max_tokens (int): Maximum tokens in response top_p (float): Nucleus sampling parameter stream (bool): Enable streaming responses stop (list): Stop sequences for generation Examples: Grok Beta for general conversation: .. code-block:: python provider = XAIProvider( model="grok-beta", temperature=0.7, max_tokens=2000 ) Grok with streaming:: provider = XAIProvider( model="grok-1", temperature=0.1, stream=True, top_p=0.9 ) """provider:LLMProvider=Field(default=LLMProvider.XAI,description="Provider identifier")# xAI model parameterstemperature:float|None=Field(default=None,ge=0,le=2,description="Sampling temperature")max_tokens:int|None=Field(default=None,ge=1,description="Maximum tokens in response")top_p:float|None=Field(default=None,ge=0,le=1,description="Nucleus sampling parameter")stream:bool=Field(default=False,description="Enable streaming responses")stop:list[str]|None=Field(default=None,description="Stop sequences for generation")def_get_chat_class(self)->type[Any]:"""Get the xAI chat class."""try:fromlangchain_xaiimportChatXAIreturnChatXAIexceptImportErrorase:raiseProviderImportError(provider=self.provider.value,package=self._get_import_package(),message="xAI requires langchain-xai. Install with: pip install langchain-xai",)fromedef_get_default_model(self)->str:"""Get the default xAI model."""return"grok-beta"def_get_import_package(self)->str:"""Get the required package name."""return"langchain-xai"def_get_initialization_params(self,**kwargs)->dict[str,Any]:"""Get xAI-specific initialization parameters."""params={"model":self.model,**kwargs,}# Add model parameters if specifiedifself.temperatureisnotNone:params["temperature"]=self.temperatureifself.max_tokensisnotNone:params["max_tokens"]=self.max_tokensifself.top_pisnotNone:params["top_p"]=self.top_pifself.streamisnotNone:params["streaming"]=self.streamifself.stopisnotNone:params["stop"]=self.stop# Add API keyapi_key=self.get_api_key()ifapi_key:params["xai_api_key"]=api_key# Add extra paramsparams.update(self.extra_paramsor{})returnparamsdef_get_env_key_name(self)->str:"""Get the environment variable name for API key."""return"XAI_API_KEY"
[docs]@classmethoddefget_models(cls)->list[str]:"""Get available xAI models."""return["grok-beta","grok-1","grok-vision-beta"]