Source code for haive.core.graph.branches.utils

"""Utility functions for working with branches."""

import inspect
import logging
import re
from typing import Any

logger = logging.getLogger(__name__)


[docs] def extract_field(state: Any, field_path: str) -> Any: """Extract a field value from state using a path. Supports: - Simple keys: "fieldname" - Nested paths: "nested.field.path" - Array indexing: "array.0.field" - Special paths: "messages.last.content" Args: state: State object field_path: Path to the field Returns: Extracted value or None if not found """ if field_path is None: return None # Handle special keys if field_path == "messages.last.content": messages = get_field_value(state, "messages") if not messages or not isinstance(messages, list | tuple) or not messages: return None last_message = messages[-1] return get_field_value(last_message, "content") # Handle dot notation if "." in field_path: parts = field_path.split(".") current = state for part in parts: # Try to convert numeric indices if part.isdigit(): part = int(part) elif part == "last" and isinstance(current, list | tuple) and current: part = -1 # Get next level current = get_field_value(current, part) if current is None: return None return current # Simple key return get_field_value(state, field_path)
[docs] def get_field_value(obj: Any, key: Any) -> Any: """Get a field value using various access methods. Args: obj: Object to extract from key: Key or index to extract Returns: Extracted value or None if not found """ # Handle indexing for lists and tuples if isinstance(obj, list | tuple): try: if isinstance(key, int) and ( 0 <= key < len(obj) or (key < 0 and abs(key) <= len(obj)) ): return obj[key] except (IndexError, TypeError): return None # Try attribute access (for objects and Pydantic models) if hasattr(obj, key): return getattr(obj, key) # Try dictionary access try: return obj[key] except (KeyError, TypeError, AttributeError): pass # Try get method for dictionaries if hasattr(obj, "get") and callable(obj.get): try: return obj.get(key) except Exception: pass return None
[docs] def extract_fields_from_function(func: callable) -> set[str]: """Extract field references from a function. Args: func: Function to analyze Returns: Set of field names referenced """ fields = set() if not func: return fields try: source = inspect.getsource(func) # Look for state["field"] or state.field patterns dict_refs = re.findall(r'state\[[\'"]([\w]+)[\'"]', source) attr_refs = re.findall(r"state\.([\w]+)", source) fields.update(dict_refs) fields.update(attr_refs) except (OSError, TypeError): pass return fields
[docs] def extract_base_field(field_path: str) -> str: """Extract the base field from a path.""" if not field_path: return field_path return field_path.split(".")[0] if "." in field_path else field_path