Source code for haive.core.graph.node.agent_node

# haive/agents/multi/agent_node.py
"""Agent-specific node configurations for multi-agent systems.

This module provides node configurations that properly handle:
- Agent state isolation and merging
- Private state schema management
- Agent coordination through meta state
"""

import logging
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

if TYPE_CHECKING:
    from haive.agents.base.agent import Agent
else:
    # Placeholder for runtime
    Agent = Any

from langchain_core.messages import BaseMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, Field

from haive.core.engine.base.types import EngineType
from haive.core.graph.node.base_config import NodeConfig
from haive.core.graph.node.engine_node import EngineNodeConfig
from haive.core.graph.node.types import NodeType

logger = logging.getLogger(__name__)


[docs] class AgentNodeConfig(EngineNodeConfig): """Node configuration specifically for agents in multi-agent systems. This extends EngineNodeConfig to: - Properly handle agent as the engine - Manage private agent state schemas - Coordinate through meta state - Handle state transformation between global and agent-specific schemas """ # Override engine to be agent engine: "Agent" = Field(description="The agent to execute") # Agent-specific fields private_state_schema: Optional[type[BaseModel]] = Field( default=None, description="Private state schema for this agent" ) # State transformation extract_private_state: bool = Field( default=True, description="Whether to extract agent's private state before execution", ) merge_agent_output: bool = Field( default=True, description="Whether to merge agent output back to global state" ) # Meta state tracking update_meta_state: bool = Field( default=True, description="Whether to update meta state with agent execution info", ) def __init__(self, **data) -> None: """Initialize with agent as engine.""" # Ensure we have an agent if "agent" in data and "engine" not in data: data["engine"] = data.pop("agent") # Set node type to CALLABLE (agents are callable) if "node_type" not in data: data["node_type"] = NodeType.CALLABLE super().__init__(**data) def __call__( self, state: dict[str, Any] | BaseModel, config: Optional[RunnableConfig] = None, ) -> dict[str, Any]: """Execute the agent with proper state management. This method: 1. Updates meta state to track agent start 2. Extracts relevant fields for agent's private schema 3. Executes the agent 4. Merges results back to global state 5. Updates meta state with completion info """ logger.info("=" * 80) logger.info(f"AGENT NODE EXECUTION: {self.name}") logger.info("=" * 80) agent = self.engine agent_id = getattr(agent, "id", agent.name) agent_name = getattr(agent, "name", agent.__class__.__name__) logger.info("Step 1: Agent Details") logger.info(f" Agent Name: {agent_name}") logger.info(f" Agent ID: {agent_id}") logger.info(f" Agent Type: {type(agent).__name__}") logger.info( f" Has compiled graph: {hasattr(agent, '_app') and agent._app is not None}" ) # Log incoming state logger.info("Step 2: Incoming State Analysis") logger.info(f" State type: {type(state).__name__}") # Handle both dict and Pydantic model states if isinstance(state, dict): logger.info(f" State keys: {list(state.keys())}") state_dict = state else: # It's a Pydantic model, extract actual messages logger.info(" State is Pydantic model, extracting messages") state_dict = state.model_dump() # IMPORTANT: For messages, keep the actual BaseMessage objects, # don't serialize them if hasattr(state, "messages"): original_messages = state.messages logger.info(f"INCOMING MESSAGES TYPE: {type(original_messages)}") # Get the actual message objects if hasattr(original_messages, "root"): # MessageList with root attribute actual_messages = original_messages.root logger.info(f" Extracted from .root: {type(actual_messages)}") elif isinstance(original_messages, (list, tuple)): # Direct list/tuple of messages actual_messages = list(original_messages) logger.info(f" Direct list/tuple: {type(actual_messages)}") else: # Try to iterate try: actual_messages = list(original_messages) logger.info(f" Converted to list: {type(actual_messages)}") except BaseException: logger.warning( f"Cannot iterate over messages of type {type(original_messages)}" ) actual_messages = [] # Check what we actually got logger.info("FINAL MESSAGE TYPES:") for i, msg in enumerate(actual_messages): logger.info( f" Message {i}: {type(msg)} (is BaseMessage: {isinstance(msg, BaseMessage)})" ) if isinstance(msg, dict): logger.warning(f" DICT MESSAGE: {list(msg.keys())}") if "tool_call_id" in msg: logger.info( f" Dict has tool_call_id: {msg['tool_call_id']}" ) else: logger.warning(" Dict missing tool_call_id!") elif hasattr(msg, "tool_call_id"): logger.info( f" BaseMessage tool_call_id: {getattr(msg, 'tool_call_id', 'None')}" ) # Keep actual BaseMessage objects - don't serialize them! state_dict["messages"] = actual_messages logger.info( f" Extracted {len(actual_messages)} actual message objects" ) # Log message types for debugging for i, msg in enumerate(actual_messages): logger.info(f" Message {i}: {type(msg).__name__}") if isinstance(msg, ToolMessage): logger.info( f" ToolMessage name: {getattr(msg, 'name', 'None')}" ) logger.info( f" ToolMessage tool_call_id: {getattr(msg, 'tool_call_id', 'None')}" ) elif hasattr(msg, "__dict__"): logger.info(f" Message data: {msg.__dict__}") logger.info(f" State keys: {list(state_dict.keys())}") # Update state to be the dict for the rest of the method state = state_dict for key, value in state_dict.items(): if isinstance(value, list) and key == "messages": logger.info(f" {key}: {len(value)} messages") else: value_str = ( str(value)[:100] + "..." if len(str(value)) > 100 else str(value) ) logger.info(f" {key}: {type(value).__name__} = {value_str}") # 1. Update meta state - agent starting if self.update_meta_state and "meta_state" in state: meta_state = state.get("meta_state") if meta_state and hasattr(meta_state, "record_agent_start"): meta_state.record_agent_start(agent_id, agent_name) try: # 2. Prepare agent input using agent's own state schema logger.info("Step 3: Preparing Agent Input with Agent's Own State Schema") # Use agent's own state schema instead of multi-agent composed # schema if hasattr(agent, "state_schema") and agent.state_schema: logger.info( f" Using agent's own state schema: {agent.state_schema.__name__}" ) # Create agent-specific state from multi-agent state agent_state_fields = {} # Extract fields that the agent's state schema expects for field_name, _field_info in agent.state_schema.model_fields.items(): if field_name in state: agent_state_fields[field_name] = state[field_name] logger.debug( f" Extracted field '{field_name}' for agent state" ) # IMPORTANT: Ensure engines are included if the agent schema # expects them if ( "engines" in agent.state_schema.model_fields and "engines" not in agent_state_fields ): # Use the engines from the parent state if available if state.get("engines"): agent_state_fields["engines"] = state["engines"] logger.debug( f" Using engines from parent state: {list(state['engines'].keys())}" ) elif hasattr(agent, "engines") and agent.engines: # Otherwise use the agent's own engines agent_state_fields["engines"] = agent.engines logger.debug( f" Using agent's own engines: {list(agent.engines.keys())}" ) else: # Empty dict as last resort agent_state_fields["engines"] = {} logger.warning(" No engines found for agent state") # Create instance of agent's own state schema try: agent_specific_state = agent.state_schema(**agent_state_fields) # Convert to dict for agent input agent_input = agent_specific_state.model_dump() # IMPORTANT: Do NOT override engines with actual engine objects # Keep the serialized version from state to avoid msgpack errors # The engines in agent_input are already properly # serialized if agent_input.get("engines"): logger.info( f" Using serialized engines from state: {list(agent_input['engines'].keys())}" ) # IMPORTANT: Also serialize tools to avoid msgpack errors # Tools often contain Pydantic classes in args_schema that # can't be serialized if agent_input.get("tools"): serialized_tools = [] for tool in agent_input["tools"]: if tool is not None: # Serialize tools to avoid Pydantic class # issues if hasattr(tool, "model_dump"): try: tool_dict = tool.model_dump( mode="json", exclude_none=True ) # Clean up args_schema - it's usually a # Pydantic class if tool_dict.get("args_schema"): if hasattr( tool_dict["args_schema"], "__name__" ): tool_dict["args_schema"] = ( f"<PydanticModel:{tool_dict['args_schema'].__name__}>" ) else: tool_dict["args_schema"] = None serialized_tools.append(tool_dict) except Exception as e: logger.warning(f"Failed to serialize tool: {e}") # Fallback: basic tool info serialized_tools.append( { "name": getattr( tool, "name", str(tool) ), "description": getattr( tool, "description", "" ), "type": "tool", } ) else: # Tool doesn't have model_dump, create # basic dict serialized_tools.append( { "name": getattr(tool, "name", str(tool)), "description": getattr( tool, "description", "" ), "type": "tool", } ) agent_input["tools"] = serialized_tools logger.info( f" Serialized {len(serialized_tools)} tools for agent input" ) logger.info( " Created agent-specific state with agent's own tools/schemas" ) except Exception as e: logger.warning(f" Could not create agent-specific state: {e}") logger.warning( " Falling back to prepared input from multi-agent state" ) agent_input = self._prepare_agent_input(state, agent) else: logger.info(" Agent has no state_schema, using prepared input") agent_input = self._prepare_agent_input(state, agent) logger.info(f" Final agent input keys: {list(agent_input.keys())}") logger.info(f" Agent input type: {type(agent_input)}") logger.info(f" Is dict: {isinstance(agent_input, dict)}") logger.info(f" Is BaseModel: {isinstance(agent_input, BaseModel)}") for key, value in agent_input.items(): if isinstance(value, list) and key == "messages": logger.info(f" {key}: {len(value)} messages") elif key == "engines" and isinstance(value, dict): logger.info(f" {key}: dict with {len(value)} engines") for eng_name, eng in value.items(): logger.info( f" - {eng_name}: {type(eng)} (is dict: {isinstance(eng, dict)})" ) if hasattr(eng, "tools"): logger.info( f" Has tools attribute: {len(getattr(eng, 'tools', []))} tools" ) for tool in getattr(eng, "tools", [])[ :2 ]: # Show first 2 tools logger.info(f" Tool type: {type(tool)}") else: value_str = ( str(value)[:100] + "..." if len(str(value)) > 100 else str(value) ) logger.info(f" {key}: {type(value).__name__} = {value_str}") # 3. Clean agent's engine tools before execution to prevent # contamination logger.info("Step 4: Cleaning Agent Tools (preventing contamination)") original_tools = None original_tool_routes = None if hasattr(agent, "engine") and agent.engine: engine = agent.engine # Backup original tools if hasattr(engine, "tools"): original_tools = engine.tools.copy() if engine.tools else [] logger.info( f" Original tools: {[getattr(t, 'name', str(t)) for t in original_tools]}" ) if hasattr(engine, "tool_routes"): original_tool_routes = ( engine.tool_routes.copy() if engine.tool_routes else {} ) # Filter tools to only include legitimate ones (not Pydantic # models) if hasattr(engine, "tools") and engine.tools: clean_tools = [] clean_routes = {} for tool in engine.tools: tool_name = getattr( tool, "name", getattr(tool, "__name__", str(tool)) ) # Skip Pydantic model classes that shouldn't be in # tools if hasattr(tool, "__bases__") and any( "BaseModel" in str(base) for base in tool.__bases__ ): logger.info( f" Filtering OUT Pydantic model from engine tools: {tool_name}" ) continue # Skip non-callable items if isinstance(tool, type) and hasattr(tool, "model_fields"): logger.info( f" Filtering OUT Pydantic model class from engine tools: {tool_name}" ) continue # Keep legitimate tools if callable(tool) or hasattr(tool, "invoke"): clean_tools.append(tool) if ( original_tool_routes and tool_name in original_tool_routes ): clean_routes[tool_name] = original_tool_routes[ tool_name ] logger.info(f" Keeping legitimate tool: {tool_name}") # Apply cleaned tools to engine engine.tools = clean_tools if hasattr(engine, "tool_routes"): engine.tool_routes = clean_routes logger.info( f" Cleaned tools: {[getattr(t, 'name', str(t)) for t in clean_tools]}" ) logger.info(f" Cleaned routes: {list(clean_routes.keys())}") # 4. Execute agent with clean tools logger.info("Step 5: Executing Agent") logger.info( f" Method: {'compiled graph' if hasattr(agent, '_app') and agent._app else 'invoke method'}" ) try: # Check if agent has compiled graph if hasattr(agent, "_app") and agent._app: # Use compiled graph logger.info(" Using agent's compiled graph (_app)") result = agent._app.invoke(agent_input, config) else: # Use agent's invoke method logger.info(" Using agent's invoke method") logger.info( f" About to invoke agent with input type: {type(agent_input)}" ) logger.info(f" Input is dict: {isinstance(agent_input, dict)}") logger.info( f" Input is BaseModel: {isinstance(agent_input, BaseModel)}" ) if isinstance(agent_input, dict) and "engines" in agent_input: logger.info( f" Engines in input: {list(agent_input['engines'].keys())}" ) for eng_name, eng in agent_input["engines"].items(): logger.info(f" Engine {eng_name} type: {type(eng)}") result = agent.invoke(agent_input, config) finally: # Restore original tools after execution if ( original_tools is not None and hasattr(agent, "engine") and agent.engine ): agent.engine.tools = original_tools if original_tool_routes is not None and hasattr( agent.engine, "tool_routes" ): agent.engine.tool_routes = original_tool_routes logger.debug(" Restored original tools after execution") logger.info("Step 5: Agent Result Analysis") logger.info(f" Result type: {type(result).__name__}") if isinstance(result, dict): logger.info(f" Result keys: {list(result.keys())}") for key, value in result.items(): if isinstance(value, list) and key in [ "messages", "retrieved_documents", ]: logger.info(f" {key}: {len(value)} items") else: value_str = ( str(value)[:100] + "..." if len(str(value)) > 100 else str(value) ) logger.info(f" {key}: {type(value).__name__}") # 4. Process agent output logger.info("Step 6: Processing Agent Output") state_update = self._process_agent_output(result, state, agent) logger.info(f" State update keys: {list(state_update.keys())}") # 5. Update meta state - agent completed if self.update_meta_state and "meta_state" in state: meta_state = state.get("meta_state") if meta_state and hasattr(meta_state, "record_agent_completion"): meta_state.record_agent_completion(agent_id, result) logger.info(f"✅ AGENT NODE COMPLETED: {self.name}") return state_update except Exception as e: logger.exception(f"❌ Error executing agent {agent_name}: {e}") logger.exception(f"Error type: {type(e).__name__}") import traceback logger.exception(f"Traceback:\n{traceback.format_exc()}") # Update meta state - agent error if self.update_meta_state and "meta_state" in state: meta_state = state.get("meta_state") if meta_state and hasattr(meta_state, "record_agent_error"): meta_state.record_agent_error(agent_id, str(e)) raise def _prepare_agent_input( self, state: dict[str, Any], agent: "Agent" ) -> dict[str, Any]: """Prepare input for agent execution. If agent has a private state schema, extract only relevant fields. Otherwise, pass appropriate fields based on agent's input schema. """ logger.debug("=== _prepare_agent_input called ===") logger.debug(f" Agent: {agent.name}") logger.debug(f" Private state schema: {self.private_state_schema}") logger.debug(f" Extract private state: {self.extract_private_state}") logger.debug(f" Has input_schema: {hasattr(agent, 'input_schema')}") if hasattr(agent, "input_schema"): logger.debug(f" Input schema type: {type(agent.input_schema)}") logger.debug(f" Input schema: {agent.input_schema}") # If we have a private state schema, use it if self.private_state_schema and self.extract_private_state: logger.debug(f"Using private state extraction for {agent.name}") # Create instance of private schema from global state relevant_fields = {} for field_name in self.private_state_schema.model_fields: if field_name in state: relevant_fields[field_name] = state[field_name] # Always include messages if available if "messages" in state and "messages" not in relevant_fields: relevant_fields["messages"] = state["messages"] return relevant_fields # Otherwise, use agent's input schema or heuristics if hasattr(agent, "input_schema") and agent.input_schema: logger.debug("Using agent's input_schema") logger.debug( f" Input schema fields: {list(agent.input_schema.model_fields.keys())}" ) # Extract fields based on input schema input_fields = {} for field_name in agent.input_schema.model_fields: if field_name in state: input_fields[field_name] = state[field_name] logger.debug(f" Added field '{field_name}' from state") else: logger.debug(f" Field '{field_name}' not found in state") logger.debug(f" Final input fields: {list(input_fields.keys())}") return input_fields # Default: pass messages and any fields the agent expects default_input = {} # Always include messages for agents if "messages" in state: default_input["messages"] = state["messages"] # Include common fields agents might need common_fields = ["query", "question", "input", "context", "documents"] for field in common_fields: if field in state: default_input[field] = state[field] # If agent has get_input_fields method, use it if hasattr(agent, "get_input_fields"): try: expected_fields = agent.get_input_fields() for field_name in expected_fields: if field_name in state: default_input[field_name] = state[field_name] except Exception as e: logger.debug(f"Could not get input fields from agent: {e}") logger.debug(f" Final default input: {list(default_input.keys())}") return default_input if default_input else state def _process_agent_output( self, result: Any, state: dict[str, Any], agent: "Agent" ) -> dict[str, Any]: """Process agent output and merge with global state. Handles various output formats and ensures proper state updates. """ # Start with empty update state_update = {} # Handle different result types if isinstance(result, dict): # Direct dictionary result state_update = result elif isinstance(result, BaseModel): # Pydantic model result - preserve actual message objects state_update = result.model_dump() # CRITICAL: If the result has messages, preserve the actual # BaseMessage objects if hasattr(result, "messages") and result.messages: logger.info( f"Preserving {len(result.messages)} actual message objects from agent result" ) # Debug: Check what types the messages actually are for i, msg in enumerate(result.messages): logger.info(f" Result message {i}: {type(msg).__name__}") if isinstance(msg, ToolMessage): logger.info( f" ToolMessage tool_call_id: {getattr(msg, 'tool_call_id', 'None')}" ) elif isinstance(msg, dict): logger.warning(f" Message is dict, not BaseMessage: {msg}") # Keep the actual BaseMessage objects instead of serialized # dicts state_update["messages"] = result.messages logger.info( "STATE UPDATE: Setting messages to actual BaseMessage objects" ) for i, msg in enumerate(result.messages): if hasattr(msg, "tool_call_id"): logger.info( f" Storing ToolMessage {i} with tool_call_id={getattr(msg, 'tool_call_id', 'None')}" ) elif isinstance(result, str): # String result - check if agent outputs to specific field if hasattr(agent, "output_field_name"): state_update[agent.output_field_name] = result else: # Default to agent_output field state_update["agent_output"] = result elif result is None: # No output logger.debug(f"Agent {agent.name} returned None") else: # Unknown type - store as agent_output logger.warning(f"Unknown result type from agent: {type(result)}") state_update["agent_output"] = result # Store agent-specific output in meta state if self.update_meta_state and "meta_state" in state: agent_id = getattr(agent, "id", agent.name) # Store in agent_outputs field if "agent_outputs" not in state_update: state_update["agent_outputs"] = state.get("agent_outputs", {}) state_update["agent_outputs"][agent_id] = result # Ensure messages are preserved/updated correctly if "messages" in state_update and "messages" in state: # If both have messages, we need to handle carefully existing_messages = state.get("messages", []) new_messages = state_update.get("messages", []) # Simple approach: if new messages is a complete replacement, use it # Otherwise, assume we should append if isinstance(new_messages, list) and isinstance(existing_messages, list): if len(new_messages) >= len(existing_messages): # Looks like a complete replacement - use new messages # as-is state_update["messages"] = new_messages else: # Looks like just new messages to append state_update["messages"] = existing_messages + new_messages return state_update
[docs] class CoordinatorNodeConfig(NodeConfig): """Coordinator node for parallel agent execution. Handles fan-out and aggregation of parallel agent execution. """ node_type: NodeType = Field( default=NodeType.CALLABLE, description="Coordinator is a callable node" ) agents: list[Agent] = Field(description="Agents to coordinate") mode: Literal["fanout", "aggregate"] = Field(description="Coordination mode") def __call__( self, state: dict[str, Any], config: Optional[RunnableConfig] = None ) -> dict[str, Any] | list[dict[str, Any]]: """Execute coordination logic. For fanout: Returns list of states for each agent For aggregate: Combines results from all agents """ if self.mode == "fanout": # Create individual states for each agent logger.info("Fanning out to parallel agents") # Mark agents as ready for parallel execution if "meta_state" in state: meta_state = state.get("meta_state") if meta_state and hasattr(meta_state, "update_workflow_stage"): meta_state.update_workflow_stage("parallel_execution") # For now, just return the state - the graph edges handle routing # In the future, we can use Send commands with proper annotations return state if self.mode == "aggregate": # Aggregate results from parallel execution logger.info("Aggregating results from parallel agents") # The state should have agent_outputs populated if "agent_outputs" in state: logger.debug(f"Found outputs from {len(state['agent_outputs'])} agents") # Update workflow stage if "meta_state" in state: meta_state = state.get("meta_state") if meta_state and hasattr(meta_state, "update_workflow_stage"): meta_state.update_workflow_stage("aggregation_complete") return state raise ValueError(f"Unknown coordination mode: {self.mode}")
# Update engine_node.py to route agents to AgentNodeConfig
[docs] def create_node_for_engine( engine: Union["Agent", Any], name: str, **kwargs ) -> Union[AgentNodeConfig, EngineNodeConfig]: """Factory function to create appropriate node config for an engine/agent. Routes agents to AgentNodeConfig, others to EngineNodeConfig. """ # Check if it's an agent if isinstance(engine, Agent) or ( hasattr(engine, "engine_type") and engine.engine_type == EngineType.AGENT ): return AgentNodeConfig(name=name, engine=engine, **kwargs) return EngineNodeConfig(name=name, engine=engine, **kwargs)