Source code for haive.core.graph.node.composer.protocols
"""Protocols for extract and update functions in NodeSchemaComposer.
This module defines the protocol interfaces that extract and update functions
must implement to be compatible with the NodeSchemaComposer system.
"""
from typing import Any, Protocol, TypeVar
from pydantic import BaseModel
# Type variables for protocol generics
TState = TypeVar("TState", bound=BaseModel)
TInput = TypeVar("TInput")
TOutput = TypeVar("TOutput")
[docs]
class ExtractFunction(Protocol[TState, TInput]):
"""Protocol for extract functions.
Extract functions take a state object and configuration, returning
the extracted input data that will be passed to the node's processing.
Examples:
def extract_messages(state: MessagesState, config: Dict[str, Any]) -> List[BaseMessage]:
field_name = config.get("field_name", "messages")
return getattr(state, field_name, [])
def extract_with_projection(state: MultiAgentState, config: Dict[str, Any]) -> Dict[str, Any]:
# Complex projection logic
return projected_state
"""
def __call__(self, state: TState, config: dict[str, Any]) -> TInput:
"""Extract input from state.
Args:
state: State object to extract from (Pydantic model, dict, etc.)
config: Configuration dictionary with extraction parameters
Returns:
Extracted input data for the node
"""
...
[docs]
class UpdateFunction(Protocol[TState, TOutput]):
"""Protocol for update functions.
Update functions take the result from node processing along with the
original state and configuration, returning a dictionary of state updates.
Examples:
def update_messages(result: AIMessage, state: MessagesState, config: Dict[str, Any]) -> Dict[str, Any]:
messages = list(getattr(state, "messages", []))
messages.append(result)
return {"messages": messages}
def update_type_aware(result: Any, state: BaseModel, config: Dict[str, Any]) -> Dict[str, Any]:
# Smart type-based updates
return update_dict
"""
def __call__(
self, result: TOutput, state: TState, config: dict[str, Any]
) -> dict[str, Any]:
"""Create state update from result.
Args:
result: Result from node processing (message, dict, model, etc.)
state: Original state object for context
config: Configuration dictionary with update parameters
Returns:
Dictionary of state field updates to apply
"""
...
[docs]
class TransformFunction(Protocol):
"""Protocol for transform functions.
Transform functions are used in field mapping pipelines to modify
values during extraction or update operations.
Examples:
def uppercase(value: str) -> str:
return value.upper() if isinstance(value, str) else str(value).upper()
def parse_json(value: str) -> Any:
import json
return json.loads(value)
"""
def __call__(self, value: Any) -> Any:
"""Transform a value.
Args:
value: Input value to transform
Returns:
Transformed value
"""
...