Source code for haive.core.engine.embedding.providers.OllamaEmbeddingConfig

"""Ollama embedding configuration."""

from typing import Any

from pydantic import Field, field_validator

from haive.core.engine.embedding.base import BaseEmbeddingConfig
from haive.core.engine.embedding.types import EmbeddingType


[docs] @BaseEmbeddingConfig.register(EmbeddingType.OLLAMA) class OllamaEmbeddingConfig(BaseEmbeddingConfig): """Configuration for Ollama embeddings. This configuration provides access to locally hosted Ollama embedding models including nomic-embed-text, mxbai-embed-large, and other supported models. Examples: Basic usage: .. code-block:: python config = OllamaEmbeddingConfig( name="ollama_embeddings", model="nomic-embed-text", base_url="http://localhost:11434" ) embeddings = config.instantiate() With custom headers:: config = OllamaEmbeddingConfig( name="ollama_embeddings", model="mxbai-embed-large", base_url="http://localhost:11434", headers={"Authorization": "Bearer token"} ) With custom options:: config = OllamaEmbeddingConfig( name="ollama_embeddings", model="nomic-embed-text", base_url="http://localhost:11434", model_options={"temperature": 0.1} ) Attributes: embedding_type: Always EmbeddingType.OLLAMA model: Ollama model name (e.g., "nomic-embed-text") base_url: Ollama server URL headers: Optional HTTP headers for requests model_options: Optional model-specific options """ embedding_type: EmbeddingType = Field( default=EmbeddingType.OLLAMA, description="The embedding provider type" ) # Ollama-specific fields base_url: str = Field( default="http://localhost:11434", description="Ollama server URL" ) headers: dict[str, str] | None = Field( default=None, description="Optional HTTP headers for requests" ) model_options: dict[str, Any] | None = Field( default=None, description="Optional model-specific options" ) request_timeout: float | None = Field( default=None, description="Timeout for API requests in seconds" ) # SecureConfigMixin configuration (not typically needed for local Ollama) provider: str = Field( default="ollama", description="Provider name for API key resolution" )
[docs] @field_validator("model") @classmethod def validate_model(cls, v) -> Any: """Validate the Ollama model name.""" popular_models = { "nomic-embed-text", "mxbai-embed-large", "snowflake-arctic-embed", "all-minilm", "llama2:7b", "mistral:7b", "codellama:7b", } if v not in popular_models: # Log info but don't fail - Ollama supports many models import logging logger = logging.getLogger(__name__) logger.info(f"Using Ollama model: {v}. Popular models: {popular_models}") return v
[docs] @field_validator("base_url") @classmethod def validate_base_url(cls, v) -> Any: """Validate Ollama server URL.""" if not v or not v.strip(): raise ValueError("Base URL is required") v = v.strip() if not v.startswith(("http://", "https://")): raise ValueError("Base URL must start with 'http://' or 'https://'") # Remove trailing slash if v.endswith("/"): v = v[:-1] return v
[docs] def instantiate(self) -> Any: """Create an Ollama embeddings instance. Returns: OllamaEmbeddings instance configured with the provided parameters Raises: ImportError: If langchain-ollama is not installed ValueError: If configuration is invalid """ try: from langchain_ollama import OllamaEmbeddings except ImportError: raise ImportError( "Ollama embeddings require the langchain-ollama package. " "Install with: pip install langchain-ollama" ) # Validate configuration self.validate_configuration() # Build kwargs kwargs = { "model": self.model, "base_url": self.base_url, } # Add optional parameters if self.headers: kwargs["headers"] = self.headers if self.model_options: kwargs["model_kwargs"] = self.model_options if self.request_timeout: kwargs["timeout"] = self.request_timeout return OllamaEmbeddings(**kwargs)
[docs] def validate_configuration(self) -> None: """Validate the configuration before instantiation.""" super().validate_configuration() if not self.base_url: raise ValueError("Base URL is required")
[docs] def get_default_model(self) -> str: """Get the default model for Ollama embeddings.""" return "nomic-embed-text"
[docs] def get_supported_models(self) -> list[str]: """Get list of popular Ollama embedding models.""" return [ "nomic-embed-text", "mxbai-embed-large", "snowflake-arctic-embed", "all-minilm", "bge-large", "bge-base", "e5-large", "e5-base", "gte-large", "gte-base", ]
[docs] def get_model_info(self) -> dict: """Get information about the configured model.""" model_info = { "nomic-embed-text": { "dimensions": 768, "description": "Nomic's text embedding model, good for general use", }, "mxbai-embed-large": { "dimensions": 1024, "description": "Mixedbread AI's large embedding model", }, "snowflake-arctic-embed": { "dimensions": 1024, "description": "Snowflake's Arctic embedding model", }, "all-minilm": { "dimensions": 384, "description": "Lightweight sentence transformer model", }, "bge-large": { "dimensions": 1024, "description": "BAAI's large embedding model", }, "bge-base": { "dimensions": 768, "description": "BAAI's base embedding model", }, "e5-large": { "dimensions": 1024, "description": "Microsoft's E5 large embedding model", }, "e5-base": { "dimensions": 768, "description": "Microsoft's E5 base embedding model", }, } return model_info.get( self.model, {"dimensions": "unknown", "description": "Ollama embedding model"}, )
[docs] def test_connection(self) -> bool: """Test connection to Ollama server. Returns: True if connection is successful, False otherwise """ try: import requests response = requests.get(f"{self.base_url}/api/tags", timeout=5) return response.status_code == 200 except Exception: return False
[docs] def list_available_models(self) -> list[str]: """List models available on the Ollama server. Returns: List of model names available on the server """ try: import requests response = requests.get(f"{self.base_url}/api/tags", timeout=5) if response.status_code == 200: data = response.json() return [model["name"] for model in data.get("models", [])] return [] except Exception: return []