"""Branch system for dynamic routing based on state values."""
from collections.abc import Callable
from typing import Any
from haive.core.graph.branches.branch import Branch
from haive.core.graph.branches.send_mapping import SendGenerator, SendMapping
from haive.core.graph.branches.types import BranchMode, ComparisonType
from haive.core.graph.common.references import CallableReference
# Import from common utilities
from haive.core.graph.common.types import StateLike
# Factory functions for common branch types
[docs]
def key_equals(
key: str, value: Any, true_dest: str = "continue", false_dest: str = "END"
) -> Branch:
"""Create a Branch that checks if a key equals a value."""
return Branch(
key=key,
value=value,
comparison=ComparisonType.EQUALS,
destinations={True: true_dest, False: false_dest},
)
[docs]
def key_exists(
key: str, true_dest: str = "continue", false_dest: str = "END"
) -> Branch:
"""Create a Branch that checks if a key exists in the state."""
return Branch(
key=key,
comparison=ComparisonType.EXISTS,
destinations={True: true_dest, False: false_dest},
)
[docs]
def from_function(
function: Callable[[StateLike], bool | str],
destinations: dict[bool | str, str] | None = None,
default: str = "END",
) -> Branch:
"""Create a Branch from a function."""
return Branch(
function_ref=CallableReference.from_callable(function),
destinations=destinations,
default=default,
mode=BranchMode.FUNCTION,
)
[docs]
def chain(*branches: Branch, default: str = "END") -> Branch:
"""Chain multiple branches together, evaluating them in sequence."""
return Branch(chain_branches=list(branches), mode=BranchMode.CHAIN, default=default)
[docs]
def conditional(
condition: Callable[[StateLike], bool],
if_true: str | Branch,
if_false: str | Branch,
default: str = "END",
) -> Branch:
"""Create a Branch with conditional evaluation."""
return Branch(
condition_ref=CallableReference.from_callable(condition),
true_branch=if_true,
false_branch=if_false,
mode=BranchMode.CONDITION,
default=default,
)
[docs]
def message_contains(
text: str,
true_dest: str = "continue",
false_dest: str = "END",
message_key: str = "messages",
) -> Branch:
"""Create a Branch that checks if the last message contains specific text."""
return Branch(
value=text,
comparison=ComparisonType.MESSAGE_CONTAINS,
destinations={True: true_dest, False: false_dest},
message_key=message_key,
)
[docs]
def send_mapper(
function: Callable[[StateLike], list[Any]] | None = None,
mappings: list[SendMapping] | None = None,
generators: list[SendGenerator] | None = None,
) -> Branch:
"""Create a Branch that generates Send objects."""
function_ref = CallableReference.from_callable(function) if function else None
return Branch(
function_ref=function_ref,
send_mappings=mappings or [],
send_generators=generators or [],
mode=BranchMode.SEND_MAPPER,
)
[docs]
def create_from_send_function(
mapper_function: Callable[[StateLike], list[Any]],
) -> Branch:
"""Create a Send mapper branch from a function that returns Send objects."""
return send_mapper(function=mapper_function)