Source code for haive.core.schema.prebuilt.llm_state

"""LLM-specific state with single engine and token tracking.

This module provides a state schema optimized for LLM-based agents that need
to track token usage against thresholds and metadata.
"""

from __future__ import annotations

from typing import Any, Self

from pydantic import Field, computed_field, model_validator

from haive.core.engine.aug_llm import AugLLMConfig
from haive.core.schema.prebuilt.tool_state import ToolState

MODEL_CONTEXT_LENGTHS = {
    "gpt-4": 8192,
    "gpt-4-32k": 32768,
    "gpt-4-turbo": 128000,
    "gpt-4-turbo-preview": 128000,
    "gpt-4o": 128000,
    "gpt-4o-mini": 128000,
    "gpt-3.5-turbo": 4096,
    "gpt-3.5-turbo-16k": 16384,
    "claude-3-opus": 200000,
    "claude-3-sonnet": 200000,
    "claude-3-haiku": 200000,
    "claude-2.1": 200000,
    "claude-2": 100000,
    "claude-instant": 100000,
    "llama-2-70b": 4096,
    "mixtral-8x7b": 32768,
    "gemini-pro": 32768,
    "gemini-pro-1.5": 1000000,
}


[docs] class LLMState(ToolState): """State schema for LLM-based agents with single engine, tool management and token tracking. This schema is designed for agents that: - Use a single primary LLM engine (with optional additional engines) - Need automatic tool management and routing (inherited from ToolState) - Need automatic token usage tracking for all messages (inherited from ToolState → MessagesStateWithTokenUsage) - Want to compare token usage against engine metadata/thresholds - Require cost tracking and capacity monitoring The schema combines: - ToolState for tool management, routing, and token tracking - Single primary engine field for cleaner LLM agent design - Computed properties for threshold comparison - Automatic context length detection from model names Input Schema: Just messages (from MessagesStateWithTokenUsage) Output Schema: Full state including engine, messages, token usage, etc. Examples: from haive.core.schema.prebuilt import LLMState from haive.core.engine.aug_llm import AugLLMConfig # Create state with engine state = LLMState( engine=AugLLMConfig( name="gpt4_engine", model="gpt-4-turbo", # Automatically detects 128k context temperature=0.7 ) ) # Messages automatically track tokens state.add_message(ai_message) # Check against model-specific thresholds if state.is_approaching_token_limit: print(f"Warning: {state.token_usage_percentage:.1f}% of {state.context_length} tokens used") # Set custom thresholds state.warning_threshold = 0.75 # Warn at 75% usage state.critical_threshold = 0.90 # Critical at 90% usage # Get cost analysis analysis = state.get_conversation_cost_analysis() print(f"Total cost: ${analysis['total_cost']:.4f}") """ engine: AugLLMConfig | None = Field( default=None, description="The LLM engine for this agent" ) warning_threshold: float = Field( default=0.75, ge=0.0, le=1.0, description="Threshold for warning about token usage (0.0-1.0)", ) critical_threshold: float = Field( default=0.9, ge=0.0, le=1.0, description="Threshold for critical token usage (0.0-1.0)", ) context_length_override: int | None = Field( default=None, description="Override auto-detected context length" )
[docs] @model_validator(mode="after") def setup_primary_engine_references(self) -> Self: """Ensure the primary LLM engine is available in engines dict with standard keys. This works with ToolState's engine management to provide consistent access. """ super().sync_tools_and_update_routes() if self.engine: self.engines["main"] = self.engine self.engines["llm"] = self.engine self.engines["primary"] = self.engine if hasattr(self.engine, "name") and self.engine.name: self.engines[self.engine.name] = self.engine return self
@computed_field @property def context_length(self) -> int: """Get the context length for the current model.""" if self.context_length_override: return self.context_length_override if not self.engine: return 4096 model_name = None if hasattr(self.engine, "model"): model_name = self.engine.model elif hasattr(self.engine, "model_name"): model_name = self.engine.model_name elif hasattr(self.engine, "llm_config"): llm_config = self.engine.llm_config if hasattr(llm_config, "model"): model_name = llm_config.model elif hasattr(llm_config, "model_name"): model_name = llm_config.model_name if model_name: if model_name in MODEL_CONTEXT_LENGTHS: return MODEL_CONTEXT_LENGTHS[model_name] model_lower = model_name.lower() for key, length in MODEL_CONTEXT_LENGTHS.items(): if key in model_lower or model_lower in key: return length if hasattr(self.engine, "max_tokens") and self.engine.max_tokens: return self.engine.max_tokens return 4096 @computed_field @property def token_usage_percentage(self) -> float: """Get token usage as percentage of context length.""" if not self.token_usage: return 0.0 return self.token_usage.total_tokens / self.context_length * 100 @computed_field @property def is_approaching_token_limit(self) -> bool: """Check if token usage is approaching the warning threshold.""" return self.token_usage_percentage / 100 >= self.warning_threshold @computed_field @property def is_at_critical_limit(self) -> bool: """Check if token usage has reached the critical threshold.""" return self.token_usage_percentage / 100 >= self.critical_threshold @computed_field @property def is_at_token_limit(self) -> bool: """Check if token usage has reached or exceeded the limit.""" return self.token_usage_percentage >= 100 @computed_field @property def remaining_tokens(self) -> int: """Calculate remaining tokens before hitting the limit.""" if not self.engine: return 0 max_tokens = 4000 if hasattr(self.engine, "max_tokens") and self.engine.max_tokens: max_tokens = self.engine.max_tokens elif ( hasattr(self.engine, "llm_config") and hasattr(self.engine.llm_config, "max_tokens") and self.engine.llm_config.max_tokens ): max_tokens = self.engine.llm_config.max_tokens if not self.token_usage: return max_tokens return max(0, max_tokens - self.token_usage.total_tokens)
[docs] def get_engine_metadata(self) -> dict[str, Any]: """Get metadata from the engine for threshold comparisons.""" if not self.engine: return {} metadata = {} if hasattr(self.engine, "metadata"): metadata.update(self.engine.metadata) for field in ["temperature", "max_tokens", "model", "model_name"]: if hasattr(self.engine, field): value = getattr(self.engine, field) if value is not None: metadata[field] = value if hasattr(self.engine, "llm_config") and self.engine.llm_config: llm_config = self.engine.llm_config for field in ["temperature", "max_tokens", "model"]: if hasattr(llm_config, field): value = getattr(llm_config, field) if value is not None: metadata[f"llm_{field}"] = value return metadata
[docs] def should_summarize_context(self, threshold: float = 0.75) -> bool: """Determine if context should be summarized based on token usage. Args: threshold: Percentage threshold (0.0-1.0) for triggering summarization Returns: True if token usage exceeds threshold percentage of max tokens """ return self.token_usage_percentage / 100 > threshold
[docs] @classmethod def from_engine(cls, engine: AugLLMConfig, **kwargs) -> LLMState: """Create LLMState from an AugLLMConfig engine. Args: engine: The AugLLMConfig engine **kwargs: Additional fields for the state Returns: New LLMState instance """ return cls(engine=engine, **kwargs)