Source code for haive.core.engine.embedding.providers.AzureOpenAIEmbeddingConfig

"""Azure OpenAI embedding configuration."""

import os
from typing import Any

from pydantic import Field, field_validator

from haive.core.engine.embedding.base import BaseEmbeddingConfig
from haive.core.engine.embedding.types import EmbeddingType


[docs] @BaseEmbeddingConfig.register(EmbeddingType.AZURE_OPENAI) class AzureOpenAIEmbeddingConfig(BaseEmbeddingConfig): """Configuration for Azure OpenAI embeddings. This configuration provides access to OpenAI embedding models deployed on Azure OpenAI Service. It supports both standard and data zone deployments. Examples: Basic usage: .. code-block:: python config = AzureOpenAIEmbeddingConfig( name="azure_embeddings", model="text-embedding-3-large", deployment_name="text-embedding-3-large", azure_endpoint="https://your-resource.openai.azure.com/", api_key="your-api-key" ) embeddings = config.instantiate() Using environment variables:: # Set AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, etc. config = AzureOpenAIEmbeddingConfig( name="azure_embeddings", model="text-embedding-3-large", deployment_name="text-embedding-3-large" ) With custom API version:: config = AzureOpenAIEmbeddingConfig( name="azure_embeddings", model="text-embedding-3-large", deployment_name="text-embedding-3-large", api_version="2024-02-15-preview" ) Attributes: embedding_type: Always EmbeddingType.AZURE_OPENAI deployment_name: Azure deployment name for the model azure_endpoint: Azure OpenAI service endpoint URL api_version: Azure OpenAI API version api_key: Azure OpenAI API key dimensions: Output dimensions (optional, model-dependent) """ embedding_type: EmbeddingType = Field( default=EmbeddingType.AZURE_OPENAI, description="The embedding provider type" ) # Azure-specific required fields deployment_name: str = Field( ..., description="Azure deployment name for the embedding model" ) azure_endpoint: str = Field( default_factory=lambda: os.getenv("AZURE_OPENAI_ENDPOINT", ""), description="Azure OpenAI service endpoint URL", ) # Azure-specific optional fields api_version: str = Field( default_factory=lambda: os.getenv( "AZURE_OPENAI_API_VERSION", "2024-02-15-preview" ), description="Azure OpenAI API version", ) max_retries: int = Field( default=3, description="Maximum number of retries for API calls" ) request_timeout: float | None = Field( default=None, description="Timeout for API requests in seconds" ) # SecureConfigMixin configuration provider: str = Field( default="azure", description="Provider name for API key resolution" )
[docs] @field_validator("azure_endpoint") @classmethod def validate_azure_endpoint(cls, v) -> Any: """Validate Azure OpenAI endpoint format.""" if not v: raise ValueError( "Azure endpoint is required. Set AZURE_OPENAI_ENDPOINT environment variable " "or provide azure_endpoint parameter." ) if not v.startswith("https://"): raise ValueError("Azure endpoint must start with 'https://'") if not v.endswith("/"): v = v + "/" return v
[docs] @field_validator("deployment_name") @classmethod def validate_deployment_name(cls, v) -> Any: """Validate deployment name.""" if not v or not v.strip(): raise ValueError("Deployment name is required and cannot be empty") return v.strip()
[docs] @field_validator("api_version") @classmethod def validate_api_version(cls, v) -> Any: """Validate API version format.""" if not v: raise ValueError("API version is required") # Check format (YYYY-MM-DD or YYYY-MM-DD-preview) import re if not re.match(r"^\d{4}-\d{2}-\d{2}(-preview)?$", v): raise ValueError( "API version must be in format YYYY-MM-DD or YYYY-MM-DD-preview" ) return v
[docs] def instantiate(self) -> Any: """Create an Azure OpenAI embeddings instance. Returns: AzureOpenAIEmbeddings instance configured with the provided parameters Raises: ImportError: If langchain-openai is not installed ValueError: If configuration is invalid """ try: from langchain_openai import AzureOpenAIEmbeddings except ImportError: raise ImportError( "Azure OpenAI embeddings require the langchain-openai package. " "Install with: pip install langchain-openai" ) # Validate configuration self.validate_configuration() # Get API key api_key = self.get_api_key() if not api_key: raise ValueError( "Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable " "or provide api_key parameter." ) # Build kwargs kwargs = { "model": self.deployment_name, # Azure uses deployment name as model "azure_deployment": self.deployment_name, "azure_endpoint": self.azure_endpoint, "api_key": api_key, "api_version": self.api_version, "max_retries": self.max_retries, } # Add optional parameters if self.request_timeout is not None: kwargs["timeout"] = self.request_timeout if self.dimensions: kwargs["dimensions"] = self.dimensions return AzureOpenAIEmbeddings(**kwargs)
[docs] def validate_configuration(self) -> None: """Validate the configuration before instantiation.""" super().validate_configuration() if not self.deployment_name: raise ValueError("Deployment name is required") if not self.azure_endpoint: raise ValueError("Azure endpoint is required") if not self.api_version: raise ValueError("API version is required")
[docs] def get_default_model(self) -> str: """Get the default model for Azure OpenAI embeddings.""" return "text-embedding-3-large"
[docs] def get_supported_models(self) -> list[str]: """Get list of supported Azure OpenAI embedding models.""" return [ "text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002", ]
[docs] def get_model_info(self) -> dict: """Get information about the configured model.""" model_info = { "text-embedding-3-large": { "dimensions": 3072, "max_dimensions": 3072, "context_length": 8192, "description": "Most capable Azure OpenAI embedding model", }, "text-embedding-3-small": { "dimensions": 1536, "max_dimensions": 1536, "context_length": 8192, "description": "Smaller, faster Azure OpenAI embedding model", }, "text-embedding-ada-002": { "dimensions": 1536, "max_dimensions": 1536, "context_length": 8192, "description": "Legacy Azure OpenAI embedding model", }, } return model_info.get( self.model, { "dimensions": "unknown", "context_length": "unknown", "description": "Azure OpenAI embedding model", }, )