Source code for haive.core.graph.state_graph.serializable

"""from typing import Any, Dict.
Serialization support for BaseGraph.

This module provides serialization and deserialization utilities for BaseGraph,
enabling graphs to be stored, loaded, and exchanged with minimal information loss.
"""

import importlib
import json
import logging
import uuid
from datetime import datetime
from typing import Any

from pydantic import BaseModel, ConfigDict, Field

from haive.core.graph.common.references import CallableReference
from haive.core.utils.serialization import ensure_json_serializable

# Get logger
logger = logging.getLogger(__name__)


# Type references
[docs] class TypeReference(BaseModel): """Reference to a type that can be serialized.""" module_path: str | None = None name: str is_generic: bool = False generic_params: list["TypeReference"] | None = None model_config = ConfigDict(extra="allow")
[docs] @classmethod def from_type(cls, type_obj) -> Any: """From Type. Args: type_obj: [TODO: Add description] Returns: [TODO: Add return description] """ if type_obj is None: return None ref = cls( name=getattr(type_obj, "__name__", str(type_obj)), module_path=getattr(type_obj, "__module__", None), ) # TODO: Add support for complex generic types return ref
[docs] def resolve(self) -> Any: """Resolve the reference back to a type.""" if not self.module_path or not self.name: return None try: module = importlib.import_module(self.module_path) return getattr(module, self.name) except (ImportError, AttributeError) as e: logger.warning(f"Error resolving type: {e}") return None
# Node serialization
[docs] class SerializableNode(BaseModel): """Serializable representation of a Node.""" id: str name: str node_type: str # Configuration input_mapping: dict[str, str] | None = None output_mapping: dict[str, str] | None = None command_goto: str | list[str] | None = None retry_policy: dict[str, Any] | None = None # Metadata description: str | None = None metadata: dict[str, Any] = Field(default_factory=dict) created_at: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True)
# Branch serialization
[docs] class SerializableBranch(BaseModel): """Serializable representation of a Branch.""" id: str name: str source_node: str mode: str | None = None # Core branch fields key: str | None = None value: Any | None = None comparison: str | None = None default: str | None = None allow_none: bool | None = None message_key: str | None = None # Function references function_ref: CallableReference | None = None condition_ref: CallableReference | None = None # Routing information destinations: dict[str, str] = Field(default_factory=dict) # Advanced branch features send_mappings: list[dict[str, Any]] | None = None send_generators: list[dict[str, Any]] | None = None dynamic_mapping: dict[str, Any] | None = None # Chain branches - these need special handling chain_branches_ids: list[str] | None = None true_branch_id: str | None = None false_branch_id: str | None = None # Metadata description: str | None = None metadata: dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(arbitrary_types_allowed=True)
[docs] class SerializableGraph(BaseModel): """Serializable representation of the BaseGraph.""" # Basic identification id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: str description: str | None = None # Graph structure nodes: dict[str, SerializableNode] = Field(default_factory=dict) # Edges stored as lists for easier serialization direct_edges: list[tuple[str, str]] = Field(default_factory=list) # Branches branches: dict[str, SerializableBranch] = Field(default_factory=dict) # Schema reference state_schema: TypeReference | None = None # Metadata metadata: dict[str, Any] = Field(default_factory=dict) created_at: str updated_at: str # Configuration config: dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(arbitrary_types_allowed=True)
[docs] @classmethod def from_graph(cls, graph) -> Any: """Create a serializable representation from a BaseGraph.""" from haive.core.graph.state_graph.base_graph import BaseGraph if not isinstance(graph, BaseGraph): raise TypeError("Expected a BaseGraph instance") # Initialize basic fields serializable = cls( id=graph.id, name=graph.name, description=graph.description, metadata=graph.metadata.copy(), created_at=graph.created_at.isoformat(), updated_at=graph.updated_at.isoformat(), ) # Convert schema if present if graph.state_schema: serializable.state_schema = TypeReference.from_type(graph.state_schema) # Convert nodes for name, node in graph.nodes.items(): serializable.nodes[name] = SerializableNode( id=node.id, name=node.name, node_type=node.node_type.value, input_mapping=node.input_mapping, output_mapping=node.output_mapping, command_goto=node.command_goto, retry_policy=node.retry_policy._asdict() if node.retry_policy else None, description=node.description, metadata=node.metadata.copy(), created_at=node.created_at.isoformat() if node.created_at else None, ) # Convert direct edges for edge in graph.edges: serializable.direct_edges.append(edge) # Convert branches for branch_id, branch in graph.branches.items(): # Basic properties all branches have branch_data = { "id": branch.id, "name": branch.name, "source_node": branch.source_node, "default": getattr(branch, "default", "END"), "description": getattr(branch, "description", None), "metadata": getattr(branch, "metadata", {}).copy(), } # Add mode if hasattr(branch, "mode"): mode = branch.mode branch_data["mode"] = ( mode.value if hasattr(mode, "value") else str(mode) ) # Add standard fields that most branches have if hasattr(branch, "key"): branch_data["key"] = branch.key if hasattr(branch, "value"): branch_data["value"] = branch.value if hasattr(branch, "comparison"): comp = branch.comparison branch_data["comparison"] = ( comp.value if hasattr(comp, "value") else str(comp) ) if hasattr(branch, "allow_none"): branch_data["allow_none"] = branch.allow_none if hasattr(branch, "message_key"): branch_data["message_key"] = branch.message_key # Handle destinations/routes if hasattr(branch, "destinations"): # Convert any non-string keys to strings destinations = {} for k, v in branch.destinations.items(): destinations[str(k)] = v branch_data["destinations"] = destinations # Handle function references if hasattr(branch, "function_ref"): branch_data["function_ref"] = branch.function_ref elif hasattr(branch, "function") and callable(branch.function): branch_data["function_ref"] = CallableReference.from_callable( branch.function ) if hasattr(branch, "condition_ref"): branch_data["condition_ref"] = branch.condition_ref elif hasattr(branch, "condition") and callable(branch.condition): branch_data["condition_ref"] = CallableReference.from_callable( branch.condition ) # Handle advanced branch features if hasattr(branch, "send_mappings") and branch.send_mappings: # Simplified serialization of send mappings branch_data["send_mappings"] = [ { "condition": m.condition if hasattr(m, "condition") else None, "node": m.node if hasattr(m, "node") else None, "fields": m.fields if hasattr(m, "fields") else {}, } for m in branch.send_mappings ] if hasattr(branch, "send_generators") and branch.send_generators: # Simplified serialization of send generators branch_data["send_generators"] = [ { "collection_field": ( g.collection_field if hasattr(g, "collection_field") else None ), "target_node": ( g.target_node if hasattr(g, "target_node") else None ), "item_field": ( g.item_field if hasattr(g, "item_field") else None ), } for g in branch.send_generators ] # Handle dynamic mapping - this part is fixed to correctly map key # to field if hasattr(branch, "dynamic_mapping") and branch.dynamic_mapping: # Serialize dynamic mapping - The important fix: use key for # field mapping = branch.dynamic_mapping branch_data["dynamic_mapping"] = { "field": mapping.key, # Store the key attribute under 'field' key "mappings": ( mapping.mappings if hasattr(mapping, "mappings") else {} ), "default_node": ( mapping.default_node if hasattr(mapping, "default_node") else "END" ), } # Handle chain branches if hasattr(branch, "chain_branches") and branch.chain_branches: # Store IDs instead of actual objects branch_data["chain_branches_ids"] = [] for chain_branch in branch.chain_branches: # Use ID if it exists in our branches for bid, b in graph.branches.items(): if b is chain_branch: branch_data["chain_branches_ids"].append(bid) break # Handle true/false branches if hasattr(branch, "true_branch"): true_branch = branch.true_branch if isinstance(true_branch, str): # Direct node name branch_data["true_branch_id"] = true_branch elif hasattr(true_branch, "id"): branch_data["true_branch_id"] = true_branch.id if hasattr(branch, "false_branch"): false_branch = branch.false_branch if isinstance(false_branch, str): # Direct node name branch_data["false_branch_id"] = false_branch elif hasattr(false_branch, "id"): branch_data["false_branch_id"] = false_branch.id # Create serializable branch serializable.branches[branch_id] = SerializableBranch(**branch_data) return serializable
[docs] def to_graph(self) -> Any: """Convert serializable representation back to a BaseGraph.""" from haive.core.graph.branches import BranchMode, ComparisonType from haive.core.graph.state_graph.base_graph import BaseGraph, Node, NodeType # Create basic graph graph = BaseGraph( id=self.id, name=self.name, description=self.description, metadata=self.metadata.copy(), ) # Set timestamps try: graph.created_at = datetime.fromisoformat(self.created_at) graph.updated_at = datetime.fromisoformat(self.updated_at) except (ValueError, TypeError): # Use current time if iso format parsing fails now = datetime.now() graph.created_at = now graph.updated_at = now # Resolve schema if present if self.state_schema: graph.state_schema = self.state_schema.resolve() # Add nodes for name, node_data in self.nodes.items(): try: # Convert string node_type to enum node_type = NodeType(node_data.node_type) # Create the node node = Node( id=node_data.id, name=node_data.name, node_type=node_type, input_mapping=node_data.input_mapping, output_mapping=node_data.output_mapping, command_goto=node_data.command_goto, retry_policy=node_data.retry_policy, # TODO: Convert dict to RetryPolicy description=node_data.description, metadata=node_data.metadata.copy(), ) # Parse created_at if node_data.created_at: try: node.created_at = datetime.fromisoformat(node_data.created_at) except (ValueError, TypeError): node.created_at = datetime.now() # Add to graph graph.nodes[name] = node except Exception as e: logger.exception(f"Error reconstructing node {name}: {e}") # Add direct edges for source, target in self.direct_edges: graph.edges.append((source, target)) # First pass to create all branches (without references to other # branches) temp_branches = {} for branch_id, branch_data in self.branches.items(): try: # Resolve function reference if available function = None if branch_data.function_ref: function = branch_data.function_ref.resolve() # Resolve condition reference if available if branch_data.condition_ref: branch_data.condition_ref.resolve() # Convert comparison enum comparison = branch_data.comparison if comparison and not isinstance(comparison, ComparisonType): try: comparison = ComparisonType(comparison) except (ValueError, TypeError): pass # Keep as string if not an enum value # Convert mode enum mode = branch_data.mode if mode and not isinstance(mode, BranchMode): try: mode = BranchMode(mode) except (ValueError, TypeError): pass # Keep as string if not an enum value # Create dynamic mapping if provided - Fixed to correctly map # field to key dynamic_mapping = None if branch_data.dynamic_mapping: # The important fix: use field for key dynamic_mapping = DynamicMapping( key=branch_data.dynamic_mapping.get( "field" ), # Map 'field' to 'key' mappings=branch_data.dynamic_mapping.get("mappings", {}), default_node=branch_data.dynamic_mapping.get( "default_node", "END" ), ) # Create send mappings if provided send_mappings = [] if branch_data.send_mappings: for mapping_data in branch_data.send_mappings: mapping = SendMapping( condition=mapping_data.get("condition"), node=mapping_data.get("node"), fields=mapping_data.get("fields", {}), ) send_mappings.append(mapping) # Create send generators if provided send_generators = [] if branch_data.send_generators: for generator_data in branch_data.send_generators: generator = SendGenerator( collection_field=generator_data.get("collection_field"), target_node=generator_data.get("target_node"), item_field=generator_data.get("item_field"), ) send_generators.append(generator) # Convert string keys in destinations to proper types # (bool/int) destinations = {} if branch_data.destinations: for k, v in branch_data.destinations.items(): # Try to convert string keys back to original types if k.lower() == "true": destinations[True] = v elif k.lower() == "false": destinations[False] = v else: try: # Try as number num = int(k) destinations[num] = v except ValueError: # Keep as string destinations[k] = v # Create branch with available data branch = Branch( id=branch_data.id, name=branch_data.name, source_node=branch_data.source_node, key=branch_data.key, value=branch_data.value, comparison=comparison, function=function, function_ref=branch_data.function_ref, destinations=destinations, default=branch_data.default or "END", allow_none=branch_data.allow_none or False, message_key=branch_data.message_key or "messages", mode=mode or BranchMode.DIRECT, # Dynamic mapping dynamic_mapping=dynamic_mapping, # Send mapping send_mappings=send_mappings, send_generators=send_generators, # For condition branches condition_ref=branch_data.condition_ref, ) # Store for second pass linking temp_branches[branch_id] = branch # Add to graph graph.branches[branch_id] = branch except Exception as e: logger.exception(f"Error reconstructing branch {branch_id}: {e}") # Second pass to link branches to each other for branch_id, branch_data in self.branches.items(): if branch_id not in temp_branches: continue branch = temp_branches[branch_id] # Link chain branches if branch_data.chain_branches_ids: branch.chain_branches = [] for chain_id in branch_data.chain_branches_ids: if chain_id in temp_branches: branch.chain_branches.append(temp_branches[chain_id]) # Link true/false branches if branch_data.true_branch_id: if branch_data.true_branch_id in temp_branches: branch.true_branch = temp_branches[branch_data.true_branch_id] else: # Might be a node name branch.true_branch = branch_data.true_branch_id if branch_data.false_branch_id: if branch_data.false_branch_id in temp_branches: branch.false_branch = temp_branches[branch_data.false_branch_id] else: # Might be a node name branch.false_branch = branch_data.false_branch_id return graph
[docs] def to_dict(self) -> Any: """Convert to a dictionary for serialization.""" return self.model_dump()
[docs] @classmethod def from_dict(cls, data: dict[str, Any]): """Create from a dictionary.""" return cls.model_validate(data)
# @classmethod
[docs] def to_json(self, **kwargs) -> Any: """Convert to JSON string.""" # First convert to a dict data = self.to_dict() # Ensure all data is JSON serializable (handle functions, etc.) serializable_data = ensure_json_serializable(data) # Convert to JSON string return json.dumps(serializable_data, **kwargs)
[docs] @classmethod def from_json(cls, json_str) -> Any: """Create from JSON string.""" return cls.model_validate_json(json_str)