Source code for haive.core.engine.retriever.providers.EnsembleRetrieverConfig
"""Ensemble Retriever implementation for the Haive framework.This module provides a configuration class for the Ensemble retriever,which combines multiple retrieval strategies using weighted combinationto improve overall retrieval performance and coverage.The EnsembleRetriever works by:1. Running multiple retrievers in parallel on the same query2. Combining results using configurable weights for each retriever3. Re-ranking and deduplicating the combined results4. Returning the most relevant documents from the ensembleThis retriever is particularly useful when:- You want to combine different retrieval strategies (sparse + dense)- Need to balance precision and recall across different approaches- Building robust systems that work across diverse query types- Implementing hybrid search with customizable weightsThe implementation integrates with LangChain's EnsembleRetriever whileproviding a consistent Haive configuration interface."""fromtypingimportAnyfrompydanticimportField,field_validatorfromhaive.core.engine.retriever.retrieverimportBaseRetrieverConfigfromhaive.core.engine.retriever.typesimportRetrieverType
[docs]@BaseRetrieverConfig.register(RetrieverType.ENSEMBLE)classEnsembleRetrieverConfig(BaseRetrieverConfig):"""Configuration for Ensemble retriever in the Haive framework. This retriever combines multiple retrieval strategies using weighted combination to improve overall performance and coverage across different query types. Attributes: retriever_type (RetrieverType): The type of retriever (always ENSEMBLE). retrievers (List[BaseRetrieverConfig]): List of retriever configurations to ensemble. weights (List[float]): Weights for each retriever (must sum to 1.0). k (int): Number of documents to return from the ensemble. normalize_scores (bool): Whether to normalize scores before combining. Examples: >>> from haive.core.engine.retriever import EnsembleRetrieverConfig >>> from haive.core.engine.retriever.providers.BM25RetrieverConfig import BM25RetrieverConfig >>> from haive.core.engine.retriever.providers.VectorStoreRetrieverConfig import VectorStoreRetrieverConfig >>> >>> # Create individual retrievers >>> bm25_config = BM25RetrieverConfig(name="bm25", documents=docs, k=10) >>> vector_config = VectorStoreRetrieverConfig(name="vector", vectorstore_config=vs_config, k=10) >>> >>> # Create ensemble retriever >>> config = EnsembleRetrieverConfig( ... name="hybrid_ensemble", ... retrievers=[bm25_config, vector_config], ... weights=[0.3, 0.7], # 30% BM25, 70% vector ... k=5 ... ) >>> >>> # Instantiate and use the retriever >>> retriever = config.instantiate() >>> docs = retriever.get_relevant_documents("machine learning algorithms") """retriever_type:RetrieverType=Field(default=RetrieverType.ENSEMBLE,description="The type of retriever")# Core ensemble configurationretrievers:list[BaseRetrieverConfig]=Field(...,min_items=2,description="List of retriever configurations to combine in the ensemble",)weights:list[float]=Field(...,description="Weights for each retriever (must sum to 1.0 and match number of retrievers)",)# Retrieval parametersk:int=Field(default=4,ge=1,le=100,description="Number of documents to return from the ensemble",)# Processing optionsnormalize_scores:bool=Field(default=True,description="Whether to normalize scores before combining results")c:int=Field(default=60,ge=1,le=1000,description="Parameter for score normalization (higher values reduce score variance)",)
[docs]@field_validator("weights")@classmethoddefvalidate_weights(cls,v):"""Validate that weights sum to 1.0."""ifabs(sum(v)-1.0)>1e-6:raiseValueError(f"Weights must sum to 1.0, got {sum(v)}")returnv
[docs]@field_validator("weights")@classmethoddefvalidate_weight_values(cls,v):"""Validate that each weight is between 0 and 1."""forweightinv:ifnot0<=weight<=1:raiseValueError(f"Each weight must be between 0 and 1, got {weight}")returnv
[docs]defget_input_fields(self)->dict[str,tuple[type,Any]]:"""Return input field definitions for Ensemble retriever."""return{"query":(str,Field(description="Query string for ensemble retrieval")),}
[docs]defget_output_fields(self)->dict[str,tuple[type,Any]]:"""Return output field definitions for Ensemble retriever."""return{"documents":(list[Any],# List[Document] but avoiding importField(default_factory=list,description="Documents retrieved by the ensemble",),),}
[docs]definstantiate(self):"""Create an Ensemble retriever from this configuration. Returns: EnsembleRetriever: Instantiated retriever ready for ensemble retrieval. Raises: ImportError: If required packages are not available. ValueError: If configuration is invalid. """try:fromlangchain.retrieversimportEnsembleRetrieverexceptImportError:raiseImportError("EnsembleRetriever requires langchain package. Install with: pip install langchain")# Instantiate all component retrieversinstantiated_retrievers=[]forretriever_configinself.retrievers:try:retriever=retriever_config.instantiate()instantiated_retrievers.append(retriever)exceptExceptionase:raiseValueError(f"Failed to instantiate retriever {retriever_config.name}: {e}")# Validate we have the right number of retrieversiflen(instantiated_retrievers)!=len(self.weights):raiseValueError(f"Number of instantiated retrievers ({len(instantiated_retrievers)}) "f"doesn't match number of weights ({len(self.weights)})")returnEnsembleRetriever(retrievers=instantiated_retrievers,weights=self.weights,c=self.c)