Source code for haive.core.schema.prebuilt.query_state

"""Query State Schema for Advanced RAG and Document Processing.

This module provides comprehensive query state management for advanced RAG workflows,
document processing, and multi-query scenarios. It builds on top of MessagesState
and DocumentState to provide a unified query processing interface.

The QueryState enables:
- Multi-query processing and refinement
- Query expansion and optimization
- Retrieval strategy management
- Context tracking and memory
- Source citation and provenance
- Time-weighted and filtered queries
- Self-query and adaptive retrieval
- Query result caching and optimization

Examples:
    Basic query processing::

        from haive.core.schema.prebuilt.query_state import QueryState

        state = QueryState(
            original_query="What are the latest trends in AI?",
            query_type="research",
            retrieval_strategy="adaptive"
        )

    Advanced multi-query workflow::

        state = QueryState(
            original_query="Analyze Q4 2024 financial performance",
            refined_queries=[
                "Q4 2024 revenue growth analysis",
                "Fourth quarter 2024 profit margins",
                "2024 Q4 market performance comparison"
            ],
            query_expansion_enabled=True,
            time_weighted_retrieval=True,
            source_filters=["financial_reports", "earnings_calls"]
        )

    Self-query with structured output::

        from haive.core.schema.prebuilt.query_state import QueryType, RetrievalStrategy

        state = QueryState(
            original_query="Find all documents about machine learning published after 2023",
            query_type=QueryType.STRUCTURED,
            retrieval_strategy=RetrievalStrategy.SELF_QUERY,
            structured_query_enabled=True,
            metadata_filters={"year": {"$gt": 2023}, "topic": "machine_learning"}
        )

Author: Claude (Haive AI Agent Framework)
Version: 1.0.0
"""

from datetime import datetime
from enum import Enum
from typing import Any

from langchain_core.documents import Document
from pydantic import BaseModel, Field, field_validator

# Conditionally import DocumentState to avoid auto-registry initialization
try:
    # Try to import without triggering the full document system
    # This is a stub to check if DocumentState is available
    import sys

    if "haive.core.engine.document" in sys.modules:
        from haive.core.schema.prebuilt.document_state import DocumentState

        _HAS_DOCUMENT_STATE = True
    else:
        DocumentState = None
        _HAS_DOCUMENT_STATE = False
except ImportError:
    DocumentState = None
    _HAS_DOCUMENT_STATE = False

from haive.core.schema.prebuilt.messages_state import MessagesState


