Source code for haive.core.engine.retriever.providers.ContextualCompressionRetrieverConfig
"""Contextual Compression Retriever implementation for the Haive framework.This module provides a configuration class for the Contextual Compression retriever,which compresses retrieved documents to extract only the most relevant informationrelative to the query, improving both relevance and efficiency.The ContextualCompressionRetriever works by:1. Using a base retriever to get initial document candidates2. Applying a compressor (LLM or extractive) to compress each document3. Extracting only the parts of documents that are relevant to the query4. Returning compressed, more focused document contentThis retriever is particularly useful when:- Documents are long and contain irrelevant sections- Need to reduce token usage in downstream processing- Want to improve precision by filtering out noise- Building systems with strict context length limitsThe implementation integrates with LangChain's ContextualCompressionRetriever whileproviding a consistent Haive configuration interface with flexible compression options."""fromtypingimportAnyfrompydanticimportField,field_validatorfromhaive.core.engine.aug_llmimportAugLLMConfigfromhaive.core.engine.retriever.retrieverimportBaseRetrieverConfigfromhaive.core.engine.retriever.typesimportRetrieverType
[docs]@BaseRetrieverConfig.register(RetrieverType.CONTEXTUAL_COMPRESSION)classContextualCompressionRetrieverConfig(BaseRetrieverConfig):"""Configuration for Contextual Compression retriever in the Haive framework. This retriever compresses retrieved documents to extract only the most relevant information relative to the query, improving both relevance and efficiency. Attributes: retriever_type (RetrieverType): The type of retriever (always CONTEXTUAL_COMPRESSION). base_retriever (BaseRetrieverConfig): The underlying retriever to get initial candidates. compressor_type (str): Type of compressor to use ('llm_chain_extract', 'llm_chain_filter'). llm_config (Optional[AugLLMConfig]): LLM configuration for compression (required for LLM compressors). chunk_size (int): Maximum size of compressed chunks. chunk_overlap (int): Overlap between compressed chunks. Examples: >>> from haive.core.engine.retriever import ContextualCompressionRetrieverConfig >>> from haive.core.engine.retriever.providers.VectorStoreRetrieverConfig import VectorStoreRetrieverConfig >>> from haive.core.engine.aug_llm import AugLLMConfig >>> >>> # Create base retriever and LLM config >>> base_config = VectorStoreRetrieverConfig(name="base", vectorstore_config=vs_config) >>> llm_config = AugLLMConfig(model_name="gpt-3.5-turbo", provider="openai") >>> >>> # Create contextual compression retriever >>> config = ContextualCompressionRetrieverConfig( ... name="compression_retriever", ... base_retriever=base_config, ... compressor_type="llm_chain_extract", ... llm_config=llm_config ... ) >>> >>> # Instantiate and use the retriever >>> retriever = config.instantiate() >>> docs = retriever.get_relevant_documents("machine learning algorithms") """retriever_type:RetrieverType=Field(default=RetrieverType.CONTEXTUAL_COMPRESSION,description="The type of retriever",)# Core configurationbase_retriever:BaseRetrieverConfig=Field(...,description="Base retriever configuration to get initial document candidates",)# Compressor configurationcompressor_type:str=Field(default="llm_chain_extract",description="Type of compressor: 'llm_chain_extract', 'llm_chain_filter'",)llm_config:AugLLMConfig|None=Field(default=None,description="LLM configuration for compression (required for LLM compressors)",)
[docs]@field_validator("compressor_type")@classmethoddefvalidate_compressor_type(cls,v):"""Validate compressor type."""valid_types=["llm_chain_extract","llm_chain_filter"]ifvnotinvalid_types:raiseTypeError(f"compressor_type must be one of {valid_types}, got {v}")returnv
[docs]@field_validator("llm_config")@classmethoddefvalidate_llm_config_required(cls,v,info):"""Validate that LLM config is provided for LLM compressors."""# Note: In Pydantic v2, cross-field validation requires model_validator# This validator only checks if llm_config is provided when neededreturnv
[docs]defget_input_fields(self)->dict[str,tuple[type,Any]]:"""Return input field definitions for Contextual Compression retriever."""return{"query":(str,Field(description="Query for contextual compression and retrieval"),),}
[docs]defget_output_fields(self)->dict[str,tuple[type,Any]]:"""Return output field definitions for Contextual Compression retriever."""return{"documents":(list[Any],# List[Document] but avoiding importField(default_factory=list,description="Compressed documents relevant to the query",),),}
[docs]definstantiate(self):"""Create a Contextual Compression retriever from this configuration. Returns: ContextualCompressionRetriever: Instantiated retriever ready for compression retrieval. Raises: ImportError: If required packages are not available. ValueError: If configuration is invalid. """try:fromlangchain.retrieversimportContextualCompressionRetrieverfromlangchain.retrievers.document_compressorsimport(LLMChainExtractor,LLMChainFilter,)exceptImportError:raiseImportError("ContextualCompressionRetriever requires langchain package. ""Install with: pip install langchain")# Instantiate the base retrievertry:base_retriever=self.base_retriever.instantiate()exceptExceptionase:raiseValueError(f"Failed to instantiate base retriever: {e}")# Create the appropriate compressorifself.compressor_type=="llm_chain_extract":ifnotself.llm_config:raiseValueError("llm_config is required for llm_chain_extract compressor")try:llm=self.llm_config.instantiate()exceptExceptionase:raiseValueError(f"Failed to instantiate LLM: {e}")compressor=LLMChainExtractor.from_llm(llm)elifself.compressor_type=="llm_chain_filter":ifnotself.llm_config:raiseValueError("llm_config is required for llm_chain_filter compressor")try:llm=self.llm_config.instantiate()exceptExceptionase:raiseValueError(f"Failed to instantiate LLM: {e}")compressor=LLMChainFilter.from_llm(llm)else:raiseTypeError(f"Unsupported compressor_type: {self.compressor_type}")returnContextualCompressionRetriever(base_compressor=compressor,base_retriever=base_retriever)