Source code for haive.core.common.structures.generic_tree

"""Generic Tree/Leaf model with auto-indexing and type safety.

This module provides a powerful generic tree structure that can be used for
any hierarchical data, with automatic indexing, computed properties, and
full type safety through generics.
"""

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Generic, TypeVar, Union

from pydantic import BaseModel, Field, PrivateAttr, computed_field

# Type variables for content and result types
ContentType = TypeVar("ContentType")
ResultType = TypeVar("ResultType")


[docs] class NodeStatus(str, Enum): """Status for tree nodes.""" PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" SKIPPED = "skipped"
[docs] class TreeNode(BaseModel, Generic[ContentType, ResultType], ABC): """Abstract base class for tree nodes. Provides common functionality for both leaf and branch nodes. """ # User-visible fields content: ContentType = Field(..., description="The node's content/objective") result: ResultType | None = Field(None, description="The execution result") status: NodeStatus = Field(NodeStatus.PENDING, description="Current node status") metadata: dict[str, Any] = Field( default_factory=dict, description="Additional metadata" ) # Private auto-indexing fields _index: int = PrivateAttr(default=0) _parent_index: int | None = PrivateAttr(default=None) _depth: int = PrivateAttr(default=0) _path: list[int] = PrivateAttr(default_factory=list)
[docs] @abstractmethod def is_leaf(self) -> bool: """Whether this is a leaf node."""
[docs] def set_result( self, result: ResultType, status: NodeStatus = NodeStatus.COMPLETED ) -> None: """Set the result and update status.""" self.result = result self.status = status
[docs] def mark_failed(self, error: str) -> None: """Mark node as failed with error message.""" self.status = NodeStatus.FAILED self.metadata["error"] = error
[docs] class Leaf(TreeNode[ContentType, ResultType]): """Leaf node - contains content but no children."""
[docs] def is_leaf(self) -> bool: """Is Leaf. Returns: [TODO: Add return description] """ return True
@computed_field @property def is_complete(self) -> bool: """Whether this leaf is complete.""" return self.status == NodeStatus.COMPLETED
[docs] class Branch(TreeNode[ContentType, ResultType], Generic[ContentType, ResultType]): """Branch node - contains content and children.""" children: list[ Union[Leaf[ContentType, ResultType], "Branch[ContentType, ResultType]"] ] = Field(default_factory=list, description="Child nodes (leaves or branches)") # Private indexing for children _next_child_index: int = PrivateAttr(default=0)
[docs] def is_leaf(self) -> bool: """Is Leaf. Returns: [TODO: Add return description] """ return False
[docs] def add_child( self, child: Union[Leaf[ContentType, ResultType], "Branch[ContentType, ResultType]"], auto_index: bool = True, ) -> Union[Leaf[ContentType, ResultType], "Branch[ContentType, ResultType]"]: """Add a child node with auto-indexing.""" if auto_index: child._index = self._next_child_index child._parent_index = self._index child._depth = self._depth + 1 child._path = self._path + [self._next_child_index] self._next_child_index += 1 self.children.append(child) return child
[docs] def add_leaf(self, content: ContentType) -> Leaf[ContentType, ResultType]: """Convenience method to add a leaf child.""" leaf = Leaf[ContentType, ResultType](content=content) return self.add_child(leaf)
[docs] def add_branch(self, content: ContentType) -> "Branch[ContentType, ResultType]": """Convenience method to add a branch child.""" branch = Branch[ContentType, ResultType](content=content) return self.add_child(branch)
[docs] def add_parallel_children( self, children: list[ Union[Leaf[ContentType, ResultType], "Branch[ContentType, ResultType]"] ], ) -> list[Union[Leaf[ContentType, ResultType], "Branch[ContentType, ResultType]"]]: """Add multiple children that represent parallel execution.""" base_index = self._next_child_index for i, child in enumerate(children): child._index = base_index # Same base index for parallel child._parent_index = self._index child._depth = self._depth + 1 child._path = self._path + [base_index, i] # Extra index for parallel child.metadata["parallel_group"] = base_index child.metadata["parallel_index"] = i self.children.append(child) self._next_child_index = base_index + 1 return children
# Computed properties for tree analysis @computed_field @property def total_nodes(self) -> int: """Total number of nodes in this subtree.""" count = 1 # This node for child in self.children: if isinstance(child, Branch): count += child.total_nodes else: count += 1 return count @computed_field @property def leaf_count(self) -> int: """Number of leaf nodes in this subtree.""" count = 0 for child in self.children: if isinstance(child, Branch): count += child.leaf_count else: count += 1 return count @computed_field @property def completed_count(self) -> int: """Number of completed nodes.""" count = 0 if self.status == NodeStatus.COMPLETED: count += 1 for child in self.children: if isinstance(child, Branch): count += child.completed_count elif child.status == NodeStatus.COMPLETED: count += 1 return count @computed_field @property def failed_count(self) -> int: """Number of failed nodes.""" count = 0 if self.status == NodeStatus.FAILED: count += 1 for child in self.children: if isinstance(child, Branch): count += child.failed_count elif child.status == NodeStatus.FAILED: count += 1 return count @computed_field @property def progress_percentage(self) -> float: """Percentage of completion (0-100).""" total = self.total_nodes if total == 0: return 0.0 return (self.completed_count / total) * 100 @computed_field @property def current_active_nodes( self, ) -> list[Union[Leaf[ContentType, ResultType], "Branch[ContentType, ResultType]"]]: """Get all nodes currently in progress.""" active = [] if self.status == NodeStatus.IN_PROGRESS: active.append(self) for child in self.children: if isinstance(child, Branch): active.extend(child.current_active_nodes) elif child.status == NodeStatus.IN_PROGRESS: active.append(child) return active @computed_field @property def next_pending_leaf(self) -> Leaf[ContentType, ResultType] | None: """Get the next pending leaf node (depth-first).""" for child in self.children: if isinstance(child, Leaf) and child.status == NodeStatus.PENDING: return child elif isinstance(child, Branch): next_leaf = child.next_pending_leaf if next_leaf: return next_leaf return None @computed_field @property def is_complete(self) -> bool: """Whether all nodes in this subtree are complete.""" if self.status != NodeStatus.COMPLETED: return False for child in self.children: if not child.is_complete: return False return True @computed_field @property def has_failures(self) -> bool: """Whether any node in this subtree has failed.""" return self.failed_count > 0
[docs] def get_all_leaves(self) -> list[Leaf[ContentType, ResultType]]: """Get all leaf nodes in this subtree.""" leaves = [] for child in self.children: if isinstance(child, Branch): leaves.extend(child.get_all_leaves()) else: leaves.append(child) return leaves
[docs] def find_node_by_path( self, path: list[int] ) -> Union[Leaf[ContentType, ResultType], "Branch[ContentType, ResultType]"] | None: """Find a node by its path indices.""" if not path: return self if path[0] >= len(self.children): return None child = self.children[path[0]] if len(path) == 1: return child elif isinstance(child, Branch): return child.find_node_by_path(path[1:]) else: return None
# Enable forward references Branch.model_rebuild()