Source code for haive.core.common.structures.tree_leaf.base

"""Base tree node classes with advanced generic support."""

from abc import ABC, abstractmethod
from typing import (
    Generic,
    Optional,
    overload,
)

from pydantic import BaseModel, Field, PrivateAttr, computed_field

from haive.core.common.structures.tree_leaf.generics import ChildT, ContentT, DefaultContent, DefaultResult, ResultT


[docs] class TreeNode(BaseModel, Generic[ContentT, ResultT], ABC): """Abstract base class for all tree nodes. Uses bounded TypeVars for better type safety and inference. """ content: ContentT result: ResultT | None = None # Auto-indexing (hidden) _index: int = PrivateAttr(default=0) _parent: Optional["TreeNode"] = PrivateAttr(default=None) _depth: int = PrivateAttr(default=0) _path: tuple[int, ...] = PrivateAttr(default=())
[docs] @abstractmethod def is_leaf(self) -> bool: """Check if this is a leaf node.""" ...
@computed_field @property def node_id(self) -> str: """Unique identifier based on path.""" if not self._path: return "root" return ".".join(str(i) for i in self._path) @computed_field @property def level(self) -> int: """Tree level (alias for depth).""" return self._depth
[docs] class Leaf(TreeNode[ContentT, ResultT], Generic[ContentT, ResultT]): """Leaf node - has content but no children. Examples: .. code-block:: python # With explicit types leaf: Leaf[TaskContent, TaskResult] = Leaf( content=TaskContent(name="Calculate", action="add", params={"a": 1, "b": 2}) ) # With default types simple_leaf = Leaf(content=DefaultContent(name="Task1")) """
[docs] def is_leaf(self) -> bool: """Is Leaf. Returns: [TODO: Add return description] """ return True
[docs] class Tree(TreeNode[ContentT, ResultT], Generic[ContentT, ChildT, ResultT]): """Tree node - has content and children. The ChildT parameter allows for heterogeneous trees where children can be of different types (but all extending the bound). Examples: .. code-block:: python # Homogeneous tree (all children same type) tree: Tree[PlanContent, PlanNode, PlanResult] = Tree( content=PlanContent(objective="Main Plan") ) # Heterogeneous tree (mixed children) mixed: Tree[DefaultContent, TreeNode, DefaultResult] = Tree( content=DefaultContent(name="Root") ) """ children: list[ChildT] = Field(default_factory=list) # Private counter for auto-indexing _child_counter: int = PrivateAttr(default=0)
[docs] def is_leaf(self) -> bool: """Is Leaf. Returns: [TODO: Add return description] """ return False
@overload def add_child(self, child: ChildT) -> ChildT: """Add a single child.""" @overload def add_child(self, *children: ChildT) -> list[ChildT]: """Add multiple children."""
[docs] def add_child(self, *children: ChildT) -> ChildT | list[ChildT]: """Add one or more children with auto-indexing.""" if len(children) == 1: child = children[0] self._index_child(child) self.children.append(child) return child else: indexed_children = [] for child in children: self._index_child(child) self.children.append(child) indexed_children.append(child) return indexed_children
def _index_child(self, child: ChildT) -> None: """Set up indexing for a child node.""" if hasattr(child, "_index"): child._index = self._child_counter if hasattr(child, "_parent"): child._parent = self if hasattr(child, "_depth"): child._depth = self._depth + 1 if hasattr(child, "_path"): child._path = self._path + (self._child_counter,) self._child_counter += 1 @computed_field @property def child_count(self) -> int: """Number of direct children.""" return len(self.children) @computed_field @property def descendant_count(self) -> int: """Total number of descendants.""" count = 0 for child in self.children: count += 1 if hasattr(child, "descendant_count"): count += child.descendant_count return count @computed_field @property def height(self) -> int: """Height of the subtree rooted at this node.""" if not self.children: return 0 max_child_height = 0 for child in self.children: child_height = 0 if hasattr(child, "height"): child_height = child.height max_child_height = max(max_child_height, child_height) return max_child_height + 1
[docs] def find_by_path(self, *indices: int) -> ChildT | None: """Find a descendant by path indices.""" if not indices: return self if indices[0] >= len(self.children): return None child = self.children[indices[0]] if len(indices) == 1: return child if hasattr(child, "find_by_path"): return child.find_by_path(*indices[1:]) return None
# Convenience type aliases for common patterns SimpleTree = Tree[ DefaultContent, TreeNode[DefaultContent, DefaultResult], DefaultResult ] SimpleLeaf = Leaf[DefaultContent, DefaultResult] SimpleBranch = Tree[ DefaultContent, TreeNode[DefaultContent, DefaultResult], DefaultResult ]