"""StateSchemaManager for creating and manipulating state schemas.
This module provides the StateSchemaManager class, which offers a low-level API for
dynamically creating, modifying, and managing state schemas at runtime. Unlike the
SchemaComposer, which provides a higher-level builder-style API focused on composition,
the StateSchemaManager offers fine-grained control over schema construction and
modification with support for advanced features like computed properties, validators,
and custom methods.
The StateSchemaManager is particularly useful for:
- Programmatically building schemas with complex interdependencies
- Adding validators, properties, and methods to schemas
- Performing schema transformations and modifications at runtime
- Providing a programmatic interface for schema manipulation
- Creating specialized schema variants with custom behaviors
Key capabilities include:
- Field creation and manipulation with comprehensive type handling
- Support for field sharing, reducers, and engine I/O relationships
- Addition of validators, properties, and computed properties
- Dynamic method addition (instance, class, and static methods)
- Schema finalization with proper metadata configuration
- Integration with SchemaComposer for seamless conversion
Examples:
from haive.core.schema import StateSchemaManager
from typing import List
from langchain_core.messages import BaseMessage
# Create a manager
manager = StateSchemaManager(name="ConversationState")
# Add fields
manager.add_field(
"messages",
List[BaseMessage],
default_factory=list,
description="Conversation history",
shared=True
)
# Add a computed property
def get_last_message(self):
if not self.messages:
return None
return self.messages[-1]
manager.add_computed_property("last_message", get_last_message)
# Add a method
def add_message(self, message):
self.messages.append(message)
manager.add_method("add_message", add_message)
# Build the schema
ConversationState = manager.build()
This module is part of the Haive Schema System, providing the lower-level foundation
for schema manipulation that complements the higher-level SchemaComposer.
"""
from __future__ import annotations
import inspect
import logging
from collections import defaultdict
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar
from pydantic import BaseModel, create_model
from haive.core.schema.field_definition import FieldDefinition
from haive.core.schema.field_utils import (
create_field,
infer_field_type,
resolve_reducer,
)
from haive.core.schema.state_schema import StateSchema
if TYPE_CHECKING:
from haive.core.schema.schema_composer import SchemaComposer
logger = logging.getLogger(__name__)
T = TypeVar("T")
[docs]
class StateSchemaManager:
"""Manager for dynamically creating and manipulating state schemas.
The StateSchemaManager provides a comprehensive, low-level API for dynamically
creating, modifying, and managing state schemas at runtime. It offers granular
control over schema construction and modification with advanced features like
computed properties, validators, and custom methods.
This class serves as a layer over Pydantic's model creation functionality with
additional features specific to the Haive framework, including field sharing,
reducer functions, and engine I/O tracking. Unlike SchemaComposer, which provides
a higher-level builder-style API focused on composition from components,
StateSchemaManager offers fine-grained control over schema construction.
Key capabilities include:
- Field management: Add, modify, and remove fields with comprehensive type handling
- Field sharing: Configure which fields are shared between parent and child graphs
- Reducer functions: Set up field-specific reducers for state merging
- Engine I/O: Track which fields are inputs and outputs for which engines
- Validators: Add custom validation functions for field values
- Properties and computed properties: Define dynamic properties with getters/setters
- Methods: Add instance, class, and static methods to the schema
- Schema finalization: Build and finalize the schema with proper metadata
This class is particularly useful for programmatically building schemas with
complex interdependencies, adding validation logic, and creating specialized
schema variants with custom behaviors.
"""
def __init__(
self,
data: (
dict[str, Any] | type[BaseModel] | BaseModel | SchemaComposer | None
) = None,
name: str | None = None,
config: dict[str, Any] | None = None,
):
"""Initialize a new StateSchemaManager.
Args:
data: Source data to initialize from
name: Custom schema name
config: Optional configuration dictionary
"""
self.fields: dict[str, tuple[type, Any]] = {}
self.field_definitions: dict[str, FieldDefinition] = (
{}
) # New field definitions store
self.validators: dict[str, Callable] = {}
self.properties: dict[str, Callable] = {}
self.computed_properties: dict[str, tuple[Callable, Callable]] = {}
self.field_descriptions: dict[str, str] = {}
self.instance_methods: dict[str, Callable] = {}
self.class_methods: dict[str, Callable] = {}
self.static_methods: dict[str, Callable] = {}
self.locked = False
self.config = config or {}
# Special field sets for state schema
self._shared_fields: set[str] = set()
self._reducer_names: dict[str, str] = (
{}
) # Field name -> reducer name (serializable)
self._reducer_functions: dict[str, Callable] = (
{}
) # Field name -> reducer function
# Input/output tracking for engines
self._input_fields: dict[str, set[str]] = defaultdict(
set
) # Engine name -> input fields
self._output_fields: dict[str, set[str]] = defaultdict(
set
) # Engine name -> output fields
self._engine_io_mappings: dict[str, dict[str, list[str]]] = (
{}
) # Engine name -> IO mapping
# Set default name
if data is None:
self.name = name or self.config.get("default_schema_name", "UnnamedSchema")
return
# Extract name from data if not provided
if name is None:
if hasattr(data, "name"):
self.name = data.name
elif isinstance(data, type):
self.name = data.__name__
else:
self.name = "CustomState"
else:
self.name = name
# Load data based on type
if isinstance(data, dict):
self._load_from_dict(data)
elif isinstance(data, type) and issubclass(data, BaseModel):
self._load_from_model(data)
elif isinstance(data, BaseModel):
self._load_from_model(data.__class__)
elif hasattr(data, "fields") and (
hasattr(data, "shared_fields") or hasattr(data, "engine_io_mappings")
):
self._load_from_composer(data)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
def _load_from_model(self, model_cls: type[BaseModel]) -> None:
"""Load fields and methods from a Pydantic model.
Args:
model_cls: Pydantic model class to load from
"""
try:
# Extract fields from model_fields
if hasattr(model_cls, "model_fields"):
for field_name, field_info in model_cls.model_fields.items():
# Skip special fields
if field_name.startswith("__") or field_name == "runnable_config":
continue
# Create a field definition
field_def = FieldDefinition.extract_from_model_field(
name=field_name,
field_type=field_info.annotation,
field_info=field_info,
)
# Add to field definitions
self.field_definitions[field_name] = field_def
# Add to fields for backward compatibility
self.fields[field_name] = field_def.to_field_info()
# Set description if available
if field_def.description:
self.field_descriptions[field_name] = field_def.description
# Extract shared fields
if hasattr(model_cls, "__shared_fields__"):
self._shared_fields.update(model_cls.__shared_fields__)
# Update field definition shared flag
for field_name in model_cls.__shared_fields__:
if field_name in self.field_definitions:
self.field_definitions[field_name].shared = True
# Extract reducer information
if hasattr(model_cls, "__serializable_reducers__"):
self._reducer_names.update(model_cls.__serializable_reducers__)
if hasattr(model_cls, "__reducer_fields__"):
self._reducer_functions.update(model_cls.__reducer_fields__)
# Update field definition reducers
for field_name, reducer in model_cls.__reducer_fields__.items():
if field_name in self.field_definitions:
self.field_definitions[field_name].reducer = reducer
# Extract engine I/O mappings
if hasattr(model_cls, "__engine_io_mappings__"):
for engine_name, mapping in model_cls.__engine_io_mappings__.items():
self._engine_io_mappings[engine_name] = mapping.copy()
# Extract input fields
for field_name in mapping.get("inputs", []):
self._input_fields[engine_name].add(field_name)
# Update field definition
if field_name in self.field_definitions and (
engine_name
not in self.field_definitions[field_name].input_for
):
self.field_definitions[field_name].input_for.append(
engine_name
)
# Extract output fields
for field_name in mapping.get("outputs", []):
self._output_fields[engine_name].add(field_name)
# Update field definition
if field_name in self.field_definitions and (
engine_name
not in self.field_definitions[field_name].output_from
):
self.field_definitions[field_name].output_from.append(
engine_name
)
# Extract input/output fields (alternative format)
if hasattr(model_cls, "__input_fields__"):
for engine_name, fields_list in model_cls.__input_fields__.items():
self._input_fields[engine_name].update(fields_list)
# Update field definitions
for field_name in fields_list:
if field_name in self.field_definitions and (
engine_name
not in self.field_definitions[field_name].input_for
):
self.field_definitions[field_name].input_for.append(
engine_name
)
if hasattr(model_cls, "__output_fields__"):
for engine_name, fields_list in model_cls.__output_fields__.items():
self._output_fields[engine_name].update(fields_list)
# Update field definitions
for field_name in fields_list:
if field_name in self.field_definitions and (
engine_name
not in self.field_definitions[field_name].output_from
):
self.field_definitions[field_name].output_from.append(
engine_name
)
# Load methods and validators
for name, attr in inspect.getmembers(model_cls):
# Skip private/special methods
if name.startswith("_") and name != "__validator__":
continue
if isinstance(attr, property):
# Handle property
getter = attr.fget
setter = attr.fset
if getter and setter:
self.computed_properties[name] = (getter, setter)
elif getter:
self.properties[name] = getter
elif inspect.ismethod(attr) or inspect.isfunction(attr):
# Handle methods
if name.startswith("validate_") or getattr(
attr, "__validator__", False
):
self.validators[name] = attr
except Exception as e:
logger.exception(f"Error loading from model {model_cls.__name__}: {e}")
# Add a placeholder field as fallback
self.fields["placeholder"] = (str, create_field(str, "")[1])
self.field_definitions["placeholder"] = FieldDefinition(
name="placeholder", field_type=str, default=""
)
def _load_from_composer(self, composer: Any) -> None:
"""Load fields and metadata from a SchemaComposer.
Args:
composer: SchemaComposer instance
"""
try:
# Add field definitions directly
if hasattr(composer, "fields") and isinstance(composer.fields, dict):
for field_name, field_def in composer.fields.items():
# Add to field definitions
self.field_definitions[field_name] = field_def
# Add to fields for backward compatibility
self.fields[field_name] = field_def.to_field_info()
# Set description if available
if field_def.description:
self.field_descriptions[field_name] = field_def.description
# Set shared flag
if field_def.shared:
self._shared_fields.add(field_name)
# Set reducer
if field_def.reducer:
self._reducer_functions[field_name] = field_def.reducer
reducer_name = field_def.get_reducer_name()
if reducer_name:
self._reducer_names[field_name] = reducer_name
# Add shared fields
if hasattr(composer, "shared_fields"):
self._shared_fields.update(composer.shared_fields)
# Add engine I/O mappings
if hasattr(composer, "engine_io_mappings"):
for engine_name, mapping in composer.engine_io_mappings.items():
self._engine_io_mappings[engine_name] = mapping.copy()
# Add input/output fields
if hasattr(composer, "input_fields"):
for engine_name, fields in composer.input_fields.items():
self._input_fields[engine_name].update(fields)
if hasattr(composer, "output_fields"):
for engine_name, fields in composer.output_fields.items():
self._output_fields[engine_name].update(fields)
except Exception as e:
logger.exception(f"Error loading from composer: {e}")
# Add a placeholder field as fallback
self.fields["placeholder"] = (str, create_field(str, "")[1])
self.field_definitions["placeholder"] = FieldDefinition(
name="placeholder", field_type=str, default=""
)
def _load_from_dict(self, data: dict[str, Any]) -> None:
"""Load fields from a dictionary.
Args:
data: Dictionary containing field data
"""
try:
# Handle special metadata keys
if "shared_fields" in data:
self._shared_fields.update(data["shared_fields"])
if "reducer_names" in data or "serializable_reducers" in data:
reducer_names = data.get(
"reducer_names", data.get("serializable_reducers", {})
)
self._reducer_names.update(reducer_names)
if "reducer_functions" in data:
self._reducer_functions.update(data["reducer_functions"])
if "field_descriptions" in data:
self.field_descriptions.update(data["field_descriptions"])
if "engine_io_mappings" in data:
for engine_name, mapping in data["engine_io_mappings"].items():
self._engine_io_mappings[engine_name] = mapping.copy()
if "input_fields" in data:
for engine_name, fields in data["input_fields"].items():
self._input_fields[engine_name].update(fields)
if "output_fields" in data:
for engine_name, fields in data["output_fields"].items():
self._output_fields[engine_name].update(fields)
# Process field definitions
for key, value in data.items():
# Skip special metadata keys
if key in [
"shared_fields",
"reducer_names",
"serializable_reducers",
"reducer_functions",
"field_descriptions",
"engine_io_mappings",
"input_fields",
"output_fields",
]:
continue
# Handle different field formats
if isinstance(value, tuple) and len(value) >= 2:
# Handle (type, default) format
field_type, default = value[0:2]
# Check for extra metadata
extra = {}
if len(value) >= 3 and isinstance(value[2], dict):
extra = value[2]
# Extract metadata
description = extra.pop("description", None)
shared = extra.pop("shared", False) or key in self._shared_fields
reducer = None
if "reducer" in extra:
reducer_value = extra.pop("reducer")
if callable(reducer_value):
reducer = reducer_value
elif isinstance(reducer_value, str):
reducer = resolve_reducer(reducer_value)
elif key in self._reducer_functions:
reducer = self._reducer_functions[key]
# Check if default is a factory function
default_factory = None
if callable(default) and not isinstance(default, type):
default_factory = default
default = None
# Create field definition
field_def = FieldDefinition(
name=key,
field_type=field_type,
default=default,
default_factory=default_factory,
description=description,
shared=shared,
reducer=reducer,
**extra,
)
# Add field
self.field_definitions[key] = field_def
self.fields[key] = field_def.to_field_info()
# Set description
if description:
self.field_descriptions[key] = description
else:
# Field with value only - infer type
field_type = infer_field_type(value)
# Create field definition
field_def = FieldDefinition(
name=key,
field_type=field_type,
default=value,
shared=key in self._shared_fields,
reducer=self._reducer_functions.get(key),
)
# Add field
self.field_definitions[key] = field_def
self.fields[key] = field_def.to_field_info()
except Exception as e:
logger.exception(f"Error loading from dict: {e}")
# Add a placeholder field as fallback
self.fields["placeholder"] = (str, create_field(str, "")[1])
self.field_definitions["placeholder"] = FieldDefinition(
name="placeholder", field_type=str, default=""
)
[docs]
def add_field(
self,
name: str,
field_type: type[T],
default: Any = None,
default_factory: Callable[[], T] | None = None,
description: str | None = None,
shared: bool = False,
reducer: Callable | None = None,
input_for: list[str] | None = None,
output_from: list[str] | None = None,
optional: bool = True,
**kwargs,
) -> StateSchemaManager:
"""Add a field to the schema with comprehensive options.
Args:
name: Field name
field_type: Type of the field
default: Default value for the field
default_factory: Optional factory function for default value
description: Optional field description
shared: Whether field is shared with parent graph
reducer: Optional reducer function for this field
input_for: List of engines this field serves as input for
output_from: List of engines this field is output from
optional: Whether to make the field optional (default: True)
**kwargs: Additional field parameters
Returns:
Self for chaining
"""
if self.locked:
raise ValueError("Schema is locked and cannot be modified.")
# Create field definition
field_def = FieldDefinition(
name=name,
field_type=field_type,
default=default,
default_factory=default_factory,
description=description,
shared=shared,
reducer=reducer,
input_for=input_for,
output_from=output_from,
**kwargs,
)
# Store the field definition
self.field_definitions[name] = field_def
# Store the field info for backward compatibility
self.fields[name] = field_def.to_field_info()
# Track additional metadata
if description:
self.field_descriptions[name] = description
if shared:
self._shared_fields.add(name)
if reducer:
# Store reducer function
self._reducer_functions[name] = reducer
# Store serializable reducer name
reducer_name = field_def.get_reducer_name()
if reducer_name:
self._reducer_names[name] = reducer_name
# Track engine I/O
if input_for:
for engine_name in input_for:
self._input_fields[engine_name].add(name)
# Update engine I/O mapping
if engine_name not in self._engine_io_mappings:
self._engine_io_mappings[engine_name] = {
"inputs": [],
"outputs": [],
}
if name not in self._engine_io_mappings[engine_name]["inputs"]:
self._engine_io_mappings[engine_name]["inputs"].append(name)
if output_from:
for engine_name in output_from:
self._output_fields[engine_name].add(name)
# Update engine I/O mapping
if engine_name not in self._engine_io_mappings:
self._engine_io_mappings[engine_name] = {
"inputs": [],
"outputs": [],
}
if name not in self._engine_io_mappings[engine_name]["outputs"]:
self._engine_io_mappings[engine_name]["outputs"].append(name)
return self
[docs]
def remove_field(self, name: str) -> StateSchemaManager:
"""Remove a field from the schema.
Args:
name: Name of the field to remove
Returns:
Self for chaining
"""
if self.locked:
raise ValueError("Schema is locked and cannot be modified.")
# Remove from fields
if name in self.fields:
del self.fields[name]
# Remove from field definitions
if name in self.field_definitions:
del self.field_definitions[name]
# Remove from metadata
if name in self.field_descriptions:
del self.field_descriptions[name]
if name in self._shared_fields:
self._shared_fields.remove(name)
if name in self._reducer_names:
del self._reducer_names[name]
if name in self._reducer_functions:
del self._reducer_functions[name]
# Remove from engine I/O
for engine_name, fields in self._input_fields.items():
if name in fields:
fields.remove(name)
# Update mapping
if (
engine_name in self._engine_io_mappings
and name in self._engine_io_mappings[engine_name]["inputs"]
):
self._engine_io_mappings[engine_name]["inputs"].remove(name)
for engine_name, fields in self._output_fields.items():
if name in fields:
fields.remove(name)
# Update mapping
if (
engine_name in self._engine_io_mappings
and name in self._engine_io_mappings[engine_name]["outputs"]
):
self._engine_io_mappings[engine_name]["outputs"].remove(name)
return self
[docs]
def modify_field(
self,
name: str,
new_type: type | None = None,
new_default: Any = None,
new_default_factory: Callable | None = None,
new_description: str | None = None,
new_shared: bool | None = None,
new_reducer: Callable | None = None,
add_input_for: list[str] | None = None,
add_output_from: list[str] | None = None,
remove_input_for: list[str] | None = None,
remove_output_from: list[str] | None = None,
**kwargs,
) -> StateSchemaManager:
"""Modify an existing field's properties.
Args:
name: Name of the field to modify
new_type: New type for the field
new_default: New default value
new_default_factory: New default factory function
new_description: New description
new_shared: Whether field should be shared
new_reducer: New reducer function
add_input_for: List of engines to add as input consumers
add_output_from: List of engines to add as output producers
remove_input_for: List of engines to remove as input consumers
remove_output_from: List of engines to remove as output producers
**kwargs: Additional field parameters
Returns:
Self for chaining
"""
if self.locked:
raise ValueError("Schema is locked and cannot be modified.")
# Check if field exists
if name not in self.field_definitions:
logger.warning(f"Field {name} does not exist, cannot be modified")
return self
# Get current field definition
field_def = self.field_definitions[name]
# Update type if provided
if new_type is not None:
field_def.field_type = new_type
# Update default or default_factory if provided
if new_default_factory is not None:
field_def = FieldDefinition(
name=name,
field_type=field_def.field_type,
default_factory=new_default_factory,
description=field_def.description,
shared=field_def.shared,
reducer=field_def.reducer,
input_for=field_def.input_for,
output_from=field_def.output_from,
)
elif new_default is not None:
field_def = FieldDefinition(
name=name,
field_type=field_def.field_type,
default=new_default,
description=field_def.description,
shared=field_def.shared,
reducer=field_def.reducer,
input_for=field_def.input_for,
output_from=field_def.output_from,
)
# Update description if provided
if new_description is not None:
field_def.description = new_description
self.field_descriptions[name] = new_description
# Update shared status if specified
if new_shared is not None:
field_def.shared = new_shared
if new_shared:
self._shared_fields.add(name)
elif name in self._shared_fields:
self._shared_fields.remove(name)
# Update reducer if provided
if new_reducer is not None:
field_def.reducer = new_reducer
self._reducer_functions[name] = new_reducer
# Update reducer name
reducer_name = field_def.get_reducer_name()
if reducer_name:
self._reducer_names[name] = reducer_name
elif name in self._reducer_names:
del self._reducer_names[name]
# Update engine I/O
if add_input_for:
for engine_name in add_input_for:
if engine_name not in field_def.input_for:
field_def.input_for.append(engine_name)
self._input_fields[engine_name].add(name)
# Update mapping
if engine_name not in self._engine_io_mappings:
self._engine_io_mappings[engine_name] = {
"inputs": [],
"outputs": [],
}
if name not in self._engine_io_mappings[engine_name]["inputs"]:
self._engine_io_mappings[engine_name]["inputs"].append(name)
if add_output_from:
for engine_name in add_output_from:
if engine_name not in field_def.output_from:
field_def.output_from.append(engine_name)
self._output_fields[engine_name].add(name)
# Update mapping
if engine_name not in self._engine_io_mappings:
self._engine_io_mappings[engine_name] = {
"inputs": [],
"outputs": [],
}
if name not in self._engine_io_mappings[engine_name]["outputs"]:
self._engine_io_mappings[engine_name]["outputs"].append(name)
if remove_input_for:
for engine_name in remove_input_for:
if engine_name in field_def.input_for:
field_def.input_for.remove(engine_name)
if name in self._input_fields[engine_name]:
self._input_fields[engine_name].remove(name)
# Update mapping
if (
engine_name in self._engine_io_mappings
and name in self._engine_io_mappings[engine_name]["inputs"]
):
self._engine_io_mappings[engine_name]["inputs"].remove(name)
if remove_output_from:
for engine_name in remove_output_from:
if engine_name in field_def.output_from:
field_def.output_from.remove(engine_name)
if name in self._output_fields[engine_name]:
self._output_fields[engine_name].remove(name)
# Update mapping
if (
engine_name in self._engine_io_mappings
and name in self._engine_io_mappings[engine_name]["outputs"]
):
self._engine_io_mappings[engine_name]["outputs"].remove(name)
# Update additional parameters
if kwargs:
# Create a new field definition with updated parameters
field_info_kwargs = field_def._extract_field_kwargs()
field_info_kwargs.update(kwargs)
# Create new field definition
if field_def.default_factory is not None:
field_def = FieldDefinition(
name=name,
field_type=field_def.field_type,
default_factory=field_def.default_factory,
description=field_def.description,
shared=field_def.shared,
reducer=field_def.reducer,
input_for=field_def.input_for,
output_from=field_def.output_from,
**field_info_kwargs,
)
else:
field_def = FieldDefinition(
name=name,
field_type=field_def.field_type,
default=field_def.default,
description=field_def.description,
shared=field_def.shared,
reducer=field_def.reducer,
input_for=field_def.input_for,
output_from=field_def.output_from,
**field_info_kwargs,
)
# Update field definition
self.field_definitions[name] = field_def
# Update field for backward compatibility
self.fields[name] = field_def.to_field_info()
return self
[docs]
def has_field(self, name: str) -> bool:
"""Check if the schema has a specific field.
Args:
name: Field name to check
Returns:
True if field exists, False otherwise
"""
return name in self.field_definitions
[docs]
def get_model(
self,
lock: bool = False,
as_state_schema: bool = True,
name: str | None = None,
use_annotated: bool = True,
) -> type[BaseModel]:
"""Create a Pydantic model with all configured options.
Args:
lock: Whether to lock the schema against further modifications
as_state_schema: Whether to use StateSchema as the base class
name: Optional name for the schema class
use_annotated: Whether to use Annotated types for metadata
Returns:
Created model class
"""
if lock:
self.locked = True
# Use provided name or default
model_name = name or self.name
# Choose the base class
base_class = StateSchema if as_state_schema else BaseModel
# Create field definitions for the model
field_defs = {}
for name, field_def in self.field_definitions.items():
if use_annotated:
field_type, field_info = field_def.to_annotated_field()
else:
field_type, field_info = field_def.to_field_info()
field_defs[name] = (field_type, field_info)
# Create the model with fields
model = create_model(model_name, __base__=base_class, **field_defs)
# Add shared fields metadata if using StateSchema
if as_state_schema:
model.__shared_fields__ = list(self._shared_fields)
# Add reducer metadata
model.__serializable_reducers__ = dict(self._reducer_names)
# Add actual reducer functions
if not hasattr(model, "__reducer_fields__"):
model.__reducer_fields__ = {}
model.__reducer_fields__.update(self._reducer_functions)
# Add engine I/O mappings metadata if available
if self._engine_io_mappings:
model.__engine_io_mappings__ = {
k: v.copy() for k, v in self._engine_io_mappings.items()
}
# Add input/output field tracking
if self._input_fields:
model.__input_fields__ = {
k: list(v) for k, v in self._input_fields.items()
}
if self._output_fields:
model.__output_fields__ = {
k: list(v) for k, v in self._output_fields.items()
}
# Add validators
for name, validator in self.validators.items():
setattr(model, name, validator)
# Add properties
for prop_name, (getter, setter) in self.computed_properties.items():
prop = property(getter, setter)
setattr(model, prop_name, prop)
# Add regular properties
for prop_name, getter in self.properties.items():
prop = property(getter)
setattr(model, prop_name, prop)
# Add instance methods
for name, method in self.instance_methods.items():
setattr(model, name, method)
# Add class methods
for name, method in self.class_methods.items():
setattr(model, name, classmethod(method))
# Add static methods
for name, method in self.static_methods.items():
setattr(model, name, staticmethod(method))
return model
[docs]
def mark_as_output_field(
self, field_name: str, engine_name: str
) -> StateSchemaManager:
"""Mark a field as an output field for an engine.
Args:
field_name: Name of the field
engine_name: Name of the engine
Returns:
Self for chaining
"""
if field_name in self.field_definitions:
# Update field definition
if engine_name not in self.field_definitions[field_name].output_from:
self.field_definitions[field_name].output_from.append(engine_name)
# Update output fields tracking
self._output_fields[engine_name].add(field_name)
# Update engine I/O mapping
if engine_name not in self._engine_io_mappings:
self._engine_io_mappings[engine_name] = {"inputs": [], "outputs": []}
if field_name not in self._engine_io_mappings[engine_name]["outputs"]:
self._engine_io_mappings[engine_name]["outputs"].append(field_name)
return self
def _update_engine_io_mapping(self, engine_name: str) -> None:
"""Update engine I/O mapping for a specific engine.
Args:
engine_name: Engine name to update mapping for
"""
# Create mapping if it doesn't exist
if engine_name not in self._engine_io_mappings:
self._engine_io_mappings[engine_name] = {"inputs": [], "outputs": []}
# Update inputs from tracked fields
if engine_name in self._input_fields:
self._engine_io_mappings[engine_name]["inputs"] = list(
self._input_fields[engine_name]
)
# Update outputs from tracked fields
if engine_name in self._output_fields:
self._engine_io_mappings[engine_name]["outputs"] = list(
self._output_fields[engine_name]
)
[docs]
def to_composer(self) -> SchemaComposer:
"""Convert to a SchemaComposer instance.
Returns:
SchemaComposer with current fields and metadata
"""
from haive.core.schema.schema_composer import SchemaComposer
# Create a new composer
composer = SchemaComposer(name=self.name)
# Add field definitions
for _name, field_def in self.field_definitions.items():
composer.add_field_definition(field_def)
return composer
[docs]
def merge(
self,
other: StateSchemaManager | type[BaseModel] | BaseModel | SchemaComposer,
) -> StateSchemaManager:
"""Merge with another schema, preserving first occurrences.
Args:
other: Another object to merge with
Returns:
New merged StateSchemaManager
"""
# Create a new manager with the same name
merged = StateSchemaManager(name=f"{self.name}_merged")
# Copy all fields and metadata from self
for name, field_def in self.field_definitions.items():
merged.field_definitions[name] = field_def
merged.fields[name] = field_def.to_field_info()
if field_def.description:
merged.field_descriptions[name] = field_def.description
if field_def.shared:
merged._shared_fields.add(name)
if field_def.reducer:
merged._reducer_functions[name] = field_def.reducer
reducer_name = field_def.get_reducer_name()
if reducer_name:
merged._reducer_names[name] = reducer_name
for engine_name in field_def.input_for:
merged._input_fields[engine_name].add(name)
for engine_name in field_def.output_from:
merged._output_fields[engine_name].add(name)
# Copy engine I/O mappings
for engine_name, mapping in self._engine_io_mappings.items():
merged._engine_io_mappings[engine_name] = mapping.copy()
# Copy methods
merged.validators = self.validators.copy()
merged.properties = self.properties.copy()
merged.computed_properties = self.computed_properties.copy()
merged.instance_methods = self.instance_methods.copy()
merged.class_methods = self.class_methods.copy()
merged.static_methods = self.static_methods.copy()
# Convert other to StateSchemaManager if needed
other_manager = other
if not isinstance(other, StateSchemaManager):
other_manager = StateSchemaManager(other)
# Add fields from other (skip if already exists)
for name, field_def in other_manager.field_definitions.items():
if name not in merged.field_definitions:
merged.field_definitions[name] = field_def
merged.fields[name] = field_def.to_field_info()
if field_def.description:
merged.field_descriptions[name] = field_def.description
if field_def.shared:
merged._shared_fields.add(name)
if field_def.reducer:
merged._reducer_functions[name] = field_def.reducer
reducer_name = field_def.get_reducer_name()
if reducer_name:
merged._reducer_names[name] = reducer_name
for engine_name in field_def.input_for:
merged._input_fields[engine_name].add(name)
for engine_name in field_def.output_from:
merged._output_fields[engine_name].add(name)
# Add engine I/O mappings (don't overwrite existing)
for engine_name, mapping in other_manager._engine_io_mappings.items():
if engine_name not in merged._engine_io_mappings:
merged._engine_io_mappings[engine_name] = mapping.copy()
# Add validators (don't overwrite existing)
for name, validator in other_manager.validators.items():
if name not in merged.validators:
merged.validators[name] = validator
# Add properties (don't overwrite existing)
for name, prop in other_manager.properties.items():
if name not in merged.properties:
merged.properties[name] = prop
# Add computed properties (don't overwrite existing)
for name, (getter, setter) in other_manager.computed_properties.items():
if name not in merged.computed_properties:
merged.computed_properties[name] = (getter, setter)
# Add methods (don't overwrite existing)
for name, method in other_manager.instance_methods.items():
if name not in merged.instance_methods:
merged.instance_methods[name] = method
for name, method in other_manager.class_methods.items():
if name not in merged.class_methods:
merged.class_methods[name] = method
for name, method in other_manager.static_methods.items():
if name not in merged.static_methods:
merged.static_methods[name] = method
# Update engine I/O mappings for consistency
for engine_name in merged._input_fields.keys() | merged._output_fields.keys():
merged._update_engine_io_mapping(engine_name)
return merged
[docs]
def add_method(
self,
method: Callable,
method_name: str | None = None,
method_type: str = "instance",
) -> StateSchemaManager:
"""Add a method to the schema.
Args:
method: Method callable to add
method_name: Optional method name (defaults to method.__name__)
method_type: Type of method: "instance", "class", or "static"
Returns:
Self for chaining
"""
if self.locked:
raise ValueError("Schema is locked and cannot be modified.")
name = method_name or method.__name__
if method_type == "instance":
self.instance_methods[name] = method
elif method_type == "class":
self.class_methods[name] = method
elif method_type == "static":
self.static_methods[name] = method
else:
raise TypeError(f"Unknown method type: {method_type}")
return self
[docs]
@classmethod
def from_components(
cls,
components: list[Any],
name: str = "ComponentSchema",
include_messages_field: bool = True,
) -> StateSchemaManager:
"""Create a schema manager from a list of components.
Args:
components: List of components to extract fields from
name: Name for the resulting schema
include_messages_field: Whether to ensure a messages field exists
Returns:
StateSchemaManager with extracted fields
"""
from haive.core.schema.schema_composer import SchemaComposer
# Create a composer and extract fields
composer = SchemaComposer(name=name)
composer.add_fields_from_components(
components, include_messages_field=include_messages_field
)
# Convert to manager
return composer.to_manager()
[docs]
@classmethod
def create_message_state(
cls,
additional_fields: dict[str, Any] | None = None,
name: str = "MessageState",
) -> type[StateSchema]:
"""Create a schema with messages field and additional fields.
Args:
additional_fields: Optional dictionary of additional fields to add
name: Name for the schema
Returns:
StateSchema subclass with messages field
"""
from haive.core.schema.schema_composer import SchemaComposer
# Use SchemaComposer to create the schema
return SchemaComposer.create_message_state(
additional_fields=additional_fields, name=name
)