Source code for haive.core.engine.retriever.providers.MultiQueryRetrieverConfig
"""Multi-Query Retriever implementation for the Haive framework.This module provides a configuration class for the Multi-Query retriever,which generates multiple query variations to improve retrieval coverageand find more relevant documents for complex or ambiguous queries.The MultiQueryRetriever works by:1. Using an LLM to generate multiple query variations from the original query2. Running each generated query against the base retriever3. Collecting and deduplicating all retrieved documents4. Returning the combined set of unique documentsThis retriever is particularly useful when:- Dealing with complex or ambiguous user queries- Need to improve recall by finding documents with different phrasings- User queries might miss relevant documents due to vocabulary mismatch- Building systems that need comprehensive document coverageThe implementation integrates with LangChain's MultiQueryRetriever whileproviding a consistent Haive configuration interface with LLM integration."""fromtypingimportAnyfrompydanticimportField,field_validatorfromhaive.core.engine.aug_llmimportAugLLMConfigfromhaive.core.engine.retriever.retrieverimportBaseRetrieverConfigfromhaive.core.engine.retriever.typesimportRetrieverType
[docs]@BaseRetrieverConfig.register(RetrieverType.MULTI_QUERY)classMultiQueryRetrieverConfig(BaseRetrieverConfig):"""Configuration for Multi-Query retriever in the Haive framework. This retriever generates multiple query variations using an LLM to improve retrieval coverage and find more relevant documents for complex queries. Attributes: retriever_type (RetrieverType): The type of retriever (always MULTI_QUERY). base_retriever (BaseRetrieverConfig): The underlying retriever to query with variations. llm_config (AugLLMConfig): LLM configuration for generating query variations. num_queries (int): Number of query variations to generate. include_original (bool): Whether to include the original query in the set. Examples: >>> from haive.core.engine.retriever import MultiQueryRetrieverConfig >>> from haive.core.engine.retriever.providers.VectorStoreRetrieverConfig import VectorStoreRetrieverConfig >>> from haive.core.engine.aug_llm import AugLLMConfig >>> >>> # Create base retriever and LLM config >>> base_config = VectorStoreRetrieverConfig(name="base", vectorstore_config=vs_config) >>> llm_config = AugLLMConfig(model_name="gpt-3.5-turbo", provider="openai") >>> >>> # Create multi-query retriever >>> config = MultiQueryRetrieverConfig( ... name="multi_query_retriever", ... base_retriever=base_config, ... llm_config=llm_config, ... num_queries=3, ... include_original=True ... ) >>> >>> # Instantiate and use the retriever >>> retriever = config.instantiate() >>> docs = retriever.get_relevant_documents("machine learning algorithms") """retriever_type:RetrieverType=Field(default=RetrieverType.MULTI_QUERY,description="The type of retriever")# Core configurationbase_retriever:BaseRetrieverConfig=Field(...,description="Base retriever configuration to query with variations")llm_config:AugLLMConfig=Field(...,description="LLM configuration for generating query variations")# Query generation parametersnum_queries:int=Field(default=3,ge=1,le=10,description="Number of query variations to generate")include_original:bool=Field(default=True,description="Whether to include the original query in the set")# Query prompt customizationquery_prompt_template:str|None=Field(default=None,description="Custom prompt template for query generation (uses default if None)",)
[docs]@field_validator("num_queries")@classmethoddefvalidate_num_queries(cls,v):"""Ensure reasonable number of queries."""ifv<1:raiseValueError("num_queries must be at least 1")ifv>10:raiseValueError("num_queries should not exceed 10 for performance reasons")returnv
[docs]defget_input_fields(self)->dict[str,tuple[type,Any]]:"""Return input field definitions for Multi-Query retriever."""return{"query":(str,Field(description="Original query to generate variations from"),),}
[docs]defget_output_fields(self)->dict[str,tuple[type,Any]]:"""Return output field definitions for Multi-Query retriever."""return{"documents":(list[Any],# List[Document] but avoiding importField(default_factory=list,description="Documents retrieved from all query variations",),),}
[docs]definstantiate(self):"""Create a Multi-Query retriever from this configuration. Returns: MultiQueryRetriever: Instantiated retriever ready for multi-query retrieval. Raises: ImportError: If required packages are not available. ValueError: If configuration is invalid. """try:fromlangchain.retrievers.multi_queryimportMultiQueryRetrieverexceptImportError:raiseImportError("MultiQueryRetriever requires langchain package. ""Install with: pip install langchain")# Instantiate the base retrievertry:base_retriever=self.base_retriever.instantiate()exceptExceptionase:raiseValueError(f"Failed to instantiate base retriever: {e}")# Instantiate the LLMtry:llm=self.llm_config.instantiate()exceptExceptionase:raiseValueError(f"Failed to instantiate LLM: {e}")# Create the multi-query retrieverkwargs={"retriever":base_retriever,"llm_chain":llm,"include_original":self.include_original,}# Add custom prompt if providedifself.query_prompt_template:try:fromlangchain.promptsimportPromptTemplateprompt=PromptTemplate(input_variables=["question"],template=self.query_prompt_template)kwargs["prompt"]=promptexceptImportError:raiseImportError("Custom prompt templates require langchain package. ""Install with: pip install langchain")returnMultiQueryRetriever.from_llm(**kwargs)