Source code for haive.core.schema.utils

"""Utility functions for schema manipulation in the Haive framework.

This module provides the SchemaUtils class containing static methods for working with
schemas in the Haive Schema System. It includes utilities for formatting type
annotations, extracting field information, creating field definitions, and building
schemas programmatically.

The utilities in this module are primarily used by the SchemaComposer and
StateSchemaManager classes, but they can also be useful for custom schema
manipulation tasks or when working with schemas directly.

Key capabilities include:
- Type annotation formatting for readable representation of complex types
- Field information extraction from Pydantic field objects
- Pydantic field creation with proper metadata
- Support for special types like Optional, Union, and generics
- Helper functions for schema display and debug visualization

Examples:
            from haive.core.schema.utils import SchemaUtils
            from typing import List, Optional

            # Format a type annotation
            type_str = SchemaUtils.format_type_annotation(List[Optional[str]])
            print(type_str)  # "List[Optional[str]]"

            # Extract field info from a Pydantic model
            from pydantic import BaseModel, Field

            class MyModel(BaseModel):
                name: str = Field(default="default", description="User name")

            field_info = MyModel.model_fields["name"]
            default, default_repr, desc = SchemaUtils.extract_field_info(field_info)
            # default = "default", desc = "User name"
"""

import logging
from collections.abc import Callable
from typing import Any, TypeVar

from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

# Type variables
T = TypeVar("T", bound=BaseModel)
FieldType = TypeVar("FieldType")
DefaultType = TypeVar("DefaultType")

# Logger setup
logger = logging.getLogger(__name__)


