"""Callable Node - Wrap any callable as a graph node.
from typing import Any
This module provides a way to wrap any Python callable (function, method, lambda)
as a proper graph node that returns Command or Send objects.
"""
import inspect
import logging
from collections.abc import Callable
from typing import Any, Self, TypeVar
from langgraph.types import Command
from pydantic import BaseModel, Field, model_validator
from haive.core.graph.common.types import ConfigLike, NodeType, StateLike
from haive.core.graph.node.base_node_config import BaseNodeConfig
from haive.core.schema.field_definition import FieldDefinition
logger = logging.getLogger(__name__)
TCallable = TypeVar("TCallable", bound=Callable)
TState = TypeVar("TState", bound=BaseModel)
[docs]
class CallableNodeConfig(BaseNodeConfig):
"""Configuration for wrapping a callable as a node.
This allows any function to be used as a graph node by:
1. Extracting required parameters from state
2. Calling the function
3. Wrapping the result in Command/Send
Examples:
Simple boolean check::
def check_threshold(messages: List[BaseMessage], threshold: int = 100) -> bool:
total_length = sum(len(msg.content) for msg in messages)
return total_length > threshold
node = CallableNodeConfig(
name="check_threshold",
callable_func=check_threshold,
goto_on_true="summarize",
goto_on_false="continue"
)
State function::
def needs_summarization(state: MessagesState) -> bool:
return state.token_count > 1000
node = CallableNodeConfig(
name="check_summary",
callable_func=needs_summarization,
extract_full_state=True
)
"""
node_type: NodeType = Field(default=NodeType.CALLABLE)
callable_func: Callable = Field(..., description="The function to wrap as a node")
result_key: str | None = Field(
default=None,
description="State key to store result in. If None, result is not stored.",
)
goto_on_true: str | None = Field(
default=None, description="Node to go to if callable returns True"
)
goto_on_false: str | None = Field(
default=None, description="Node to go to if callable returns False"
)
goto_mapping: dict[Any, str] | None = Field(
default=None, description="Map function results to node names"
)
default_goto: str | None = Field(
default=None, description="Default node if no mapping matches"
)
extract_full_state: bool = Field(
default=False, description="Pass the full state object as first parameter"
)
parameter_mapping: dict[str, str] | None = Field(
default=None,
description="Map function parameters to state fields. {'param': 'state.field'}",
)
extraction_paths: dict[str, str] | None = Field(
default=None,
description="Advanced extraction paths like 'param': 'state.nested.field[0].value'",
)
on_error: str = Field(
default="raise",
description="What to do on error: 'raise', 'return_none', 'goto_error'",
)
error_goto: str | None = Field(
default=None, description="Node to go to on error (if on_error='goto_error')"
)
[docs]
@model_validator(mode="after")
def validate_config(self) -> Self:
"""Validate the configuration."""
if not any(
[
self.goto_on_true,
self.goto_on_false,
self.goto_mapping,
self.default_goto,
self.command_goto,
]
):
raise ValueError(
"Must specify at least one of: goto_on_true/false, goto_mapping, default_goto, or command_goto"
)
return self
def __call__(self, state: StateLike, config: ConfigLike | None = None) -> Command:
"""Execute the callable and return appropriate Command."""
try:
if self.extract_full_state:
result = self.callable_func(state)
else:
kwargs = self._extract_parameters(state)
result = self.callable_func(**kwargs)
update = {}
if self.result_key:
update[self.result_key] = result
goto = self._determine_goto(result)
return Command(update=update, goto=goto)
except Exception as e:
logger.exception(f"Error in callable node '{self.name}': {e}")
if self.on_error == "raise":
raise
if self.on_error == "return_none":
return Command(
update={self.result_key: None} if self.result_key else {},
goto=self.default_goto or self.command_goto,
)
if self.on_error == "goto_error":
return Command(
update={"error": str(e)},
goto=self.error_goto or self.default_goto or self.command_goto,
)
def _extract_parameters(self, state: StateLike) -> dict[str, Any]:
"""Extract function parameters from state."""
sig = inspect.signature(self.callable_func)
kwargs = {}
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
if self.extraction_paths and param_name in self.extraction_paths:
path = self.extraction_paths[param_name]
value = self._extract_by_path(state, path)
elif self.parameter_mapping and param_name in self.parameter_mapping:
field_name = self.parameter_mapping[param_name]
value = self._get_state_value(state, field_name)
else:
value = self._get_state_value(state, param_name)
if value is None and param.default != inspect.Parameter.empty:
value = param.default
kwargs[param_name] = value
return kwargs
def _extract_by_path(self, state: StateLike, path: str) -> Any:
"""Extract value using a path like 'messages[0].content'."""
parts = path.split(".")
current = state
for part in parts:
if "[" in part and "]" in part:
field, index = part.split("[")
index = int(index.rstrip("]"))
current = self._get_state_value(current, field)
if current and hasattr(current, "__getitem__"):
current = current[index]
else:
return None
else:
current = self._get_state_value(current, part)
if current is None:
return None
return current
def _get_state_value(self, obj: Any, field: str) -> Any:
"""Get value from state object."""
if hasattr(obj, field):
return getattr(obj, field)
if hasattr(obj, "__getitem__"):
try:
return obj[field]
except (KeyError, TypeError):
pass
return None
def _determine_goto(self, result: Any) -> str | None:
"""Determine which node to go to based on result."""
if isinstance(result, bool):
if result and self.goto_on_true:
return self.goto_on_true
if not result and self.goto_on_false:
return self.goto_on_false
if self.goto_mapping and result in self.goto_mapping:
return self.goto_mapping[result]
return self.default_goto or self.command_goto
[docs]
def wrap_callable(
func: Callable, name: str | None = None, **kwargs
) -> CallableNodeConfig:
"""Convenience function to wrap a callable as a node.
Args:
func: The function to wrap
name: Node name (defaults to function name)
**kwargs: Additional CallableNodeConfig parameters
Returns:
Configured CallableNodeConfig
Examples:
node = wrap_callable(
check_threshold,
goto_on_true="summarize",
goto_on_false="continue"
)
"""
if name is None:
name = func.__name__
return CallableNodeConfig(name=name, callable_func=func, **kwargs)
[docs]
def as_node(**kwargs) -> Any:
"""Decorator to turn a function into a node.
Examples:
@as_node(goto_on_true="next", goto_on_false="retry")
def should_continue(messages: List[BaseMessage]) -> bool:
return len(messages) > 5
"""
def decorator(func: Callable) -> CallableNodeConfig:
"""Decorator.
Args:
func: [TODO: Add description]
Returns:
[TODO: Add return description]
"""
return wrap_callable(func, **kwargs)
return decorator