Source code for haive.core.engine.embedding.providers.GoogleVertexAIEmbeddingConfig

"""Google Vertex AI embedding configuration."""

from typing import Any

from pydantic import Field, field_validator

from haive.core.engine.embedding.base import BaseEmbeddingConfig
from haive.core.engine.embedding.types import EmbeddingType


[docs] @BaseEmbeddingConfig.register(EmbeddingType.GOOGLE_VERTEX_AI) class GoogleVertexAIEmbeddingConfig(BaseEmbeddingConfig): """Configuration for Google Vertex AI embeddings. This configuration provides access to Google's Vertex AI embedding models including text-embedding-004, text-multilingual-embedding-002, and others. Examples: Basic usage: .. code-block:: python config = GoogleVertexAIEmbeddingConfig( name="vertex_embeddings", model="text-embedding-004", project="your-project-id", location="us-central1" ) embeddings = config.instantiate() With custom task type:: config = GoogleVertexAIEmbeddingConfig( name="vertex_embeddings", model="text-embedding-004", project="your-project-id", location="us-central1", task_type="SEMANTIC_SIMILARITY" ) Using service account:: config = GoogleVertexAIEmbeddingConfig( name="vertex_embeddings", model="text-embedding-004", project="your-project-id", location="us-central1", credentials_path="/path/to/service-account.json" ) Attributes: embedding_type: Always EmbeddingType.GOOGLE_VERTEX_AI model: Vertex AI model name (e.g., "text-embedding-004") project: Google Cloud project ID location: Google Cloud location/region task_type: Task type for embeddings credentials_path: Path to service account credentials """ embedding_type: EmbeddingType = Field( default=EmbeddingType.GOOGLE_VERTEX_AI, description="The embedding provider type", ) # Google Vertex AI specific fields project: str = Field(..., description="Google Cloud project ID") location: str = Field( default="us-central1", description="Google Cloud location/region" ) task_type: str | None = Field( default=None, description="Task type for embeddings (RETRIEVAL_QUERY, RETRIEVAL_DOCUMENT, etc.)", ) title: str | None = Field(default=None, description="Title for the embedding task") credentials_path: str | None = Field( default=None, description="Path to service account credentials JSON file" ) max_retries: int = Field( default=3, description="Maximum number of retries for API calls" ) request_timeout: float | None = Field( default=None, description="Timeout for API requests in seconds" ) # SecureConfigMixin configuration provider: str = Field( default="google", description="Provider name for API key resolution" )
[docs] @field_validator("model") @classmethod def validate_model(cls, v) -> Any: """Validate the Vertex AI model name.""" valid_models = { "text-embedding-004", "text-multilingual-embedding-002", "text-embedding-preview-0409", "text-multilingual-embedding-preview-0409", "textembedding-gecko@001", "textembedding-gecko@003", "textembedding-gecko-multilingual@001", } if v not in valid_models: # Log warning but don't fail - new models may be added import logging logger = logging.getLogger(__name__) logger.warning( f"Unknown Vertex AI embedding model: {v}. Valid models: {valid_models}" ) return v
[docs] @field_validator("task_type") @classmethod def validate_task_type(cls, v) -> Any: """Validate task type.""" if v is None: return v valid_types = { "RETRIEVAL_QUERY", "RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", } if v not in valid_types: raise TypeError(f"Invalid task_type: {v}. Valid types: {valid_types}") return v
[docs] @field_validator("location") @classmethod def validate_location(cls, v) -> Any: """Validate Google Cloud location.""" if not v or not v.strip(): raise ValueError("Location is required") return v.strip()
[docs] @field_validator("project") @classmethod def validate_project(cls, v) -> Any: """Validate Google Cloud project ID.""" if not v or not v.strip(): raise ValueError("Project ID is required") return v.strip()
[docs] def instantiate(self) -> Any: """Create a Google Vertex AI embeddings instance. Returns: VertexAIEmbeddings instance configured with the provided parameters Raises: ImportError: If langchain-google-vertexai is not installed ValueError: If configuration is invalid """ try: from langchain_google_vertexai import VertexAIEmbeddings except ImportError: raise ImportError( "Google Vertex AI embeddings require the langchain-google-vertexai package. " "Install with: pip install langchain-google-vertexai" ) # Validate configuration self.validate_configuration() # Build kwargs kwargs = { "model_name": self.model, "project": self.project, "location": self.location, "max_retries": self.max_retries, } # Add optional parameters if self.task_type: kwargs["task_type"] = self.task_type if self.title: kwargs["title"] = self.title if self.request_timeout: kwargs["request_timeout"] = self.request_timeout if self.credentials_path: kwargs["credentials"] = self.credentials_path return VertexAIEmbeddings(**kwargs)
[docs] def validate_configuration(self) -> None: """Validate the configuration before instantiation.""" super().validate_configuration() if not self.project: raise ValueError("Project ID is required") if not self.location: raise ValueError("Location is required")
[docs] def get_default_model(self) -> str: """Get the default model for Vertex AI embeddings.""" return "text-embedding-004"
[docs] def get_supported_models(self) -> list[str]: """Get list of supported Vertex AI embedding models.""" return [ "text-embedding-004", "text-multilingual-embedding-002", "text-embedding-preview-0409", "text-multilingual-embedding-preview-0409", "textembedding-gecko@001", "textembedding-gecko@003", "textembedding-gecko-multilingual@001", ]
[docs] def get_model_info(self) -> dict: """Get information about the configured model.""" model_info = { "text-embedding-004": { "dimensions": 768, "max_input_tokens": 3072, "languages": ["English", "100+ languages"], "description": "Latest text embedding model with high performance", }, "text-multilingual-embedding-002": { "dimensions": 768, "max_input_tokens": 2048, "languages": ["100+ languages"], "description": "Multilingual text embedding model", }, "textembedding-gecko@001": { "dimensions": 768, "max_input_tokens": 2048, "languages": ["English", "100+ languages"], "description": "Gecko text embedding model", }, "textembedding-gecko@003": { "dimensions": 768, "max_input_tokens": 2048, "languages": ["English", "100+ languages"], "description": "Gecko text embedding model v3", }, "textembedding-gecko-multilingual@001": { "dimensions": 768, "max_input_tokens": 2048, "languages": ["100+ languages"], "description": "Multilingual Gecko text embedding model", }, } return model_info.get( self.model, { "dimensions": "unknown", "max_input_tokens": "unknown", "languages": "unknown", "description": "Vertex AI embedding model", }, )
[docs] def get_task_types(self) -> list[str]: """Get list of supported task types.""" return [ "RETRIEVAL_QUERY", "RETRIEVAL_DOCUMENT", "SEMANTIC_SIMILARITY", "CLASSIFICATION", "CLUSTERING", ]
[docs] def get_supported_locations(self) -> list[str]: """Get list of supported Google Cloud locations.""" return [ "us-central1", "us-east1", "us-west1", "us-west4", "europe-west1", "europe-west4", "asia-east1", "asia-northeast1", "asia-southeast1", ]