[docs] class SchemaUtils: """Utility functions for schema manipulation and formatting. This class provides static methods for working with schema-related tasks such as formatting type annotations, extracting field information, and building state schemas from components. The methods focus on common operations needed when programmatically working with schemas and Pydantic models in the Haive framework. Key methods include: - format_type_annotation: Creates readable string representations of type annotations - extract_field_info: Extracts default values and descriptions from Pydantic fields - create_pydantic_field: Creates Pydantic Field objects with proper metadata - extract_model_info: Analyzes a Pydantic model to extract useful information - format_default_value: Creates string representations of default values - analyze_field_type: Determines characteristics of field types (generic, optional, etc.) These utilities are designed to work with both Pydantic v1 and v2, handling differences in field structures and model internals. They're used extensively throughout the Haive Schema System to provide consistent handling of types and field definitions. """
[docs] @staticmethod def format_type_annotation(type_hint: Any) -> str: """Format a type hint for pretty printing. Creates a clean, readable string representation of a type annotation. Args: type_hint: The type hint to format. Returns: A clean string representation of the type. """ # Handle primitive types if type_hint is str: return "str" if type_hint is int: return "int" if type_hint is float: return "float" if type_hint is bool: return "bool" if type_hint is list: return "list" if type_hint is dict: return "dict" if type_hint is None or type_hint is type(None): return "None" # Handle typing annotations type_str = str(type_hint) # Remove 'typing.' prefix type_str = type_str.replace("typing.", "") # Handle special case of Optional and Union if type_str.startswith(("Optional[", "Union[")): return type_str # Handle nested annotations if type_str.startswith("<class '"): # Extract the actual type name from <class 'type'> type_name = type_str.strip("<class '").strip("'>") # Handle built-in types if type_name.startswith("builtins."): return type_name[9:] # Remove builtins. prefix return type_name return type_str
[docs] @staticmethod def extract_field_info(field_info: FieldInfo) -> tuple[Any, str, str | None]: """Extract useful information from a Pydantic FieldInfo. Args: field_info: Pydantic field info object. Returns: Tuple of (default_value, default_string_representation, description). """ description = getattr(field_info, "description", None) # Handle default_factory if ( hasattr(field_info, "default_factory") and field_info.default_factory is not None ): factory = field_info.default_factory if factory == list: default_str = " = Field(default_factory=list)" elif hasattr(factory, "__name__"): default_str = f" = Field(default_factory={factory.__name__})" else: default_str = f" = Field(default_factory=lambda: {factory()})" return factory, default_str, description # Handle normal default if hasattr(field_info, "default") and field_info.default != ...: default = field_info.default if default is None: default_str = " = None" elif isinstance(default, str): default_str = f' = "{default}"' elif isinstance(default, int | float | bool): default_str = f" = {default}" else: default_str = f" = {default!r}" return default, default_str, description # Handle required field return ..., "", description
[docs] @staticmethod def format_schema_as_python( schema_name: str, fields: dict[str, tuple[Any, FieldInfo]], properties: dict[str, Any] | None = None, computed_properties: dict[str, Any] | None = None, class_methods: dict[str, Any] | None = None, static_methods: dict[str, Any] | None = None, field_descriptions: dict[str, str] | None = None, shared_fields: set[str] | None = None, reducer_fields: dict[str, Callable] | None = None, base_class: str = "StateSchema", ) -> str: """Format a schema definition as Python code. Creates a string representation of a schema class with all its fields, properties, methods, and metadata. Args: schema_name: Name of the schema class. fields: Dictionary of field names to (type, field_info) tuples. properties: Optional dictionary of property names to property methods. computed_properties: Optional dictionary of computed property definitions. class_methods: Optional dictionary of class method names to methods. static_methods: Optional dictionary of static method names to methods. field_descriptions: Optional dictionary of field descriptions. shared_fields: Optional set of field names that are shared with parent. reducer_fields: Optional dictionary of fields with reducer functions. base_class: Base class name for the schema. Returns: String containing the Python code representation. """ properties = properties or {} computed_properties = computed_properties or {} class_methods = class_methods or {} static_methods = static_methods or {} field_descriptions = field_descriptions or {} shared_fields = shared_fields or set() reducer_fields = reducer_fields or {} # Generate the schema representation output = [] output.append(f"class {schema_name}({base_class}):") if not any( [fields, properties, computed_properties, class_methods, static_methods] ): output.append(" pass # No fields defined\n") return "\n".join(output) # Add field definitions for field_name, (field_type, field_info) in fields.items(): # Format the type string type_str = SchemaUtils.format_type_annotation(field_type) # Extract default info _, default_str, _ = SchemaUtils.extract_field_info(field_info) # Add description as comment if present description = field_descriptions.get(field_name, "") if description: output.append(f" # {description}") # Add special annotations for shared or reducer fields annotations = [] if field_name in shared_fields: annotations.append("shared") if field_name in reducer_fields: reducer_name = getattr( reducer_fields[field_name], "__name__", "reducer" ) annotations.append(f"reducer={reducer_name}") if annotations: output.append(f" # {', '.join(annotations)}") # Add the field definition output.append(f" {field_name}: {type_str}{default_str}") # Add properties for prop_name in properties: output.append("\n @property") output.append(f" def {prop_name}(self): ...") # Add computed properties for prop_name, (_getter, setter) in computed_properties.items(): output.append("\n @property") output.append(f" def {prop_name}(self): ...") if setter: output.append(f" @{prop_name}.setter") output.append(f" def {prop_name}(self, value): ...") # Add class methods for method_name in class_methods: output.append("\n @classmethod") output.append(f" def {method_name}(cls): ...") # Add static methods for method_name in static_methods: output.append("\n @staticmethod") output.append(f" def {method_name}(): ...") # Return the output string return "\n".join(output)
[docs] @staticmethod def build_state_schema( name: str, fields: dict[str, tuple[type, Any]], shared_fields: list[str] | None = None, reducers: dict[str, Callable] | None = None, base_class: type[BaseModel] | None = None, ) -> type[BaseModel]: """Build a state schema from field definitions. Args: name: Name for the schema class. fields: Dictionary mapping field names to (type, default) tuples. shared_fields: Optional list of fields shared with parent. reducers: Optional dictionary mapping field names to reducer functions. base_class: Optional base class (defaults to StateSchema). Returns: A new schema class. """ # Import StateSchema if no base class provided if base_class is None: from haive.core.schema.state_schema import StateSchema base_class = StateSchema # Create the model with fields model = create_model(name, __base__=base_class, **fields) # Add shared fields if shared_fields: model.__shared_fields__ = shared_fields # Add reducers if reducers: # Create serializable_reducers dict serializable_reducers = {} for field, reducer in reducers.items(): reducer_name = getattr(reducer, "__name__", str(reducer)) serializable_reducers[field] = reducer_name # Store both the serializable names and the actual reducer # functions model.__serializable_reducers__ = serializable_reducers model.__reducer_fields__ = reducers return model
[docs] @staticmethod def add_field_to_schema( schema: type[BaseModel], name: str, field_type: type, default: Any = None, description: str | None = None, shared: bool = False, reducer: Callable | None = None, ) -> type[BaseModel]: """Add a field to an existing schema class. Args: schema: Existing schema class. name: Field name to add. field_type: Type of the field. default: Default value. description: Optional field description. shared: Whether the field is shared with parent. reducer: Optional reducer function. Returns: Updated schema class with the new field. """ # Create field dict for the new model field_dict = {} # Copy existing fields for field_name, field_info in schema.model_fields.items(): field_dict[field_name] = (field_info.annotation, field_info) # Add the new field field_info = Field(default=default, description=description) field_dict[name] = (field_type, field_info) # Create new model with all fields new_model = create_model( schema.__name__, __base__=schema.__base__, **field_dict ) # Copy shared fields if hasattr(schema, "__shared_fields__"): new_model.__shared_fields__ = list(schema.__shared_fields__) # Add new field to shared fields if needed if shared and name not in new_model.__shared_fields__: new_model.__shared_fields__.append(name) # Copy reducers if hasattr(schema, "__serializable_reducers__"): new_model.__serializable_reducers__ = dict(schema.__serializable_reducers__) if hasattr(schema, "__reducer_fields__"): new_model.__reducer_fields__ = dict(schema.__reducer_fields__) # Add reducer if provided if reducer: if not hasattr(new_model, "__serializable_reducers__"): new_model.__serializable_reducers__ = {} if not hasattr(new_model, "__reducer_fields__"): new_model.__reducer_fields__ = {} reducer_name = getattr(reducer, "__name__", str(reducer)) new_model.__serializable_reducers__[name] = reducer_name new_model.__reducer_fields__[name] = reducer return new_model
[docs] @staticmethod def get_reducer_name(reducer: callable) -> str: """Get a serializable name for a reducer function. Args: reducer: Reducer function. Returns: Serializable name for the reducer. """ # Special handling for operator module functions if hasattr(reducer, "__module__") and reducer.__module__ == "operator": return f"operator.{reducer.__name__}" # Handle lambda functions if hasattr(reducer, "__name__") and reducer.__name__ == "<lambda>": return "<lambda>" # Handle standard functions with module and name if ( hasattr(reducer, "__module__") and hasattr(reducer, "__name__") and reducer.__module__ != "__main__" ): # Use fully qualified name for imported functions return f"{reducer.__module__}.{reducer.__name__}" # Use just the name if it has one if hasattr(reducer, "__name__"): return reducer.__name__ # Last resort: string representation return str(reducer)