Source code for haive.core.schema.field_extractor
"""Field extractor utility for the Haive Schema System.
from typing import Any
This module provides the FieldExtractor class, which offers a standardized way to
extract field definitions from various sources including Pydantic models, engines,
and dictionary specifications. It ensures consistent field handling throughout the
Haive Schema System, serving as a key component for dynamic schema composition.
The FieldExtractor enables automatic discovery of fields and their metadata from
existing components, making it possible to build schemas that properly integrate
with those components without manual field specification. This is particularly
valuable when working with complex systems where fields need to be shared across
multiple components or where field specifications are distributed across different
parts of the system.
Key capabilities include:
- Extracting field definitions from Pydantic models (including annotations)
- Discovering input and output fields from engine components
- Identifying shared fields and reducer functions
- Mapping engine I/O relationships for state management
- Handling structured output models
Examples:
from haive.core.schema import FieldExtractor
# Extract fields from a list of components
field_defs, engine_io_mappings, structured_model_fields, structured_models = (
FieldExtractor.extract_from_components([
retriever_engine,
llm_engine,
memory_component
])
)
# Fields are returned as FieldDefinition objects
for name, field_def in field_defs.items():
print(f"Field: {name}, Type: {field_def.field_type}")
# Engine I/O mappings show which fields are used by which engines
for engine, mapping in engine_io_mappings.items():
print(f"Engine: {engine}")
print(f" Inputs: {mapping['inputs']}")
print(f" Outputs: {mapping['outputs']}")
"""
import logging
from collections import defaultdict
from collections.abc import Callable, Sequence
from typing import Any, Optional, TypeVar
from pydantic import BaseModel, Field
from haive.core.schema.field_definition import FieldDefinition
from haive.core.schema.field_utils import extract_type_metadata, infer_field_type
logger = logging.getLogger(__name__)
# Type variable for return types
T = TypeVar("T")
[docs]
class FieldExtractor:
"""Unified utility for extracting field definitions from various sources.
The FieldExtractor class provides static methods for extracting field definitions,
shared fields, reducer functions, and engine I/O mappings from various components
in the Haive ecosystem. It's designed to work with:
1. Pydantic models and model classes
2. Engine components with get_input_fields/get_output_fields methods
3. Components with structured_output_model attributes
4. Dictionary-based field specifications
This class is a key component of the Haive Schema System's composition capabilities,
enabling automatic discovery and integration of fields from different parts of an
application. By standardizing field extraction, it ensures consistent handling of
field metadata throughout the framework.
The extraction methods are designed to be comprehensive, gathering not only basic
field information like types and defaults, but also Haive-specific metadata such as:
- Whether fields are shared between parent and child graphs
- Reducer functions for combining field values during updates
- Input/output relationships with specific engines
- Structured output model associations
All methods are static and don't require instantiation of the class.
"""
[docs]
@staticmethod
def extract_from_model(
model_cls: type[BaseModel],
) -> tuple[
dict[str, tuple[Any, Any]], # fields
dict[str, str], # descriptions
set[str], # shared_fields
dict[str, str], # reducer_names
dict[str, Callable], # reducer_functions
dict[str, dict[str, list[str]]], # engine_io_mappings
dict[str, set[str]], # input_fields
dict[str, set[str]], # output_fields
]:
"""Extract all field information from a Pydantic model.
This method extracts standard field information as well as Haive-specific
metadata like shared fields, reducers, and engine I/O mappings.
Args:
model_cls: Pydantic model class to extract from
Returns:
Tuple of (fields, descriptions, shared_fields, reducer_names,
reducer_functions, engine_io_mappings, input_fields,
output_fields)
"""
fields = {}
descriptions = {}
shared_fields = set()
reducer_names = {}
reducer_functions = {}
engine_io_mappings = {}
input_fields = defaultdict(set)
output_fields = defaultdict(set)
# Check if it's a Pydantic model
if not (isinstance(model_cls, type) and issubclass(model_cls, BaseModel)):
logger.warning(f"Not a Pydantic model: {model_cls}")
return (
fields,
descriptions,
shared_fields,
reducer_names,
reducer_functions,
engine_io_mappings,
input_fields,
output_fields,
)
# Extract shared fields from class
if hasattr(model_cls, "__shared_fields__"):
shared_fields.update(model_cls.__shared_fields__)
# Extract reducer information from class
if hasattr(model_cls, "__serializable_reducers__"):
reducer_names.update(model_cls.__serializable_reducers__)
if hasattr(model_cls, "__reducer_fields__"):
reducer_functions.update(model_cls.__reducer_fields__)
# Extract engine I/O mappings from class
if hasattr(model_cls, "__engine_io_mappings__"):
engine_io_mappings = model_cls.__engine_io_mappings__.copy()
# Extract input/output field mappings
if hasattr(model_cls, "__input_fields__"):
for engine, fields_list in model_cls.__input_fields__.items():
input_fields[engine].update(fields_list)
if hasattr(model_cls, "__output_fields__"):
for engine, fields_list in model_cls.__output_fields__.items():
output_fields[engine].update(fields_list)
# Get all fields from model_fields (Pydantic v2)
for field_name, field_info in model_cls.model_fields.items():
# Skip internal fields
if field_name.startswith("__") or field_name == "runnable_config":
continue
# Get field type and extract any annotations
field_type = field_info.annotation
base_type, meta = extract_type_metadata(field_type)
# Create a field definition
field_def = FieldDefinition.extract_from_model_field(
name=field_name, field_type=field_type, field_info=field_info
)
# Check if field is shared
if field_name in shared_fields:
field_def.shared = True
# Check if field has a reducer
if field_name in reducer_functions:
field_def.reducer = reducer_functions[field_name]
# Check if field is used in engine I/O
for engine_name, mapping in engine_io_mappings.items():
if field_name in mapping.get("inputs", []):
field_def.input_for.append(engine_name)
input_fields[engine_name].add(field_name)
if field_name in mapping.get("outputs", []):
field_def.output_from.append(engine_name)
output_fields[engine_name].add(field_name)
# Extract field info for return
fields[field_name] = field_def.to_field_info()
if field_def.description:
descriptions[field_name] = field_def.description
# Store reducer name if available
if field_def.reducer:
reducer_name = field_def.get_reducer_name()
if reducer_name:
reducer_names[field_name] = reducer_name
reducer_functions[field_name] = field_def.reducer
return (
fields,
descriptions,
shared_fields,
reducer_names,
reducer_functions,
engine_io_mappings,
input_fields,
output_fields,
)
[docs]
@staticmethod
def extract_from_engine(
engine: Any,
) -> tuple[
dict[str, tuple[Any, Any]], # fields
dict[str, str], # descriptions
dict[str, dict[str, list[str]]], # engine_io_mappings
dict[str, set[str]], # input_fields
dict[str, set[str]], # output_fields
]:
"""Extract all field information from an engine.
This method extracts field information specific to engines, including
input and output fields, as well as structured output models.
Args:
engine: Engine instance to extract from
Returns:
Tuple of (fields, descriptions, engine_io_mappings, input_fields, output_fields)
"""
fields = {}
descriptions = {}
engine_io_mappings = {}
input_fields = defaultdict(set)
output_fields = defaultdict(set)
# Extract engine name for tracking
engine_name = getattr(engine, "name", str(engine))
# Create an initial empty mapping for this engine
engine_io_mappings[engine_name] = {"inputs": [], "outputs": []}
# Try different methods to extract field information
# Method 1: Check for get_input_fields and get_output_fields methods
input_fields_dict = {}
output_fields_dict = {}
# Extract input fields
if hasattr(engine, "get_input_fields") and callable(engine.get_input_fields):
try:
input_fields_dict = engine.get_input_fields()
for field_name, (field_type, field_info) in input_fields_dict.items():
# Skip internal or special fields
if field_name.startswith("__") or field_name == "runnable_config":
continue
# Create a field definition
field_def = FieldDefinition(
name=field_name,
field_type=field_type,
field_info=field_info,
source=engine_name,
input_for=[engine_name],
)
# Extract field info for return
fields[field_name] = field_def.to_field_info()
if field_def.description:
descriptions[field_name] = field_def.description
# Track as input field
input_fields[engine_name].add(field_name)
except Exception as e:
logger.warning(f"Error getting input_fields from {engine_name}: {e}")
# Extract output fields
if hasattr(engine, "get_output_fields") and callable(engine.get_output_fields):
try:
# Only extract output fields if no structured output model
# This prevents the duplication problem
if (
not hasattr(engine, "structured_output_model")
or engine.structured_output_model is None
):
output_fields_dict = engine.get_output_fields()
for field_name, (
field_type,
field_info,
) in output_fields_dict.items():
# Skip if field already exists - keep input fields as
# priority
if field_name in fields:
# Just mark as output field
output_fields[engine_name].add(field_name)
# Update existing field definition
field_type, field_info = fields[field_name]
field_def = FieldDefinition(
name=field_name,
field_type=field_type,
field_info=field_info,
)
# Add output_from to field
if engine_name not in field_def.output_from:
field_def.output_from.append(engine_name)
# Update field info
fields[field_name] = field_def.to_field_info()
continue
# Create a field definition
field_def = FieldDefinition(
name=field_name,
field_type=field_type,
field_info=field_info,
source=engine_name,
output_from=[engine_name],
)
# Extract field info for return
fields[field_name] = field_def.to_field_info()
if field_def.description:
descriptions[field_name] = field_def.description
# Track as output field
output_fields[engine_name].add(field_name)
except Exception as e:
logger.warning(f"Error getting output_fields from {engine_name}: {e}")
# Method 2: Check for structured_output_model
if (
hasattr(engine, "structured_output_model")
and engine.structured_output_model is not None
):
try:
model = engine.structured_output_model
# Use proper field naming utilities
from haive.core.schema.field_utils import get_field_info_from_model
field_info_dict = get_field_info_from_model(model)
model_name = field_info_dict["field_name"]
field_description = field_info_dict.get(
"description", f"Output in {model.__name__} format"
)
field_type = field_info_dict.get("field_type", Optional[model])
logger.info(
f"Found structured_output_model in {engine_name}: {model.__name__} -> {
model_name
}"
)
# Add a single field for the entire model
field_info = Field(default=None, description=field_description)
# Create a field definition
field_def = FieldDefinition(
name=model_name,
field_type=field_type,
field_info=field_info,
source=engine_name,
output_from=[engine_name],
structured_model=model.__name__,
)
# Extract field info for return
fields[model_name] = field_def.to_field_info()
descriptions[model_name] = field_def.description
# Track as output field
output_fields[engine_name].add(model_name)
except Exception as e:
logger.warning(f"Error extracting structured_output_model from {engine_name}: {e}")
# Update engine I/O mappings
engine_io_mappings[engine_name]["inputs"] = list(input_fields[engine_name])
engine_io_mappings[engine_name]["outputs"] = list(output_fields[engine_name])
return fields, descriptions, engine_io_mappings, input_fields, output_fields
[docs]
@staticmethod
def extract_from_dict(
data: dict[str, Any],
) -> tuple[
dict[str, tuple[Any, Any]], # fields
dict[str, str], # descriptions
set[str], # shared_fields
dict[str, str], # reducer_names
dict[str, Callable], # reducer_functions
dict[str, dict[str, list[str]]], # engine_io_mappings
dict[str, set[str]], # input_fields
dict[str, set[str]], # output_fields
]:
"""Extract fields from a dictionary definition.
This method extracts field information from a dictionary, which can
be provided in various formats.
Args:
data: Dictionary containing field definitions
Returns:
Tuple of (fields, descriptions, shared_fields, reducer_names,
reducer_functions, engine_io_mappings, input_fields,
output_fields)
"""
fields = {}
descriptions = {}
shared_fields = set()
reducer_names = {}
reducer_functions = {}
engine_io_mappings = {}
input_fields = defaultdict(set)
output_fields = defaultdict(set)
# Process special metadata keys
for key, value in data.items():
if key == "shared_fields":
shared_fields.update(value)
continue
if key in {"reducer_names", "serializable_reducers"}:
reducer_names.update(value)
continue
if key == "reducer_functions":
reducer_functions.update(value)
continue
if key == "field_descriptions":
descriptions.update(value)
continue
if key == "engine_io_mappings":
engine_io_mappings.update(value)
continue
if key == "input_fields":
for engine, fields_list in value.items():
input_fields[engine].update(fields_list)
continue
if key == "output_fields":
for engine, fields_list in value.items():
output_fields[engine].update(fields_list)
continue
# Process field definitions
if isinstance(value, tuple) and len(value) >= 2:
# Handle (type, default) format
field_type, default = value[0:2]
# Check for extra metadata
field_metadata = {}
if len(value) >= 3 and isinstance(value[2], dict):
field_metadata = value[2]
if "description" in field_metadata:
descriptions[key] = field_metadata["description"]
# Check if default is a factory function
if callable(default) and not isinstance(default, type):
field_info = Field(default_factory=default, **field_metadata)
else:
field_info = Field(default=default, **field_metadata)
# Create field definition
field_def = FieldDefinition(
name=key,
field_type=field_type,
field_info=field_info,
shared=key in shared_fields,
reducer=reducer_functions.get(key),
)
# Check engine I/O mappings
for engine_name, mapping in engine_io_mappings.items():
if key in mapping.get("inputs", []):
field_def.input_for.append(engine_name)
if key in mapping.get("outputs", []):
field_def.output_from.append(engine_name)
# Extract field info
fields[key] = field_def.to_field_info()
# Store reducer name if available
if field_def.reducer:
reducer_name = field_def.get_reducer_name()
if reducer_name:
reducer_names[key] = reducer_name
else:
# Field with value only - infer type
field_type = infer_field_type(value)
field_info = Field(default=value)
# Create field definition
field_def = FieldDefinition(
name=key,
field_type=field_type,
field_info=field_info,
shared=key in shared_fields,
reducer=reducer_functions.get(key),
)
# Extract field info
fields[key] = field_def.to_field_info()
# Return all extracted data
return (
fields,
descriptions,
shared_fields,
reducer_names,
reducer_functions,
engine_io_mappings,
input_fields,
output_fields,
)
[docs]
@staticmethod
def extract_from_components(
components: list[Any], include_messages_field: bool = True
) -> tuple[
dict[str, FieldDefinition], # All field definitions
dict[str, dict[str, list[str]]], # Engine I/O mappings
dict[str, set[str]], # Structured model fields
dict[str, type], # Structured models
]:
"""Extract field definitions from a list of heterogeneous components.
This is a high-level method that extracts field definitions from various
component types (engines, models, dictionaries) and returns them in a
consistent format. It's designed to work with mixed collections of components
and serves as the primary entry point for schema composition.
The method processes each component according to its type:
- Engine components: Uses get_input_fields/get_output_fields and looks for structured_output_model
- Pydantic models: Extracts fields, shared fields, reducers, and engine mappings
- Dictionaries: Processes field definitions in dictionary format
It also handles field conflict resolution by merging field definitions when
the same field appears in multiple components.
Args:
components (List[Any]): List of components to extract fields from. Can include
engine instances, Pydantic models, model classes, and dictionaries.
include_messages_field (bool, optional): Whether to automatically add a
messages field with appropriate reducer if one doesn't exist in the
components. This is useful for conversation-based agents. Defaults to True.
Returns:
Tuple containing:
- Dict[str, FieldDefinition]: Dictionary mapping field names to their
complete FieldDefinition objects
- Dict[str, Dict[str, List[str]]]: Engine I/O mappings showing which
fields are inputs/outputs for which engines
- Dict[str, Set[str]]: Structured model fields, mapping model names
to sets of field names within those models
- Dict[str, Type]: Structured model types, mapping model names to
their actual class types
Examples:
# Create a list of components
components = [
retriever_engine, # Engine with get_input/output_fields
ConversationMemory(), # Pydantic model instance
ResponseGeneratorConfig, # Pydantic model class
{ # Dictionary-based field definition
"custom_field": (str, "", {"description": "Custom field"}),
"shared_fields": ["messages"]
}
]
# Extract field definitions
field_defs, io_mappings, model_fields, models = (
FieldExtractor.extract_from_components(components)
)
# Field definitions can be used with SchemaComposer
composer = SchemaComposer(name="AgentState")
for name, field_def in field_defs.items():
composer.add_field_definition(field_def)
# Build the schema
AgentState = composer.build()
"""
field_definitions = {}
engine_io_mappings = {}
structured_model_fields = defaultdict(set)
structured_models = {}
# Process each component
for component in components:
if component is None:
continue
# Process based on type
if hasattr(component, "engine_type"):
# Engine component
fields, descriptions, io_mappings, in_fields, out_fields = (
FieldExtractor.extract_from_engine(component)
)
# Convert to FieldDefinition objects
for field_name, (field_type, field_info) in fields.items():
# Check if field already exists
if field_name in field_definitions:
# Merge with existing field
existing_field = field_definitions[field_name]
# Update input_for and output_from
for engine_name, field_set in in_fields.items():
if (
field_name in field_set
and engine_name not in existing_field.input_for
):
existing_field.input_for.append(engine_name)
for engine_name, field_set in out_fields.items():
if (
field_name in field_set
and engine_name not in existing_field.output_from
):
existing_field.output_from.append(engine_name)
else:
# Create new field definition
field_def = FieldDefinition(
name=field_name,
field_type=field_type,
field_info=field_info,
source=getattr(component, "name", str(component)),
)
# Add input_for and output_from
for engine_name, field_set in in_fields.items():
if field_name in field_set:
field_def.input_for.append(engine_name)
for engine_name, field_set in out_fields.items():
if field_name in field_set:
field_def.output_from.append(engine_name)
field_definitions[field_name] = field_def
# Update engine I/O mappings
for engine_name, mapping in io_mappings.items():
engine_io_mappings[engine_name] = mapping
# Check for structured output model
if (
hasattr(component, "structured_output_model")
and component.structured_output_model
):
model = component.structured_output_model
model_name = model.__name__.lower()
# Store model reference
structured_models[model_name] = model
# Extract model fields
if hasattr(model, "model_fields"):
for field_name in model.model_fields:
structured_model_fields[model_name].add(field_name)
elif isinstance(component, BaseModel):
# Pydantic model instance
(
fields,
descriptions,
shared,
reducer_names,
reducer_funcs,
io_mappings,
in_fields,
out_fields,
) = FieldExtractor.extract_from_model(component.__class__)
# Convert to FieldDefinition objects
for field_name, (field_type, field_info) in fields.items():
# Create field definition
field_def = FieldDefinition(
name=field_name,
field_type=field_type,
field_info=field_info,
shared=field_name in shared,
reducer=reducer_funcs.get(field_name),
source=component.__class__.__name__,
)
# Add to field definitions if not already present
if field_name not in field_definitions:
field_definitions[field_name] = field_def
# Update engine I/O mappings
for engine_name, mapping in io_mappings.items():
engine_io_mappings[engine_name] = mapping
elif isinstance(component, type) and issubclass(component, BaseModel):
# Pydantic model class
(
fields,
descriptions,
shared,
reducer_names,
reducer_funcs,
io_mappings,
in_fields,
out_fields,
) = FieldExtractor.extract_from_model(component)
# Convert to FieldDefinition objects
for field_name, (field_type, field_info) in fields.items():
# Create field definition
field_def = FieldDefinition(
name=field_name,
field_type=field_type,
field_info=field_info,
shared=field_name in shared,
reducer=reducer_funcs.get(field_name),
source=component.__name__,
)
# Add to field definitions if not already present
if field_name not in field_definitions:
field_definitions[field_name] = field_def
# Update engine I/O mappings
for engine_name, mapping in io_mappings.items():
engine_io_mappings[engine_name] = mapping
elif isinstance(component, dict):
# Dictionary of field definitions
(
fields,
descriptions,
shared,
reducer_names,
reducer_funcs,
io_mappings,
in_fields,
out_fields,
) = FieldExtractor.extract_from_dict(component)
# Convert to FieldDefinition objects
for field_name, (field_type, field_info) in fields.items():
# Create field definition
field_def = FieldDefinition(
name=field_name,
field_type=field_type,
field_info=field_info,
shared=field_name in shared,
reducer=reducer_funcs.get(field_name),
source="dictionary",
)
# Add to field definitions if not already present
if field_name not in field_definitions:
field_definitions[field_name] = field_def
# Update engine I/O mappings
for engine_name, mapping in io_mappings.items():
engine_io_mappings[engine_name] = mapping
# Ensure messages field exists if requested
if include_messages_field and "messages" not in field_definitions:
from langchain_core.messages import BaseMessage
# Try to use add_messages reducer if available
reducer = None
try:
from langgraph.graph import add_messages
reducer = add_messages
except ImportError:
# Fallback to a simple list concatenation
def concat_lists(a, b) -> Any:
"""Concat Lists.
Args:
a: [TODO: Add description]
b: [TODO: Add description]
Returns:
[TODO: Add return description]
"""
return (a or []) + (b or [])
reducer = concat_lists
# Create messages field
field_def = FieldDefinition(
name="messages",
field_type=Sequence[BaseMessage],
# default_factory=list,
description="Messages for conversation",
reducer=reducer,
)
field_definitions["messages"] = field_def
return (
field_definitions,
engine_io_mappings,
structured_model_fields,
structured_models,
)