[docs] class QueryType(str, Enum): """Types of queries supported by the query processing system.""" SIMPLE = "simple" # Basic question answering RESEARCH = "research" # Research and analysis queries STRUCTURED = "structured" # Structured data extraction MULTI_STEP = "multi_step" # Multi-step reasoning COMPARISON = "comparison" # Comparative analysis SUMMARIZATION = "summarization" # Document summarization EXTRACTION = "extraction" # Information extraction CLASSIFICATION = "classification" # Document classification RECOMMENDATION = "recommendation" # Recommendation queries TEMPORAL = "temporal" # Time-based queries SPATIAL = "spatial" # Location-based queries ANALYTICAL = "analytical" # Data analysis queries
[docs] class RetrievalStrategy(str, Enum): """Retrieval strategies for query processing.""" BASIC = "basic" # Basic vector similarity ADAPTIVE = "adaptive" # Adaptive retrieval based on query SELF_QUERY = "self_query" # Self-querying retrieval PARENT_DOCUMENT = "parent_document" # Parent document retrieval MULTI_QUERY = "multi_query" # Multiple query variations ENSEMBLE = "ensemble" # Ensemble of multiple retrievers TIME_WEIGHTED = "time_weighted" # Time-weighted retrieval CONTEXTUAL = "contextual" # Contextual compression HYBRID = "hybrid" # Hybrid semantic + keyword RERANKING = "reranking" # Retrieval with reranking
[docs] class QueryComplexity(str, Enum): """Query complexity levels for processing optimization.""" LOW = "low" # Simple factual queries MEDIUM = "medium" # Moderate reasoning required HIGH = "high" # Complex multi-step reasoning EXPERT = "expert" # Expert-level analysis required
[docs] class QueryIntent(str, Enum): """Intent classification for query processing.""" INFORMATION_SEEKING = "information_seeking" PROBLEM_SOLVING = "problem_solving" DECISION_MAKING = "decision_making" LEARNING = "learning" COMPARISON = "comparison" PLANNING = "planning" TROUBLESHOOTING = "troubleshooting" CREATIVE = "creative" ANALYTICAL = "analytical" EXPLORATORY = "exploratory"
[docs] class QueryProcessingConfig(BaseModel): """Configuration for query processing behavior.""" max_query_variations: int = Field(default=5, ge=1, le=20) enable_query_expansion: bool = Field(default=True) enable_query_refinement: bool = Field(default=True) enable_context_compression: bool = Field(default=True) enable_result_reranking: bool = Field(default=False) enable_citation_tracking: bool = Field(default=True) enable_confidence_scoring: bool = Field(default=True) max_context_documents: int = Field(default=10, ge=1, le=50) context_window_size: int = Field(default=4000, ge=1000, le=16000) similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0) time_weight_decay: float = Field(default=0.1, ge=0.0, le=1.0) enable_caching: bool = Field(default=True) cache_ttl: int = Field(default=3600, ge=60)
[docs] class QueryMetrics(BaseModel): """Metrics and analytics for query processing.""" processing_time: float = Field(default=0.0, ge=0.0) retrieval_time: float = Field(default=0.0, ge=0.0) generation_time: float = Field(default=0.0, ge=0.0) total_documents_searched: int = Field(default=0, ge=0) relevant_documents_found: int = Field(default=0, ge=0) confidence_score: float = Field(default=0.0, ge=0.0, le=1.0) retrieval_accuracy: float = Field(default=0.0, ge=0.0, le=1.0) query_complexity_score: float = Field(default=0.0, ge=0.0, le=1.0) context_utilization: float = Field(default=0.0, ge=0.0, le=1.0) cache_hit_rate: float = Field(default=0.0, ge=0.0, le=1.0)
[docs] class QueryResult(BaseModel): """Result container for query processing.""" query_id: str = Field(description="Unique identifier for the query") response: str = Field(description="Generated response to the query") confidence: float = Field(default=0.0, ge=0.0, le=1.0) source_documents: list[Document] = Field(default_factory=list) citations: list[dict[str, Any]] = Field(default_factory=list) metadata: dict[str, Any] = Field(default_factory=dict) processing_metrics: QueryMetrics = Field(default_factory=QueryMetrics) class Config: arbitrary_types_allowed = True
# Define QueryState with conditional inheritance based on DocumentState availability if _HAS_DOCUMENT_STATE and DocumentState is not None: class QueryState(MessagesState, DocumentState): pass else: class QueryState(MessagesState): pass # Now redefine the actual QueryState class with its full implementation
[docs] class QueryState(QueryState): """Comprehensive query state for advanced RAG and document processing. This state schema combines messages, documents, and query-specific information to provide a complete context for query processing workflows. It supports multi-query scenarios, retrieval strategies, and advanced RAG features. The state includes: - Query processing and refinement - Document context and retrieval - Multi-query coordination - Retrieval strategy management - Results and metrics tracking - Source citation and provenance - Time-weighted and filtered queries - Adaptive and self-query capabilities Examples: Basic query state:: state = QueryState( original_query="What is quantum computing?", query_type=QueryType.SIMPLE, retrieval_strategy=RetrievalStrategy.BASIC ) Advanced research query:: state = QueryState( original_query="Analyze the impact of AI on healthcare", query_type=QueryType.RESEARCH, retrieval_strategy=RetrievalStrategy.ADAPTIVE, query_expansion_enabled=True, time_weighted_retrieval=True, source_filters=["medical_journals", "clinical_trials"], metadata_filters={"publication_year": {"$gte": 2020}} ) Multi-query workflow:: state = QueryState( original_query="Compare Q3 vs Q4 2024 performance", refined_queries=[ "Q3 2024 financial results analysis", "Q4 2024 earnings report summary", "Q3 Q4 2024 performance comparison" ], query_type=QueryType.COMPARISON, retrieval_strategy=RetrievalStrategy.MULTI_QUERY ) """ # Core Query Information original_query: str = Field(description="The original user query") query_id: str = Field(default_factory=lambda: f"query_{datetime.now().timestamp()}") query_type: QueryType = Field(default=QueryType.SIMPLE) query_intent: QueryIntent = Field(default=QueryIntent.INFORMATION_SEEKING) query_complexity: QueryComplexity = Field(default=QueryComplexity.MEDIUM) # Query Processing refined_queries: list[str] = Field(default_factory=list) expanded_queries: list[str] = Field(default_factory=list) query_variations: list[str] = Field(default_factory=list) processed_query: str = Field(default="") # Retrieval Configuration retrieval_strategy: RetrievalStrategy = Field(default=RetrievalStrategy.ADAPTIVE) retrieval_config: dict[str, Any] = Field(default_factory=dict) # Query Enhancement query_expansion_enabled: bool = Field(default=True) query_refinement_enabled: bool = Field(default=True) multi_query_enabled: bool = Field(default=False) structured_query_enabled: bool = Field(default=False) time_weighted_retrieval: bool = Field(default=False) # Filtering and Constraints source_filters: list[str] = Field(default_factory=list) metadata_filters: dict[str, Any] = Field(default_factory=dict) time_range_filter: dict[str, datetime] | None = Field(default=None) similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0) max_results: int = Field(default=10, ge=1, le=100) # Context and Memory context_documents: list[Document] = Field(default_factory=list) retrieved_documents: list[Document] = Field(default_factory=list) relevant_contexts: list[str] = Field(default_factory=list) memory_contexts: list[str] = Field(default_factory=list) # Results and Tracking query_results: list[QueryResult] = Field(default_factory=list) current_result: QueryResult | None = Field(default=None) intermediate_results: list[dict[str, Any]] = Field(default_factory=list) # Citations and Provenance citations: list[dict[str, Any]] = Field(default_factory=list) source_provenance: dict[str, Any] = Field(default_factory=dict) confidence_scores: dict[str, float] = Field(default_factory=dict) # Processing Configuration processing_config: QueryProcessingConfig = Field( default_factory=QueryProcessingConfig ) # Metrics and Analytics processing_metrics: QueryMetrics = Field(default_factory=QueryMetrics) query_history: list[dict[str, Any]] = Field(default_factory=list) # Execution State current_stage: str = Field(default="initialized") execution_path: list[str] = Field(default_factory=list) error_history: list[dict[str, Any]] = Field(default_factory=list) # Caching and Optimization cache_key: str | None = Field(default=None) cached_results: dict[str, Any] = Field(default_factory=dict) optimization_hints: dict[str, Any] = Field(default_factory=dict)
[docs] class Config: arbitrary_types_allowed = True
[docs] @field_validator("original_query") @classmethod def validate_original_query(cls, v: str) -> str: """Validate that the original query is not empty.""" if not v or not v.strip(): raise ValueError("Original query cannot be empty") return v.strip()
[docs] @field_validator("refined_queries") @classmethod def validate_refined_queries(cls, v: list[str]) -> list[str]: """Validate refined queries are not empty.""" return [q.strip() for q in v if q and q.strip()]
[docs] @field_validator("time_range_filter") @classmethod def validate_time_range( cls, v: dict[str, datetime] | None ) -> dict[str, datetime] | None: """Validate time range filter has valid start and end dates.""" if v and "start" in v and "end" in v and v["start"] > v["end"]: raise ValueError("Start date must be before end date") return v
[docs] def add_refined_query(self, query: str) -> None: """Add a refined query to the list.""" if query and query.strip() and query not in self.refined_queries: self.refined_queries.append(query.strip())
[docs] def add_expanded_query(self, query: str) -> None: """Add an expanded query to the list.""" if query and query.strip() and query not in self.expanded_queries: self.expanded_queries.append(query.strip())
[docs] def add_query_variation(self, query: str) -> None: """Add a query variation to the list.""" if query and query.strip() and query not in self.query_variations: self.query_variations.append(query.strip())
[docs] def add_context_document(self, document: Document) -> None: """Add a context document to the state.""" if document not in self.context_documents: self.context_documents.append(document)
[docs] def add_retrieved_document(self, document: Document) -> None: """Add a retrieved document to the state.""" if document not in self.retrieved_documents: self.retrieved_documents.append(document)
[docs] def add_citation(self, citation: dict[str, Any]) -> None: """Add a citation to the state.""" if citation not in self.citations: self.citations.append(citation)
[docs] def set_confidence_score(self, source: str, score: float) -> None: """Set confidence score for a source.""" if 0.0 <= score <= 1.0: self.confidence_scores[source] = score
[docs] def get_confidence_score(self, source: str) -> float: """Get confidence score for a source.""" return self.confidence_scores.get(source, 0.0)
[docs] def add_intermediate_result(self, result: dict[str, Any]) -> None: """Add an intermediate result to tracking.""" result["timestamp"] = datetime.now().isoformat() self.intermediate_results.append(result)
[docs] def update_stage(self, stage: str) -> None: """Update the current processing stage.""" self.current_stage = stage self.execution_path.append(stage)
[docs] def add_error(self, error: str, context: dict[str, Any] | None = None) -> None: """Add an error to the history.""" error_entry = { "error": error, "timestamp": datetime.now().isoformat(), "stage": self.current_stage, "context": context or {}, } self.error_history.append(error_entry)
[docs] def get_all_queries(self) -> list[str]: """Get all queries including original, refined, and expanded.""" all_queries = [self.original_query] all_queries.extend(self.refined_queries) all_queries.extend(self.expanded_queries) all_queries.extend(self.query_variations) return list(set(all_queries)) # Remove duplicates
[docs] def get_all_documents(self) -> list[Document]: """Get all documents including raw, context, and retrieved.""" all_docs = [] # Use raw_documents from DocumentState all_docs.extend(self.raw_documents) all_docs.extend(self.context_documents) all_docs.extend(self.retrieved_documents) # Remove duplicates based on content seen_content = set() unique_docs = [] for doc in all_docs: content_hash = hash(doc.page_content) if content_hash not in seen_content: seen_content.add(content_hash) unique_docs.append(doc) return unique_docs
[docs] def get_processing_summary(self) -> dict[str, Any]: """Get a summary of processing statistics.""" return { "query_id": self.query_id, "original_query": self.original_query, "query_type": self.query_type.value, "retrieval_strategy": self.retrieval_strategy.value, "total_queries": len(self.get_all_queries()), "total_documents": len(self.get_all_documents()), "context_documents": len(self.context_documents), "retrieved_documents": len(self.retrieved_documents), "citations": len(self.citations), "current_stage": self.current_stage, "execution_path": self.execution_path, "processing_time": self.processing_metrics.processing_time, "confidence_score": self.processing_metrics.confidence_score, "errors": len(self.error_history), }
[docs] def is_multi_query_workflow(self) -> bool: """Check if this is a multi-query workflow.""" return ( self.multi_query_enabled or len(self.refined_queries) > 1 or len(self.expanded_queries) > 1 or len(self.query_variations) > 1 )
[docs] def requires_structured_output(self) -> bool: """Check if structured output is required.""" return self.structured_query_enabled or self.query_type in [ QueryType.STRUCTURED, QueryType.EXTRACTION, QueryType.CLASSIFICATION, ]
[docs] def get_active_filters(self) -> dict[str, Any]: """Get all active filters for the query.""" filters = {} if self.source_filters: filters["sources"] = self.source_filters if self.metadata_filters: filters["metadata"] = self.metadata_filters if self.time_range_filter: filters["time_range"] = self.time_range_filter if self.similarity_threshold < 1.0: filters["similarity_threshold"] = self.similarity_threshold return filters
[docs] def create_cache_key(self) -> str: """Create a cache key for the current query state.""" import hashlib # Create a hash based on query and key parameters cache_components = [ self.original_query, str(self.query_type.value), str(self.retrieval_strategy.value), str(sorted(self.source_filters)), str(self.metadata_filters), str(self.similarity_threshold), str(self.max_results), ] cache_string = "|".join(cache_components) cache_key = hashlib.md5(cache_string.encode()).hexdigest() self.cache_key = cache_key return cache_key
# Alias for backward compatibility QueryProcessingState = QueryState # Export all classes __all__ = [ "QueryComplexity", "QueryIntent", "QueryMetrics", "QueryProcessingConfig", "QueryProcessingState", "QueryResult", "QueryState", "QueryType", "RetrievalStrategy", ]