"""Engine management mixin for tracking and accessing Haive engines.
This module provides a sophisticated mixin for managing different types of
engines within the Haive framework. It enables tracking engines by name and type,
collecting usage statistics, and provides rich visualization capabilities.
Usage:
from pydantic import BaseModel
from haive.core.common.mixins import EngineStateMixin
from haive.core.engine.llm import LLMEngine
class MyState(EngineStateMixin, BaseModel):
# Other fields
pass
# Create state and add engines
state = MyState()
# Add an LLM engine
llm = LLMEngine(name="my_llm", model="gpt-4", provider="openai")
state.add_engine(llm)
# Later, retrieve the engine
retrieved_llm = state.get_engine("my_llm")
# Get all LLM engines
all_llms = state.get_llms()
# Display engine information
state.display_engines()
"""
import json
import logging
from datetime import datetime
from typing import Any, Self
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from rich.console import Console
from rich.panel import Panel
from rich.syntax import Syntax
from rich.table import Table
from rich.tree import Tree
from haive.core.engine.base import Engine
from haive.core.engine.base.types import EngineType
logger = logging.getLogger(__name__)
[docs]
class EngineStateMixin(BaseModel):
"""Mixin providing comprehensive engine management capabilities with validation.
This mixin allows StateSchema to manage engines by name and type, providing
rich access patterns, performance tracking, and debugging capabilities.
The mixin maintains multiple indexes for efficient access:
- By name: Fast lookup of specific engines
- By type: Group engines by their type (LLM, Retriever, etc.)
It also tracks metadata about engines including access patterns and
performance metrics, and provides rich visualization for debugging.
Attributes:
engines: Dictionary of engines indexed by name.
engines_by_type: Dictionary of engines organized by type.
engine_metadata: Dictionary containing additional metadata for each engine.
"""
# Main storage - using proper field names (no leading underscore)
engines: dict[str, Engine] = Field(
default_factory=dict, description="Engines indexed by name"
)
engines_by_type: dict[EngineType, list[Engine]] = Field(
default_factory=lambda: {engine_type: [] for engine_type in EngineType},
description="Engines organized by type for quick access",
)
# Track engine metadata and relationships
engine_metadata: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Additional metadata for each engine"
)
# Private attributes for internal tracking
_engine_access_log: list[dict[str, Any]] = PrivateAttr(default_factory=list)
_engine_performance_metrics: dict[str, dict[str, float]] = PrivateAttr(
default_factory=dict
)
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
[docs]
@model_validator(mode="after")
def validate_and_organize_engines(self) -> Self:
"""Ensure engines are properly organized by type and validated.
This validator rebuilds the engines_by_type index to ensure consistency
and initializes metadata for any engines that don't have it yet.
Returns:
Self with validated engine organization.
"""
# Rebuild engines_by_type to ensure consistency
self.engines_by_type = {engine_type: [] for engine_type in EngineType}
for name, engine in self.engines.items():
if not isinstance(engine, Engine):
logger.warning(f"Object '{name}' is not an Engine instance")
continue
engine_type = engine.engine_type
if engine not in self.engines_by_type[engine_type]:
self.engines_by_type[engine_type].append(engine)
# Initialize metadata if not present
if name not in self.engine_metadata:
self.engine_metadata[name] = {
"added_at": datetime.now().isoformat(),
"access_count": 0,
"last_accessed": None,
"performance": {
"total_calls": 0,
"total_time": 0.0,
"avg_time": 0.0,
},
}
return self
# ===== Engine Management Methods =====
[docs]
def add_engine(self, engine: Engine, name: str | None = None) -> None:
"""Add an engine to the state with automatic organization by type.
This method adds an engine to both the main engines dictionary and the
type-based index. It also initializes metadata tracking for the engine.
Args:
engine: The engine to add.
name: Optional name override (defaults to engine.name).
"""
engine_name = name or engine.name
self.engines[engine_name] = engine
# Add to type index
if engine not in self.engines_by_type[engine.engine_type]:
self.engines_by_type[engine.engine_type].append(engine)
# Update metadata
self.engine_metadata[engine_name] = {
"added_at": datetime.now().isoformat(),
"access_count": 0,
"last_accessed": None,
"original_name": engine.name,
"type": engine.engine_type.value,
"performance": {"total_calls": 0, "total_time": 0.0, "avg_time": 0.0},
}
logger.debug(f"Added engine '{engine_name}' of type {engine.engine_type}")
[docs]
def get_engine(self, name: str) -> Engine | None:
"""Get an engine by name with access logging.
This method retrieves an engine by name and updates access metrics
including access count, timestamp, and performance data.
Args:
name: Engine name.
Returns:
Engine if found, None otherwise.
"""
import time
start_time = time.time()
engine = self.engines.get(name)
if engine:
# Update access metadata
if name in self.engine_metadata:
self.engine_metadata[name]["access_count"] += 1
self.engine_metadata[name]["last_accessed"] = datetime.now().isoformat()
# Update performance metrics
elapsed = time.time() - start_time
perf = self.engine_metadata[name]["performance"]
perf["total_calls"] += 1
perf["total_time"] += elapsed
perf["avg_time"] = perf["total_time"] / perf["total_calls"]
# Log access
self._engine_access_log.append(
{
"engine": name,
"timestamp": datetime.now().isoformat(),
"operation": "get_engine",
"duration_ms": (time.time() - start_time) * 1000,
}
)
return engine
[docs]
def get_engines_by_type(self, engine_type: EngineType) -> list[Engine]:
"""Get all engines of a specific type.
Args:
engine_type: The engine type to filter by.
Returns:
List of engines of the specified type.
"""
return self.engines_by_type.get(engine_type, [])
[docs]
def get_all_engines(self) -> dict[str, Engine]:
"""Get all engines as a dictionary.
Returns:
Dictionary of all engines by name.
"""
return self.engines.copy()
# ===== Specialized Getters =====
[docs]
def get_llms(self) -> list[Engine]:
"""Get all LLM/AugLLM engines.
Returns:
List of LLM engines.
"""
return self.get_engines_by_type(EngineType.LLM)
[docs]
def get_retrievers(self) -> list[Engine]:
"""Get all retriever engines.
Returns:
List of retriever engines.
"""
return self.get_engines_by_type(EngineType.RETRIEVER)
[docs]
def get_agents(self) -> list[Engine]:
"""Get all agent engines.
Returns:
List of agent engines.
"""
return self.get_engines_by_type(EngineType.AGENT)
[docs]
def get_vector_stores(self) -> list[Engine]:
"""Get all vector store engines.
Returns:
List of vector store engines.
"""
return self.get_engines_by_type(EngineType.VECTOR_STORE)
[docs]
def get_embeddings(self) -> list[Engine]:
"""Get all embeddings engines.
Returns:
List of embeddings engines.
"""
return self.get_engines_by_type(EngineType.EMBEDDINGS)
# ===== Engine Manipulation =====
[docs]
def update_engine(self, name: str, **kwargs) -> None:
"""Update engine attributes dynamically.
This method allows updating specific attributes of an engine
by providing them as keyword arguments.
Args:
name: Engine name.
**kwargs: Attributes to update.
Raises:
ValueError: If the engine is not found.
"""
engine = self.get_engine(name)
if not engine:
raise ValueError(f"Engine '{name}' not found")
for key, value in kwargs.items():
if hasattr(engine, key):
setattr(engine, key, value)
logger.debug(f"Updated engine '{name}' attribute '{key}' to '{value}'")
else:
logger.warning(f"Engine '{name}' has no attribute '{key}'")
[docs]
def change_provider(self, name: str, provider: str) -> None:
"""Change the provider of an LLM engine.
This is a convenience method specifically for changing the
provider of an engine (e.g., switching from 'openai' to 'anthropic').
Args:
name: Engine name.
provider: New provider name.
Raises:
ValueError: If the engine is not found.
"""
engine = self.get_engine(name)
if not engine:
raise ValueError(f"Engine '{name}' not found")
if hasattr(engine, "provider"):
old_provider = getattr(engine, "provider", "unknown")
engine.provider = provider
logger.info(
f"Changed engine '{name}' provider from '{old_provider}' to '{provider}'"
)
else:
logger.warning(f"Engine '{name}' does not have a provider attribute")
[docs]
def get_engine_routes(self, name: str) -> dict[str, str]:
"""Get tool routes from an engine.
This method tries to extract routing information from an engine
by checking various attribute names where routes might be stored.
Args:
name: Engine name.
Returns:
Dictionary of routes if available, empty dict otherwise.
"""
engine = self.get_engine(name)
if not engine:
return {}
# Try different attributes where routes might be stored
for attr in ["tool_routes", "routes", "routing_table", "tool_routing"]:
if hasattr(engine, attr):
routes = getattr(engine, attr)
if isinstance(routes, dict):
return routes
return {}
[docs]
def remove_engine(self, name: str) -> bool:
"""Remove an engine from the state.
This method removes an engine from both the main engines dictionary
and the type-based index, as well as removing its metadata.
Args:
name: Engine name to remove.
Returns:
True if removed, False if not found.
"""
if name not in self.engines:
return False
engine = self.engines[name]
# Remove from main storage
del self.engines[name]
# Remove from type index
if engine in self.engines_by_type[engine.engine_type]:
self.engines_by_type[engine.engine_type].remove(engine)
# Remove metadata
if name in self.engine_metadata:
del self.engine_metadata[name]
logger.info(f"Removed engine '{name}'")
return True
# ===== Rich UI Display Methods =====
[docs]
def display_engines(
self, show_metadata: bool = False, show_performance: bool = False
) -> None:
"""Display all engines in a rich tree view organized by type.
This method creates a rich visual representation of engines
organized by type, with optional metadata and performance information.
Args:
show_metadata: Whether to show additional metadata.
show_performance: Whether to show performance metrics.
"""
console = Console()
# Create tree view by type
tree = Tree("🔧 [bold cyan]Engines[/bold cyan]")
# LLMs
if llms := self.get_llms():
llm_branch = tree.add(f"🤖 [bold yellow]LLMs[/bold yellow] ({len(llms)})")
for llm in llms:
llm_info = f"[green]{llm.name}[/green]"
if hasattr(llm, "model"):
llm_info += f" [dim]({llm.model})[/dim]"
if hasattr(llm, "provider"):
llm_info += f" [blue]{llm.provider}[/blue]"
node = llm_branch.add(llm_info)
if show_metadata and llm.name in self.engine_metadata:
meta = self.engine_metadata[llm.name]
node.add(f"[dim]Access count: {meta.get('access_count', 0)}[/dim]")
node.add(
f"[dim]Last accessed: {meta.get('last_accessed', 'Never')}[/dim]"
)
if show_performance and llm.name in self.engine_metadata:
perf = self.engine_metadata[llm.name].get("performance", {})
node.add(
f"[dim]Avg response time: {perf.get('avg_time', 0):.3f}s[/dim]"
)
# Retrievers
if retrievers := self.get_retrievers():
ret_branch = tree.add(
f"🔍 [bold magenta]Retrievers[/bold magenta] ({len(retrievers)})"
)
for ret in retrievers:
ret_info = f"[green]{ret.name}[/green]"
if hasattr(ret, "retriever_type"):
ret_info += f" [dim]({ret.retriever_type})[/dim]"
ret_branch.add(ret_info)
# Agents
if agents := self.get_agents():
agent_branch = tree.add(f"🎯 [bold red]Agents[/bold red] ({len(agents)})")
for agent in agents:
agent_info = f"[green]{agent.name}[/green]"
if hasattr(agent, "engines"):
agent_info += f" [dim]({len(agent.engines)} sub-engines)[/dim]"
elif hasattr(agent, "components"):
agent_info += f" [dim]({len(agent.components)} components)[/dim]"
agent_branch.add(agent_info)
# Vector Stores
if vector_stores := self.get_vector_stores():
vs_branch = tree.add(
f"💾 [bold blue]Vector Stores[/bold blue] ({len(vector_stores)})"
)
for vs in vector_stores:
vs_info = f"[green]{vs.name}[/green]"
if hasattr(vs, "vector_store_provider"):
vs_info += f" [dim]({vs.vector_store_provider})[/dim]"
vs_branch.add(vs_info)
# Tools
if tools := self.get_tools():
tool_branch = tree.add(f"🔨 [bold green]Tools[/bold green] ({len(tools)})")
for tool in tools:
tool_info = f"[green]{tool.name}[/green]"
if hasattr(tool, "description"):
tool_info += f" [dim]{tool.description[:50]}...[/dim]"
tool_branch.add(tool_info)
# Embeddings
if embeddings := self.get_embeddings():
emb_branch = tree.add(
f"📊 [bold cyan]Embeddings[/bold cyan] ({len(embeddings)})"
)
for emb in embeddings:
emb_info = f"[green]{emb.name}[/green]"
if hasattr(emb, "model"):
emb_info += f" [dim]({emb.model})[/dim]"
emb_branch.add(emb_info)
console.print(tree)
# Show summary statistics
total_engines = len(self.engines)
console.print(f"\n[bold]Total Engines:[/bold] {total_engines}")
if show_metadata or show_performance:
# Show top accessed engines
sorted_by_access = sorted(
self.engine_metadata.items(),
key=lambda x: x[1].get("access_count", 0),
reverse=True,
)[:5]
if sorted_by_access and sorted_by_access[0][1].get("access_count", 0) > 0:
console.print("\n[bold]Top Accessed Engines:[/bold]")
for name, meta in sorted_by_access:
if meta.get("access_count", 0) > 0:
console.print(f" {name}: {meta['access_count']} accesses")
[docs]
def display_engine_details(self, name: str) -> None:
"""Display detailed information about a specific engine.
This method creates a rich visual representation of all the details
for a specific engine, including its configuration, metadata,
and performance metrics.
Args:
name: Engine name.
"""
console = Console()
engine = self.get_engine(name)
if not engine:
console.print(f"[red]Engine '{name}' not found[/red]")
return
# Create detailed panel
details = Table(show_header=False, box=None, padding=(0, 1))
details.add_column(style="bold cyan", width=20)
details.add_column()
# Basic info
details.add_row("Name:", engine.name)
details.add_row("Type:", engine.engine_type.value)
details.add_row("ID:", engine.id)
# Type-specific info
if hasattr(engine, "model"):
details.add_row("Model:", str(engine.model))
if hasattr(engine, "provider"):
details.add_row("Provider:", str(engine.provider))
if hasattr(engine, "temperature"):
details.add_row("Temperature:", str(engine.temperature))
if hasattr(engine, "retriever_type"):
details.add_row("Retriever Type:", str(engine.retriever_type))
if hasattr(engine, "vector_store_provider"):
details.add_row("Vector Store:", str(engine.vector_store_provider))
# Tools info
tools = self.get_engine_tools(name)
if tools:
details.add_row("Tools:", f"{len(tools)} available")
for i, tool in enumerate(tools[:3]):
tool_name = getattr(tool, "name", f"Tool {i + 1}")
details.add_row("", f" - {tool_name}")
if len(tools) > 3:
details.add_row("", f" ... and {len(tools) - 3} more")
# Routes info
routes = self.get_engine_routes(name)
if routes:
details.add_row("Routes:", f"{len(routes)} defined")
for route_name, route_target in list(routes.items())[:3]:
details.add_row("", f" {route_name} → {route_target}")
if len(routes) > 3:
details.add_row("", f" ... and {len(routes) - 3} more")
# Metadata
if name in self.engine_metadata:
meta = self.engine_metadata[name]
details.add_row("", "") # Spacer
details.add_row("[bold]Metadata[/bold]", "")
details.add_row("Added At:", meta.get("added_at", "Unknown")[:19])
details.add_row("Access Count:", str(meta.get("access_count", 0)))
details.add_row(
"Last Accessed:",
(
meta.get("last_accessed", "Never")[:19]
if meta.get("last_accessed")
else "Never"
),
)
# Performance metrics
if "performance" in meta:
perf = meta["performance"]
details.add_row("", "") # Spacer
details.add_row("[bold]Performance[/bold]", "")
details.add_row("Total Calls:", str(perf.get("total_calls", 0)))
details.add_row("Avg Time:", f"{perf.get('avg_time', 0):.3f}s")
panel = Panel(
details, title=f"[bold]Engine: {name}[/bold]", border_style="cyan"
)
console.print(panel)
# Show configuration if available
if hasattr(engine, "model_dump"):
config_data = engine.model_dump(exclude={"id", "name", "engine_type"})
if config_data:
config_json = json.dumps(config_data, indent=2)
syntax = Syntax(config_json, "json", theme="monokai", line_numbers=True)
console.print("\n[bold]Configuration:[/bold]")
console.print(syntax)
[docs]
def debug_engine_access(self, limit: int = 10) -> None:
"""Show debug information about engine access patterns.
This method displays access statistics and a recent access log
to help with debugging engine usage patterns and performance.
Args:
limit: Maximum number of access log entries to show.
"""
console = Console()
console.print("[bold cyan]Engine Access Debug Info[/bold cyan]\n")
# Access statistics table
stats_table = Table(title="Access Statistics")
stats_table.add_column("Engine", style="cyan")
stats_table.add_column("Type", style="magenta")
stats_table.add_column("Access Count", justify="right")
stats_table.add_column("Last Accessed")
stats_table.add_column("Avg Time (ms)", justify="right")
for name, metadata in sorted(
self.engine_metadata.items(),
key=lambda x: x[1].get("access_count", 0),
reverse=True,
):
engine = self.engines.get(name)
if engine:
perf = metadata.get("performance", {})
stats_table.add_row(
name,
engine.engine_type.value,
str(metadata.get("access_count", 0)),
(
metadata.get("last_accessed", "Never")[:19]
if metadata.get("last_accessed")
else "Never"
),
(
f"{perf.get('avg_time', 0) * 1000:.1f}"
if perf.get("avg_time", 0) > 0
else "-"
),
)
console.print(stats_table)
# Recent access log
if self._engine_access_log:
console.print(f"\n[bold]Recent Access Log[/bold] (last {limit} entries):")
log_table = Table()
log_table.add_column("Timestamp", style="dim")
log_table.add_column("Engine", style="cyan")
log_table.add_column("Operation")
log_table.add_column("Duration (ms)", justify="right")
for entry in self._engine_access_log[-limit:]:
log_table.add_row(
entry["timestamp"][:19],
entry["engine"],
entry["operation"],
f"{entry.get('duration_ms', 0):.1f}",
)
console.print(log_table)
[docs]
def get_engine_summary(self) -> dict[str, Any]:
"""Get a summary of all engines.
This method generates a summary of engine statistics including
counts by type, access patterns, and performance information.
Returns:
Dictionary with engine statistics.
"""
summary = {
"total_engines": len(self.engines),
"engines_by_type": {
engine_type.value: len(self.get_engines_by_type(engine_type))
for engine_type in EngineType
},
"most_accessed": None,
"least_accessed": None,
"total_accesses": 0,
}
if self.engine_metadata:
# Find most and least accessed
sorted_by_access = sorted(
self.engine_metadata.items(), key=lambda x: x[1].get("access_count", 0)
)
if sorted_by_access:
least = sorted_by_access[0]
most = sorted_by_access[-1]
summary["least_accessed"] = {
"name": least[0],
"count": least[1].get("access_count", 0),
}
summary["most_accessed"] = {
"name": most[0],
"count": most[1].get("access_count", 0),
}
# Total accesses
summary["total_accesses"] = sum(
meta.get("access_count", 0)
for meta in self.engine_metadata.values()
)
return summary