# src/haive/core/graph/node/engine_node.py
import logging
from typing import Any
from langgraph.types import Command, RetryPolicy, Send
from pydantic import Field
from haive.core.engine.base import Engine
from haive.core.engine.base.types import EngineType
from haive.core.graph.common.types import ConfigLike, StateLike
from haive.core.graph.node.base_config import NodeConfig
from haive.core.graph.node.types import NodeType
# Get module logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
[docs]
class EngineNodeConfig(NodeConfig):
"""Engine-based node with intelligent I/O handling and schema support.
This node config extends the base NodeConfig with engine-specific functionality
while maintaining the new input/output schema pattern for better state utilization.
"""
# Core identity
node_type: NodeType = Field(default=NodeType.ENGINE)
engine: Engine | None = Field(default=None)
# Legacy field mappings (backwards compatibility)
input_fields: list[str] | dict[str, str] | None = Field(default=None)
output_fields: list[str] | dict[str, str] | None = Field(default=None)
# Options
retry_policy: RetryPolicy | None = Field(default=None)
use_send: bool = Field(default=False)
debug: bool = Field(default=True)
[docs]
def model_post_init(self, __context) -> None:
"""Post-initialization to setup engine-specific configurations."""
# Set engine_name from engine if not provided
if self.engine and not self.engine_name:
self.engine_name = self.engine.name
# Setup default field definitions based on engine type if not provided
if self.engine and not self.input_field_defs and not self.output_field_defs:
self._setup_default_field_defs_from_engine()
# Call parent post_init to handle schema setup
super().model_post_init(__context)
def _setup_default_field_defs_from_engine(self):
"""Setup default field definitions based on engine type."""
if not self.engine:
return
from haive.core.schema.field_registry import StandardFields
# Set defaults based on engine type
if self.engine.engine_type == EngineType.LLM:
if not self.input_field_defs:
# Use the engine's derived input fields instead of hardcoding
if hasattr(self.engine, "get_input_fields"):
engine_input_fields = self.engine.get_input_fields()
logger.debug(f"Engine derived input fields: {list(engine_input_fields.keys())}")
# Convert engine fields to field definitions using
# StandardFields when possible
self.input_field_defs = []
for field_name, (
type_hint,
_field_info,
) in engine_input_fields.items():
# Try to use StandardFields for known field types
try:
if field_name == "messages":
field_def = StandardFields.messages(use_enhanced=True)
elif field_name == "query":
field_def = StandardFields.query()
elif field_name == "context":
field_def = StandardFields.context()
else:
# For other fields, try StandardFields first,
# then fallback
field_method = getattr(StandardFields, field_name, None)
if field_method and callable(field_method):
field_def = field_method()
else:
# Create generic field definition
from haive.core.schema.field_definition import (
FieldDefinition,
)
field_def = FieldDefinition(
name=field_name, type_hint=type_hint
)
self.input_field_defs.append(field_def)
except Exception as e:
logger.debug(f"Failed to create field definition for {field_name}: {e}")
# Fallback to generic field definition
from haive.core.schema.field_definition import (
FieldDefinition,
)
field_def = FieldDefinition(name=field_name, type_hint=type_hint)
self.input_field_defs.append(field_def)
else:
# Fallback to messages only
self.input_field_defs = [StandardFields.messages(use_enhanced=True)]
if not self.output_field_defs:
# LLM engines should ONLY output to messages field
# V2 structured output: Tool calls in AIMessage are extracted by downstream validation nodes
# V1 regular output: AI response appended to messages
self.output_field_defs = [
# ONLY messages field
StandardFields.messages(use_enhanced=True),
]
elif self.engine.engine_type == EngineType.RETRIEVER:
if not self.input_field_defs:
self.input_field_defs = [
StandardFields.query(),
StandardFields.messages(use_enhanced=True),
]
if not self.output_field_defs:
# Use the actual output schema from the retriever engine
# The retriever output schema has 'retrieved_documents' field
self.output_field_defs = [] # Let it use the engine's actual output schema
# Note: For LLM engines with structured output, we do NOT add structured fields to node output
# V2 structured output uses tool calls embedded in AIMessage, extracted by downstream validation nodes
# Only add structured output fields for non-LLM engines that need
# separate structured fields
if (
hasattr(self.engine, "structured_output_model")
and self.engine.structured_output_model
and self.engine.engine_type != EngineType.LLM # Skip for LLM engines
):
structured_field = StandardFields.structured_output(self.engine.structured_output_model)
if structured_field not in self.output_field_defs:
self.output_field_defs.append(structured_field)
[docs]
def __call__(self, state: StateLike, config: ConfigLike | None = None) -> Command | Send:
"""Execute engine node with schema-aware I/O handling."""
logger.info("=" * 80)
logger.info(f"ENGINE NODE EXECUTION: {self.name}")
logger.info("=" * 80)
logger.debug(f"Starting execution of node {self.name}")
try:
# Get engine and validate
logger.info("Step 1: Getting Engine")
engine = self._get_engine(state)
if not engine:
logger.error(f"No engine available for node '{self.name}'")
logger.error(f" engine_name: {self.engine_name}")
logger.error(f" direct engine: {self.engine}")
raise ValueError(f"No engine available for node '{self.name}'")
logger.info(f"✅ Got engine: {engine.name} (type: {engine.engine_type.value})")
# Extract input using schema-aware method
logger.info("Step 2: Extracting Input")
logger.debug(f"Node input_schema: {self.input_schema}")
logger.debug(f"Node input_field_defs: {self.input_field_defs}")
logger.debug(
f"State fields available: {
[field for field in dir(state) if not field.startswith('_')]
}"
)
if self.input_schema or self.input_field_defs:
input_data = self.extract_input_from_state(state)
logger.info(
f"Using schema-based input extraction: {
list(input_data.keys())
if isinstance(input_data, dict)
else type(input_data)
}"
)
logger.debug(f"Extracted input_data: {input_data}")
else:
input_data = self._extract_smart_input(state, engine)
logger.info("Using legacy smart input extraction")
logger.debug(f"Smart extracted input_data: {input_data}")
logger.debug(f"Input data type: {type(input_data).__name__}")
if isinstance(input_data, dict):
logger.debug(f"Input keys: {list(input_data.keys())}")
for key, value in input_data.items():
logger.debug(f" {key}: {type(value).__name__} = {str(value)[:100]}...")
else:
logger.debug(f"Input value: {str(input_data)[:200]}...")
# Execute with merged config
logger.info("Step 3: Executing Engine")
result = self._execute_with_config(engine, input_data, config)
# Log result details
logger.debug(f"Result type: {type(result).__name__}")
self._log_result_details(result)
# Wrap result using schema-aware method
logger.info("Step 4: Creating Update")
if self.output_schema or self.output_field_defs:
wrapped = self.create_output_for_state(result)
logger.info(
f"Using schema-based output creation: {
list(wrapped.keys()) if isinstance(wrapped, dict) else type(wrapped)
}"
)
else:
wrapped = self._wrap_smart_result(result, state, engine)
logger.info("Using legacy smart result wrapping")
# Log final update
self._log_final_update(wrapped)
logger.info(f"✅ ENGINE NODE COMPLETED: {self.name}")
return wrapped
except Exception as e:
self._log_error(e)
raise
def _get_engine(self, state: StateLike | None = None) -> Engine | None:
"""Get engine from direct reference or state's engines dict."""
logger.debug("Getting engine...")
# Priority 1: Direct engine reference
if self.engine:
logger.debug(f"Using direct engine reference: {self.engine.name}")
return self.engine
# Priority 2: Get from state's engines dict using engine_name
if self.engine_name and state:
logger.debug(f"Looking for engine_name: {self.engine_name}")
# Try to get from engines dict in state
if hasattr(state, "engines"):
engines_dict = getattr(state, "engines", {})
logger.debug(f"Found state.engines dict with {len(engines_dict)} engines")
if isinstance(engines_dict, dict):
# Log available engines
if engines_dict:
logger.debug("Available engines in state:")
for name, eng in engines_dict.items():
logger.debug(f" - {name}: {type(eng).__name__}")
if self.engine_name in engines_dict:
engine = engines_dict[self.engine_name]
if engine:
logger.info(f"✅ Found engine '{self.engine_name}' in state.engines!")
self.engine = engine # Cache it
return engine
logger.error(f"Engine '{self.engine_name}' exists but is None")
else:
logger.error(f"Engine '{self.engine_name}' not found in state.engines")
logger.error(f"Available engines: {list(engines_dict.keys())}")
else:
logger.error(f"state.engines is not a dict: {type(engines_dict)}")
else:
logger.debug("State has no 'engines' attribute")
# Also check if state is a dict with engines key
if isinstance(state, dict) and "engines" in state:
logger.debug("State is a dict, checking state['engines']...")
engines_dict = state["engines"]
if isinstance(engines_dict, dict) and self.engine_name in engines_dict:
engine = engines_dict[self.engine_name]
if engine:
logger.info(
f"✅ Found engine '{self.engine_name}' in state['engines']!"
)
self.engine = engine # Cache it
return engine
logger.error("No engine found!")
return None
def _log_result_details(self, result: Any):
"""Log details about the result."""
try:
from langchain_core.messages import AIMessage, BaseMessage
if isinstance(result, BaseMessage):
logger.info(f"✅ Result is a {type(result).__name__}")
logger.debug(f" Content: {result.content[:200]}...")
if (
isinstance(result, AIMessage)
and hasattr(result, "tool_calls")
and result.tool_calls
):
logger.debug(f" Tool Calls: {len(result.tool_calls)}")
except ImportError:
pass
if isinstance(result, dict):
logger.debug("Result is a dictionary:")
for key, value in result.items():
logger.debug(f" {key}: {type(value).__name__}")
elif isinstance(result, str):
logger.debug(f"Result is string: {result[:200]}...")
else:
logger.debug(f"Result: {str(result)[:200]}...")
def _log_final_update(self, wrapped: Command | Send):
"""Log the final wrapped update."""
logger.info("Final Update:")
if isinstance(wrapped, Command):
logger.info(" Type: Command")
logger.info(f" Goto: {wrapped.goto}")
if wrapped.update:
logger.debug(" Update dict:")
for key, value in wrapped.update.items():
if key == "messages":
if isinstance(value, list):
logger.debug(f" {key}: List with {len(value)} messages")
if value:
last_msg = value[-1]
logger.debug(f" Last message type: {type(last_msg).__name__}")
else:
logger.debug(f" {key}: {type(value).__name__}")
else:
logger.debug(f" {key}: {str(value)[:100]}...")
elif isinstance(wrapped, Send):
logger.info(" Type: Send")
logger.info(f" Node: {wrapped.node}")
logger.debug(f" Arg type: {type(wrapped.arg).__name__}")
def _extract_smart_input(self, state: StateLike, engine: Engine) -> Any:
"""Extract input using the most appropriate strategy."""
logger.debug(f"Extracting input for {engine.engine_type.value} engine...")
# Strategy 1: Explicit mapping
if self.input_fields:
logger.debug("Using explicit input field mapping")
return self._extract_mapped_input(state, self._normalize_mapping(self.input_fields))
# Strategy 2: Schema-defined inputs
schema_inputs = self._get_schema_inputs(state, engine.name)
if schema_inputs:
logger.debug(f"Using schema-defined inputs: {schema_inputs}")
return self._extract_typed_input(state, schema_inputs, engine.engine_type)
# Strategy 3: Engine-defined inputs
engine_inputs = self._get_engine_inputs(engine)
if engine_inputs:
logger.debug(f"Using engine-defined inputs: {engine_inputs}")
return self._extract_typed_input(state, engine_inputs, engine.engine_type)
# Strategy 4: Type-based defaults
logger.debug("Using type-based default extraction")
return self._extract_default_input(state, engine.engine_type)
def _extract_typed_input(
self, state: StateLike, fields: list[str], engine_type: EngineType
) -> dict[str, Any]:
"""Extract fields with type-specific intelligence."""
logger.debug(f"Extracting typed input for {engine_type.value}")
extractors = {
EngineType.RETRIEVER: self._extract_retriever_fields,
EngineType.LLM: self._extract_llm_fields,
EngineType.VECTOR_STORE: self._extract_vectorstore_fields,
EngineType.EMBEDDINGS: self._extract_embeddings_fields,
EngineType.AGENT: self._extract_agent_fields,
}
extractor = extractors.get(engine_type, self._extract_generic_fields)
return extractor(state, fields)
def _extract_retriever_fields(self, state: StateLike, fields: list[str]) -> dict[str, Any]:
"""Retriever-specific extraction: always include query, filter None values."""
logger.debug("Extracting retriever fields")
input_data = {}
for field in fields:
value = self._get_state_value(state, field)
if field == "query":
# Always include query, even if empty
input_data[field] = value or ""
logger.debug(f"Retriever query: '{value or ''}'")
elif value is not None:
# Only include other fields if they have values
input_data[field] = value
logger.debug(f"Retriever {field}: {value}")
else:
logger.debug(f"Skipping None value for {field}")
return input_data
def _extract_llm_fields(self, state: StateLike, fields: list[str]) -> dict[str, Any]:
"""LLM-specific extraction: include all fields."""
logger.debug("Extracting LLM fields")
result = {field: self._get_state_value(state, field) for field in fields}
logger.debug(f"LLM input fields: {list(result.keys())}")
return result
def _extract_vectorstore_fields(self, state: StateLike, fields: list[str]) -> dict[str, Any]:
"""Vector store extraction: filter None values except query."""
logger.debug("Extracting vector store fields")
input_data = {}
for field in fields:
value = self._get_state_value(state, field)
if field == "query" or value is not None:
input_data[field] = value
return input_data
def _extract_embeddings_fields(self, state: StateLike, fields: list[str]) -> Any:
"""Embeddings extraction: often just needs text."""
logger.debug("Extracting embeddings fields")
# Try to get text/query field first
for field in ["query", "text", "content"]:
if field in fields:
value = self._get_state_value(state, field)
if value:
logger.debug(f"Using {field} field for embeddings: {str(value)[:100]}...")
return value
# Fall back to all fields as dict
return {field: self._get_state_value(state, field) for field in fields}
def _extract_agent_fields(self, state: StateLike, fields: list[str]) -> dict[str, Any]:
"""Agent-specific extraction: include all fields, prioritize messages."""
logger.debug("Extracting agent fields")
result = {}
# Always include messages if it's in the fields
if "messages" in fields:
messages = self._get_state_value(state, "messages", [])
result["messages"] = messages
logger.debug(f"Agent messages: {len(messages) if messages else 0} messages")
# Include all other fields
for field in fields:
if field != "messages": # Already handled
value = self._get_state_value(state, field)
if value is not None:
result[field] = value
logger.debug(f"Agent {field}: {type(value).__name__}")
return result
def _extract_generic_fields(self, state: StateLike, fields: list[str]) -> dict[str, Any]:
"""Generic extraction: include non-None values."""
logger.debug("Extracting generic fields")
return {
field: value
for field in fields
if (value := self._get_state_value(state, field)) is not None
}
def _extract_default_input(self, state: StateLike, engine_type: EngineType) -> Any:
"""Default extraction when no fields are specified."""
logger.debug(f"Using default extraction for {engine_type.value}")
defaults = {
EngineType.RETRIEVER: lambda: self._extract_retriever_fields(
state, ["query", "k", "filter", "search_type", "score_threshold"]
),
EngineType.LLM: lambda: {"messages": self._get_state_value(state, "messages", [])},
EngineType.VECTOR_STORE: lambda: self._extract_vectorstore_fields(
state, ["query", "k", "filter"]
),
EngineType.EMBEDDINGS: lambda: self._get_state_value(state, "query", ""),
EngineType.AGENT: lambda: self._state_as_dict(
state
), # Agents typically need full state
}
extractor = defaults.get(engine_type, lambda: self._state_as_dict(state))
return extractor()
def _wrap_smart_result(self, result: Any, state: StateLike, engine: Engine) -> Command | Send:
"""Intelligently wrap result based on type and configuration."""
logger.debug("Wrapping result...")
# Already wrapped? Return as-is
if isinstance(result, Command | Send):
logger.debug("Result already wrapped, returning as-is")
return result
# Generate update dictionary
update = self._create_update_dict(result, state, engine)
# Return appropriate wrapper
if self.use_send and self.command_goto:
logger.debug(f"Creating Send to {self.command_goto}")
return Send(node=self.command_goto, arg=update)
logger.debug(f"Creating Command with goto={self.command_goto}")
return Command(update=update, goto=self.command_goto)
def _create_update_dict(self, result: Any, state: StateLike, engine: Engine) -> dict[str, Any]:
"""Create state update dictionary from result."""
logger.debug("Creating update dictionary...")
# Strategy 1: Explicit output mapping
if self.output_fields:
logger.debug("Using explicit output field mapping")
return self._apply_output_mapping(result)
# Strategy 2: Schema-defined outputs
schema_outputs = self._get_schema_outputs(state, engine.name)
if schema_outputs:
logger.debug(f"Using schema-defined outputs: {schema_outputs}")
return self._map_to_outputs(result, schema_outputs)
# Strategy 3: Smart type-based mapping
logger.debug("Using smart type-based result mapping")
return self._smart_result_mapping(result, state, engine.engine_type)
def _smart_result_mapping(
self, result: Any, state: StateLike, engine_type: EngineType
) -> dict[str, Any]:
"""Smart result mapping based on result type and engine type."""
logger.debug(f"Smart result mapping for {engine_type.value} engine...")
# Check if it's a message first
if self._is_message_like(result):
logger.info("✅ Result is message-like, updating messages")
return self._update_messages(result, state)
# Check if it's a string that might be a response from LLM
if isinstance(result, str) and engine_type == EngineType.LLM:
logger.info("LLM returned string, converting to AIMessage and updating messages")
try:
from langchain_core.messages import AIMessage
ai_msg = AIMessage(content=result)
return self._update_messages(ai_msg, state)
except ImportError:
logger.exception("Could not import AIMessage")
# Agent results are typically full state updates
if engine_type == EngineType.AGENT and isinstance(result, dict):
logger.info("Agent returned dict state update")
return result
# Dictionary results
if isinstance(result, dict):
logger.debug("Result is dict, returning as-is")
if "messages" in result:
logger.debug("Dict contains 'messages' key")
return result
# Engine-specific single value mapping
field_map = {
EngineType.RETRIEVER: "documents",
EngineType.LLM: "response",
EngineType.EMBEDDINGS: "embeddings",
EngineType.VECTOR_STORE: "documents",
EngineType.AGENT: "agent_output", # Generic field for agent outputs
}
field = field_map.get(engine_type, "result")
logger.debug(f"Mapping result to field: {field}")
return {field: result}
def _update_messages(self, result: Any, state: StateLike) -> dict[str, Any]:
"""Update messages list with new message(s)."""
logger.info("Updating messages list...")
# Get existing messages
existing = self._get_state_value(state, "messages", [])
logger.debug(f"Existing messages: {len(existing) if existing else 0}")
# Create new list
messages = list(existing) if existing else []
# Add new messages with engine attribution
if isinstance(result, list):
# Process each message to add engine attribution
processed_messages = []
for msg in result:
processed_msg = self._add_engine_attribution_to_message(msg)
processed_messages.append(processed_msg)
messages.extend(processed_messages)
logger.debug(f"Added {len(result)} messages with engine attribution")
else:
# Process single message to add engine attribution
processed_msg = self._add_engine_attribution_to_message(result)
messages.append(processed_msg)
logger.debug(f"Added 1 message: {type(result).__name__} with engine attribution")
logger.info(f"✅ Total messages after update: {len(messages)}")
return {"messages": messages}
def _add_engine_attribution_to_message(self, message: Any) -> Any:
"""Add engine attribution to a message if it's an AI message."""
try:
from langchain_core.messages import AIMessage
# Only add attribution to AI messages
if isinstance(message, AIMessage) and self.engine:
logger.debug(f"Adding engine attribution '{self.engine.name}' to AIMessage")
# Get existing additional_kwargs or create new dict
additional_kwargs = getattr(message, "additional_kwargs", {}).copy()
# Add engine attribution
additional_kwargs["engine_name"] = self.engine.name
# Create new AIMessage with attribution
attributed_message = AIMessage(
content=message.content,
additional_kwargs=additional_kwargs,
tool_calls=getattr(message, "tool_calls", None),
id=getattr(message, "id", None),
)
logger.debug(f"✅ Added engine attribution: {self.engine.name}")
return attributed_message
except ImportError:
logger.debug("Could not import AIMessage for attribution")
except Exception as e:
logger.debug(f"Failed to add engine attribution: {e}")
# Return original message if attribution failed or not applicable
return message
def _is_message_like(self, obj: Any) -> bool:
"""Check if object is message-like."""
try:
from langchain_core.messages import BaseMessage
is_msg = isinstance(obj, BaseMessage)
if is_msg:
logger.debug(f"✅ Object is a BaseMessage: {type(obj).__name__}")
return is_msg
except ImportError:
logger.debug("Could not import BaseMessage, checking attributes")
has_attrs = hasattr(obj, "content") and hasattr(obj, "type")
if has_attrs:
logger.debug("Object has message-like attributes")
return has_attrs
# ... rest of utility methods remain the same ...
def _get_state_value(self, state: StateLike, key: str, default: Any = None) -> Any:
"""Get value from state with fallback."""
if hasattr(state, key):
return getattr(state, key)
if isinstance(state, dict):
return state.get(key, default)
return default
def _get_schema_inputs(self, state: StateLike, engine_name: str) -> list[str] | None:
"""Get engine inputs from schema."""
if not hasattr(state, "__engine_io_mappings__") or not engine_name:
return None
return getattr(state, "__engine_io_mappings__", {}).get(engine_name, {}).get("inputs")
def _get_schema_outputs(self, state: StateLike, engine_name: str) -> list[str] | None:
"""Get engine outputs from schema."""
if not hasattr(state, "__engine_io_mappings__") or not engine_name:
return None
return getattr(state, "__engine_io_mappings__", {}).get(engine_name, {}).get("outputs")
def _get_engine_inputs(self, engine: Engine) -> list[str] | None:
"""Get input fields from engine definition."""
if hasattr(engine, "get_input_fields"):
return list(engine.get_input_fields().keys())
return None
def _extract_mapped_input(self, state: StateLike, mapping: dict[str, str]) -> dict[str, Any]:
"""Extract using explicit state->input mapping."""
logger.debug(f"Extracting with mapping: {mapping}")
return {
input_key: self._get_state_value(state, state_key)
for state_key, input_key in mapping.items()
if self._get_state_value(state, state_key) is not None
}
def _apply_output_mapping(self, result: Any) -> dict[str, Any]:
"""Apply explicit output mapping."""
mapping = self._normalize_mapping(self.output_fields)
logger.debug(f"Applying output mapping: {mapping}")
if isinstance(result, dict):
return {
state_key: result.get(result_key)
for result_key, state_key in mapping.items()
if result_key in result
}
# Single value to first mapped field
if mapping:
first_state_key = next(iter(mapping.values()))
return {first_state_key: result}
return {"result": result}
def _map_to_outputs(self, result: Any, output_fields: list[str]) -> dict[str, Any]:
"""Map result to schema output fields."""
if isinstance(result, dict):
return {field: result.get(field) for field in output_fields if field in result}
return {output_fields[0]: result} if output_fields else {"result": result}
def _execute_with_config(
self, engine: Engine, input_data: Any, config: ConfigLike | None
) -> Any:
"""Execute engine with merged configuration."""
merged_config = self._build_merged_config(config, engine)
# ================= DETAILED PRE-INVOKE LOGGING =================
logger.info("🔍 DETAILED PRE-INVOKE ANALYSIS")
logger.info(f"Engine: {engine.name} (type: {getattr(engine, 'engine_type', 'unknown')})")
# Log input data comprehensively
logger.info(f"Input data type: {type(input_data).__name__}")
if isinstance(input_data, dict):
logger.info(f"Input dict keys: {list(input_data.keys())}")
for key, value in input_data.items():
logger.info(f" 🔑 {key}: {type(value).__name__}")
if hasattr(value, "__len__") and len(value) < 200:
logger.info(f" Value: {value}")
else:
logger.info(f" Value preview: {str(value)[:100]}...")
else:
logger.info(f"Input value: {str(input_data)[:200]}...")
# Log engine-specific details
if hasattr(engine, "prompt_template") and engine.prompt_template:
template = engine.prompt_template
logger.info(f"🎯 Engine has prompt_template: {type(template).__name__}")
# Check for input variables
input_vars = getattr(template, "input_variables", [])
optional_vars = getattr(template, "optional_variables", [])
partial_vars = getattr(template, "partial_variables", {})
logger.info(f" Required input_variables: {input_vars}")
logger.info(f" Optional variables: {optional_vars}")
logger.info(f" Partial variables: {list(partial_vars.keys()) if partial_vars else []}")
# Check if input data provides the required variables
if isinstance(input_data, dict):
missing_vars = [
var for var in input_vars if var not in input_data and var not in partial_vars
]
extra_vars = [
key
for key in input_data
if key not in input_vars + optional_vars and key != "messages"
]
if missing_vars:
logger.warning(f"⚠️ MISSING template variables: {missing_vars}")
if extra_vars:
logger.info(f"📦 Extra input keys (not in template): {extra_vars}")
logger.info(
f"✅ Available template variables: {
[var for var in input_vars if var in input_data or var in partial_vars]
}"
)
# Log merged config
if merged_config:
logger.info(f"Merged config keys: {list(merged_config.keys())}")
else:
logger.info("No merged config")
logger.info("🚀 CALLING engine.invoke() NOW...")
# ================= END DETAILED LOGGING =================
# Special handling for retrievers - they need string queries
if hasattr(engine, "engine_type") and engine.engine_type.value == "retriever":
logger.info("RETRIEVER DETECTED - Special handling")
if isinstance(input_data, dict):
if "query" in input_data:
query_str = str(input_data["query"])
logger.debug(f"Extracting query string: '{query_str}'")
logger.debug(f"Other params: {[k for k in input_data if k != 'query']}")
return engine.invoke(query_str, merged_config)
logger.debug("No 'query' key in dict, using whole dict as string")
return engine.invoke(str(input_data), merged_config)
logger.debug(f"Input is not dict, converting to string: '{input_data!s}'")
return engine.invoke(str(input_data), merged_config)
logger.debug("Standard engine invoke")
return engine.invoke(input_data, merged_config)
def _build_merged_config(
self, runtime_config: ConfigLike | None, engine: Engine
) -> dict[str, Any] | None:
"""Build merged configuration."""
if not runtime_config and not self.config_overrides:
return None
config = dict(runtime_config or {})
# Ensure configurable section
config.setdefault("configurable", {})
# Apply node-level overrides
config["configurable"].update(self.config_overrides)
# Apply engine-specific overrides
engine_id = getattr(engine, "id", None)
if engine_id and self.config_overrides:
config["configurable"].setdefault("engine_configs", {})
config["configurable"]["engine_configs"].setdefault(engine_id, {})
config["configurable"]["engine_configs"][engine_id].update(self.config_overrides)
return config
def _normalize_mapping(self, fields: list[str] | dict[str, str] | None) -> dict[str, str]:
"""Normalize field mapping to dict."""
if isinstance(fields, list):
return {field: field for field in fields}
return fields or {}
def _state_as_dict(self, state: StateLike) -> dict[str, Any]:
"""Convert state to dictionary."""
if isinstance(state, dict):
return state
if hasattr(state, "model_dump"):
return state.model_dump()
if hasattr(state, "__dict__"):
return state.__dict__
return {"value": state}
def _log_error(self, error: Exception) -> None:
"""Log error with full context."""
logger.exception(f"Engine node '{self.name}' failed: {error}")
[docs]
def __repr__(self) -> str:
"""Clean string representation."""
engine_ref = self.engine.name if self.engine else self.engine_name or "None"
return f"EngineNode(name='{self.name}', engine='{engine_ref}')"
# Factory function for creating appropriate node configs
def create_engine_node_config(engine: Engine, name: str, **kwargs) -> NodeConfig:
"""Factory function to create appropriate node config based on engine type.
Routes agents to AgentNodeConfig if available, otherwise uses EngineNodeConfig.
Args:
engine: The engine to create a node for
name: Name for the node
**kwargs: Additional configuration parameters
Returns:
Appropriate NodeConfig subclass instance
"""
# Check if it's an agent
if hasattr(engine, "engine_type") and engine.engine_type == EngineType.AGENT:
try:
# Try to import AgentNodeConfig
from haive.core.graph.node.agent_node import AgentNodeConfig
logger.debug(f"Creating AgentNodeConfig for agent: {name}")
return AgentNodeConfig(name=name, engine=engine, **kwargs)
except ImportError:
# Fall back to regular EngineNodeConfig
logger.debug("AgentNodeConfig not available, using EngineNodeConfig for agent")
# Default to EngineNodeConfig
return EngineNodeConfig(name=name, engine=engine, node_type=NodeType.ENGINE, **kwargs)