Source code for haive.core.models.llm.providers.google

"""Google AI Providers Module.

This module implements both Google Generative AI (Gemini) and Vertex AI providers
for the Haive framework. It supports Gemini models through the standard API and
enterprise Vertex AI deployments.

The providers handle API key management, model configuration, and safe imports of
the langchain-google packages.

Examples:
    Using Gemini:

    .. code-block:: python

        from haive.core.models.llm.providers.google import GeminiProvider

        provider = GeminiProvider(
            model="gemini-1.5-pro",
            temperature=0.7
        )
        llm = provider.instantiate()

    Using Vertex AI::

        from haive.core.models.llm.providers.google import VertexAIProvider

        provider = VertexAIProvider(
            model="gemini-1.5-pro",
            project="my-project",
            location="us-central1"
        )
        llm = provider.instantiate()

.. autosummary::
   :toctree: generated/

   GeminiProvider
   VertexAIProvider
"""

import os
from typing import Any

from pydantic import Field

from haive.core.models.llm.provider_types import LLMProvider
from haive.core.models.llm.providers.base import BaseLLMProvider, ProviderImportError


[docs] class GeminiProvider(BaseLLMProvider): """Google Gemini language model provider configuration. This provider supports Google's Gemini models through the Generative AI API. It's suitable for general use with API key authentication. Attributes: provider: Always LLMProvider.GEMINI model: Model name (default: "gemini-1.5-pro") temperature: Sampling temperature (0-1) max_output_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter top_k: Top-k sampling parameter n: Number of responses to generate Environment Variables: GOOGLE_API_KEY: API key for authentication GEMINI_API_KEY: Alternative API key environment variable Model Variants: - gemini-1.5-pro: Most capable, 1M token context - gemini-1.5-flash: Faster, more efficient - gemini-pro: Previous generation - gemini-pro-vision: Multimodal support Examples: Basic usage: .. code-block:: python provider = GeminiProvider( model="gemini-1.5-pro", temperature=0.7, max_output_tokens=2048 ) llm = provider.instantiate() With advanced sampling:: provider = GeminiProvider( model="gemini-1.5-flash", temperature=0.9, top_p=0.95, top_k=40 ) """ provider: LLMProvider = Field( default=LLMProvider.GEMINI, description="Provider identifier" ) # Gemini specific parameters temperature: float | None = Field( default=None, ge=0, le=1, description="Sampling temperature" ) max_output_tokens: int | None = Field( default=None, ge=1, description="Maximum tokens to generate" ) top_p: float | None = Field( default=None, ge=0, le=1, description="Nucleus sampling parameter" ) top_k: int | None = Field( default=None, ge=1, description="Top-k sampling parameter" ) n: int | None = Field(default=None, ge=1, description="Number of responses") def _get_chat_class(self) -> type[Any]: """Get the Gemini chat class. Returns: ChatGoogleGenerativeAI class Raises: ProviderImportError: If langchain-google-genai is not installed """ try: from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI except ImportError as e: raise ProviderImportError( provider="Google Gemini", package="langchain-google-genai" ) from e def _get_default_model(self) -> str: """Get the default model for Gemini. Returns: Default model name """ return "gemini-1.5-pro" def _get_import_package(self) -> str: """Get the pip package name. Returns: Package name for installation """ return "langchain-google-genai" def _get_env_key_name(self) -> str: """Get environment variable name for API key. Returns: Environment variable name """ # Check both GOOGLE_API_KEY and GEMINI_API_KEY if os.getenv("GEMINI_API_KEY"): return "GEMINI_API_KEY" return "GOOGLE_API_KEY" def _get_initialization_params(self, **kwargs) -> dict: """Get initialization parameters for the LLM. Returns: Dictionary of initialization parameters """ params = super()._get_initialization_params(**kwargs) # Add Gemini-specific parameters if set optional_params = ["temperature", "max_output_tokens", "top_p", "top_k", "n"] for param in optional_params: value = getattr(self, param) if value is not None: params[param] = value return params def _get_api_key_param_name(self) -> str | None: """Get the parameter name for API key. Returns: The parameter name for Google API key """ return "google_api_key"
[docs] @classmethod def get_models(cls) -> list[str]: """Get available Gemini models. Returns: List of available model names """ return [ "gemini-1.5-pro", "gemini-1.5-flash", "gemini-pro", "gemini-pro-vision", ]
[docs] class VertexAIProvider(BaseLLMProvider): """Google Vertex AI language model provider configuration. This provider supports Google's models through Vertex AI, suitable for enterprise deployments with project-based authentication and regional control. Attributes: provider: Always LLMProvider.VERTEX_AI model: Model name (default: "gemini-1.5-pro") project: Google Cloud project ID location: Google Cloud region (default: "us-central1") temperature: Sampling temperature (0-1) max_output_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter top_k: Top-k sampling parameter Environment Variables: GOOGLE_CLOUD_PROJECT: Default project ID GOOGLE_APPLICATION_CREDENTIALS: Path to service account JSON Authentication: Vertex AI uses Google Cloud authentication. You can authenticate by: 1. Setting GOOGLE_APPLICATION_CREDENTIALS to service account key path 2. Using gcloud auth application-default login 3. Running on Google Cloud with appropriate IAM roles Examples: Basic usage: .. code-block:: python provider = VertexAIProvider( model="gemini-1.5-pro", project="my-project", location="us-central1" ) llm = provider.instantiate() With custom parameters:: provider = VertexAIProvider( model="gemini-1.5-flash", project="my-project", location="europe-west1", temperature=0.5, max_output_tokens=1024 ) """ provider: LLMProvider = Field( default=LLMProvider.VERTEX_AI, description="Provider identifier" ) # Vertex AI specific parameters project: str | None = Field(default=None, description="Google Cloud project ID") location: str = Field(default="us-central1", description="Google Cloud region") temperature: float | None = Field( default=None, ge=0, le=1, description="Sampling temperature" ) max_output_tokens: int | None = Field( default=None, ge=1, description="Maximum tokens to generate" ) top_p: float | None = Field( default=None, ge=0, le=1, description="Nucleus sampling parameter" ) top_k: int | None = Field( default=None, ge=1, description="Top-k sampling parameter" ) def _get_chat_class(self) -> type[Any]: """Get the Vertex AI chat class. Returns: ChatVertexAI class Raises: ProviderImportError: If langchain-google-vertexai is not installed """ try: from langchain_google_vertexai import ChatVertexAI return ChatVertexAI except ImportError as e: raise ProviderImportError( provider="Google Vertex AI", package="langchain-google-vertexai" ) from e def _get_default_model(self) -> str: """Get the default model for Vertex AI. Returns: Default model name """ return "gemini-1.5-pro" def _get_import_package(self) -> str: """Get the pip package name. Returns: Package name for installation """ return "langchain-google-vertexai" def _requires_api_key(self) -> bool: """Check if this provider requires an API key. Returns: False - Vertex AI uses Google Cloud auth, not API keys """ return False def _validate_config(self) -> None: """Validate Vertex AI configuration. Raises: ValueError: If project ID is not set """ if not self.project: # Try to get from environment self.project = os.getenv("GOOGLE_CLOUD_PROJECT") if not self.project: raise ValueError( "Google Cloud Project ID is required. " "Set the 'project' parameter or GOOGLE_CLOUD_PROJECT environment variable." ) def _get_initialization_params(self, **kwargs) -> dict: """Get initialization parameters for the LLM. Returns: Dictionary of initialization parameters """ params = super()._get_initialization_params(**kwargs) # Remove API key related params params.pop("api_key", None) params.pop("cache", None) # Vertex AI doesn't use cache param # Add required Vertex AI parameters params["project"] = self.project params["location"] = self.location # Use model_name instead of model params["model_name"] = params.pop("model") # Add optional parameters if set optional_params = ["temperature", "max_output_tokens", "top_p", "top_k"] for param in optional_params: value = getattr(self, param) if value is not None: params[param] = value return params
[docs] @classmethod def get_models(cls) -> list[str]: """Get available Vertex AI models. Returns: List of available model names """ return [ "gemini-1.5-pro", "gemini-1.5-flash", "gemini-pro", "gemini-pro-vision", "text-bison", "text-bison-32k", "code-bison", "code-bison-32k", ]