Source code for haive.core.engine.embedding.providers.OpenAIEmbeddingConfig

"""OpenAI embedding configuration."""

from typing import Any

from pydantic import Field, field_validator

from haive.core.engine.embedding.base import BaseEmbeddingConfig
from haive.core.engine.embedding.types import EmbeddingType


[docs] @BaseEmbeddingConfig.register(EmbeddingType.OPENAI) class OpenAIEmbeddingConfig(BaseEmbeddingConfig): """Configuration for OpenAI embeddings. This configuration provides access to OpenAI's embedding models including the latest text-embedding-3-large and text-embedding-3-small models. Examples: Basic usage: .. code-block:: python config = OpenAIEmbeddingConfig( name="openai_embeddings", model="text-embedding-3-large", api_key="sk-..." ) embeddings = config.instantiate() With custom dimensions:: config = OpenAIEmbeddingConfig( name="openai_embeddings", model="text-embedding-3-large", dimensions=1536, api_key="sk-..." ) Using environment variables:: # Set OPENAI_API_KEY environment variable config = OpenAIEmbeddingConfig( name="openai_embeddings", model="text-embedding-3-large" ) Attributes: embedding_type: Always EmbeddingType.OPENAI model: OpenAI model name (e.g., "text-embedding-3-large") api_key: OpenAI API key (auto-resolved from OPENAI_API_KEY env var) dimensions: Output dimensions (optional, model-dependent) max_retries: Maximum number of retries for API calls request_timeout: Timeout for API requests in seconds """ embedding_type: EmbeddingType = Field( default=EmbeddingType.OPENAI, description="The embedding provider type" ) # OpenAI-specific fields max_retries: int = Field( default=3, description="Maximum number of retries for API calls" ) request_timeout: float | None = Field( default=None, description="Timeout for API requests in seconds" ) base_url: str | None = Field( default=None, description="Base URL for OpenAI API (for custom endpoints)" ) organization: str | None = Field(default=None, description="OpenAI organization ID") # SecureConfigMixin configuration provider: str = Field( default="openai", description="Provider name for API key resolution" )
[docs] @field_validator("model") @classmethod def validate_model(cls, v) -> Any: """Validate the OpenAI model name.""" valid_models = { "text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002", "text-embedding-ada-001", # legacy } if v not in valid_models: # Log warning but don't fail - new models may be added import logging logger = logging.getLogger(__name__) logger.warning( f"Unknown OpenAI embedding model: {v}. Valid models: {valid_models}" ) return v
[docs] @field_validator("dimensions") @classmethod def validate_dimensions(cls, v, values) -> Any: """Validate dimensions based on model.""" if v is None: return v model = values.get("model") if model == "text-embedding-3-large" and v > 3072: raise ValueError("text-embedding-3-large maximum dimensions: 3072") if model == "text-embedding-3-small" and v > 1536: raise ValueError("text-embedding-3-small maximum dimensions: 1536") if model == "text-embedding-ada-002" and v != 1536: raise ValueError("text-embedding-ada-002 dimensions must be 1536") return v
[docs] def instantiate(self) -> Any: """Create an OpenAI embeddings instance. Returns: OpenAIEmbeddings instance configured with the provided parameters Raises: ImportError: If langchain-openai is not installed ValueError: If configuration is invalid """ try: from langchain_openai import OpenAIEmbeddings except ImportError: raise ImportError( "OpenAI embeddings require the langchain-openai package. " "Install with: pip install langchain-openai" ) # Validate configuration self.validate_configuration() # Get API key api_key = self.get_api_key() if not api_key: raise ValueError( "OpenAI API key is required. Set OPENAI_API_KEY environment variable " "or provide api_key parameter." ) # Build kwargs kwargs = { "model": self.model, "api_key": api_key, "max_retries": self.max_retries, } # Add optional parameters if self.request_timeout is not None: kwargs["timeout"] = self.request_timeout if self.base_url: kwargs["base_url"] = self.base_url if self.organization: kwargs["organization"] = self.organization if self.dimensions: kwargs["dimensions"] = self.dimensions return OpenAIEmbeddings(**kwargs)
[docs] def get_default_model(self) -> str: """Get the default model for OpenAI embeddings.""" return "text-embedding-3-large"
[docs] def get_supported_models(self) -> list[str]: """Get list of supported OpenAI embedding models.""" return [ "text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002", "text-embedding-ada-001", ]
[docs] def get_model_info(self) -> dict: """Get information about the configured model.""" model_info = { "text-embedding-3-large": { "dimensions": 3072, "max_dimensions": 3072, "context_length": 8192, "description": "Most capable OpenAI embedding model", }, "text-embedding-3-small": { "dimensions": 1536, "max_dimensions": 1536, "context_length": 8192, "description": "Smaller, faster OpenAI embedding model", }, "text-embedding-ada-002": { "dimensions": 1536, "max_dimensions": 1536, "context_length": 8192, "description": "Legacy OpenAI embedding model", }, "text-embedding-ada-001": { "dimensions": 1024, "max_dimensions": 1024, "context_length": 2048, "description": "Legacy OpenAI embedding model", }, } return model_info.get( self.model, { "dimensions": "unknown", "context_length": "unknown", "description": "OpenAI embedding model", }, )