"""Schema merging strategies for combining multiple schemas."""
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, create_model
from haive.core.schema.compatibility.analyzer import TypeAnalyzer
from haive.core.schema.compatibility.compatibility import CompatibilityChecker
from haive.core.schema.compatibility.types import FieldInfo, MergeStrategy, SchemaInfo
[docs]
class ConflictResolution(str, Enum):
"""How to resolve field conflicts during merge."""
FIRST_WINS = "first_wins"
LAST_WINS = "last_wins"
TYPE_UNION = "type_union"
MOST_SPECIFIC = "most_specific"
MOST_GENERAL = "most_general"
CUSTOM = "custom"
[docs]
class MergeContext:
"""Context for merge operations."""
def __init__(self) -> None:
"""Init .
Returns:
[TODO: Add return description]
"""
self.conflicts: list[dict[str, Any]] = []
self.warnings: list[str] = []
# field -> [schema_names]
self.field_sources: dict[str, list[str]] = {}
self.resolution_log: list[str] = []
[docs]
def add_conflict(
self,
field_name: str,
schemas: list[str],
reason: str,
resolution: str,
) -> None:
"""Log a conflict."""
self.conflicts.append(
{
"field": field_name,
"schemas": schemas,
"reason": reason,
"resolution": resolution,
}
)
self.resolution_log.append(
f"Conflict in '{field_name}': {reason}. Resolution: {resolution}"
)
[docs]
def add_warning(self, warning: str) -> None:
"""Add a warning."""
self.warnings.append(warning)
[docs]
def track_field_source(self, field_name: str, schema_name: str) -> None:
"""Track which schema a field came from."""
if field_name not in self.field_sources:
self.field_sources[field_name] = []
self.field_sources[field_name].append(schema_name)
[docs]
class MergeStrategy(ABC):
"""Abstract base for merge strategies."""
[docs]
@abstractmethod
def merge_fields(
self,
field_infos: list[tuple[str, FieldInfo]], # (schema_name, field_info)
context: MergeContext,
) -> FieldInfo | None:
"""Merge multiple field definitions."""
[docs]
@abstractmethod
def should_include_field(
self,
field_name: str,
schemas_with_field: list[str],
total_schemas: int,
) -> bool:
"""Determine if field should be included in merged schema."""
[docs]
class UnionMergeStrategy(MergeStrategy):
"""Include all fields from all schemas."""
def __init__(
self, conflict_resolution: ConflictResolution = ConflictResolution.LAST_WINS
):
"""Init .
Args:
conflict_resolution: [TODO: Add description]
"""
self.conflict_resolution = conflict_resolution
[docs]
def should_include_field(
self,
field_name: str,
schemas_with_field: list[str],
total_schemas: int,
) -> bool:
"""Include all fields."""
return True
[docs]
def merge_fields(
self,
field_infos: list[tuple[str, FieldInfo]],
context: MergeContext,
) -> FieldInfo | None:
"""Merge field definitions."""
if not field_infos:
return None
if len(field_infos) == 1:
return field_infos[0][1]
# Check for conflicts
field_name = field_infos[0][1].name
types = [(name, info.type_info.type_hint) for name, info in field_infos]
unique_types = list({t[1] for t in types})
if len(unique_types) > 1:
# Type conflict
context.add_conflict(
field_name,
[name for name, _ in field_infos],
f"Different types: {unique_types}",
self.conflict_resolution.value,
)
if self.conflict_resolution == ConflictResolution.FIRST_WINS:
return field_infos[0][1]
if self.conflict_resolution == ConflictResolution.LAST_WINS:
return field_infos[-1][1]
if self.conflict_resolution == ConflictResolution.TYPE_UNION:
# Create union type
from typing import Union
union_type = Union[tuple(unique_types)]
merged = field_infos[-1][1]
merged.type_info.type_hint = union_type
return merged
# Merge other properties
# Use the most restrictive settings
is_required = any(info.is_required for _, info in field_infos)
all(info.has_default for _, info in field_infos)
# Take the last definition as base
merged = field_infos[-1][1]
merged.is_required = is_required
# Merge metadata
merged.is_shared = any(info.is_shared for _, info in field_infos)
# Combine engine mappings
for _, info in field_infos:
merged.input_for_engines.update(info.input_for_engines)
merged.output_from_engines.update(info.output_from_engines)
return merged
[docs]
class IntersectionMergeStrategy(MergeStrategy):
"""Include only fields present in all schemas."""
[docs]
def should_include_field(
self,
field_name: str,
schemas_with_field: list[str],
total_schemas: int,
) -> bool:
"""Include only if in all schemas."""
return len(schemas_with_field) == total_schemas
[docs]
def merge_fields(
self,
field_infos: list[tuple[str, FieldInfo]],
context: MergeContext,
) -> FieldInfo | None:
"""Merge field definitions."""
if not field_infos:
return None
# For intersection, ensure types are compatible
analyzer = TypeAnalyzer()
base_type = field_infos[0][1].type_info.type_hint
for _schema_name, field_info in field_infos[1:]:
if not analyzer.is_subtype(field_info.type_info.type_hint, base_type):
context.add_warning(
f"Type incompatibility in intersection for '{field_info.name}'"
)
# Use most general type
base_type = Any
# Use first definition as base
merged = field_infos[0][1]
merged.type_info.type_hint = base_type
# Use most permissive settings
merged.is_required = all(info.is_required for _, info in field_infos)
merged.is_shared = all(info.is_shared for _, info in field_infos)
return merged
[docs]
class SchemaMerger:
"""Main schema merging engine."""
def __init__(
self,
strategy: MergeStrategy | str = "union",
analyzer: TypeAnalyzer | None = None,
compatibility_checker: CompatibilityChecker | None = None,
):
"""Init .
Args:
strategy: [TODO: Add description]
analyzer: [TODO: Add description]
compatibility_checker: [TODO: Add description]
"""
self.analyzer = analyzer or TypeAnalyzer()
self.compatibility_checker = compatibility_checker or CompatibilityChecker()
# Set strategy
if isinstance(strategy, str):
if strategy == "union":
self.strategy = UnionMergeStrategy()
elif strategy == "intersection":
self.strategy = IntersectionMergeStrategy()
else:
raise ValueError(f"Unknown strategy: {strategy}")
else:
self.strategy = strategy
[docs]
def merge_schemas(
self,
schemas: list[type[BaseModel] | SchemaInfo],
name: str | None = None,
base_class: type[BaseModel] | None = None,
) -> type[BaseModel]:
"""Merge multiple schemas into one.
Args:
schemas: List of schemas to merge
name: Name for the merged schema
base_class: Base class for the merged schema
Returns:
Merged schema class
"""
if not schemas:
raise ValueError("No schemas to merge")
# Convert all to SchemaInfo
schema_infos = []
for _i, schema in enumerate(schemas):
if isinstance(schema, type) and issubclass(schema, BaseModel):
info = self.analyzer.analyze_schema(schema)
info.name = schema.__name__
else:
info = schema
schema_infos.append(info)
# Merge
context = MergeContext()
merged_info = self._merge_schema_infos(schema_infos, context)
# Generate name
if not name:
name = f"Merged{''.join(s.name for s in schema_infos[:3])}"
if len(schema_infos) > 3:
name += f"And{len(schema_infos) - 3}More"
# Create model
return self._create_model_from_info(merged_info, name, base_class)
def _merge_schema_infos(
self,
schema_infos: list[SchemaInfo],
context: MergeContext,
) -> SchemaInfo:
"""Merge SchemaInfo objects."""
# Group fields by name
field_groups: dict[str, list[tuple[str, FieldInfo]]] = {}
for schema_info in schema_infos:
for field_name, field_info in schema_info.fields.items():
if field_name not in field_groups:
field_groups[field_name] = []
field_groups[field_name].append((schema_info.name, field_info))
context.track_field_source(field_name, schema_info.name)
# Create merged schema info
merged = SchemaInfo(
name="MergedSchema",
type_info=schema_infos[0].type_info, # Use first as template
)
# Merge fields
total_schemas = len(schema_infos)
for field_name, field_list in field_groups.items():
schemas_with_field = [name for name, _ in field_list]
if self.strategy.should_include_field(
field_name, schemas_with_field, total_schemas
):
merged_field = self.strategy.merge_fields(field_list, context)
if merged_field:
merged.fields[field_name] = merged_field
# Merge metadata
for schema_info in schema_infos:
merged.shared_fields.update(schema_info.shared_fields)
# Merge reducers with conflict detection
for field, reducer in schema_info.reducer_fields.items():
if (
field in merged.reducer_fields
and merged.reducer_fields[field] != reducer
):
context.add_conflict(
field,
[schema_info.name],
"Different reducers",
"Keeping existing",
)
else:
merged.reducer_fields[field] = reducer
# Merge engine mappings
for engine, mapping in schema_info.engine_io_mappings.items():
if engine not in merged.engine_io_mappings:
merged.engine_io_mappings[engine] = {"inputs": [], "outputs": []}
merged.engine_io_mappings[engine]["inputs"].extend(
mapping.get("inputs", [])
)
merged.engine_io_mappings[engine]["outputs"].extend(
mapping.get("outputs", [])
)
# Deduplicate engine mappings
for engine, mapping in merged.engine_io_mappings.items():
mapping["inputs"] = list(set(mapping["inputs"]))
mapping["outputs"] = list(set(mapping["outputs"]))
return merged
def _create_model_from_info(
self,
schema_info: SchemaInfo,
name: str,
base_class: type[BaseModel] | None = None,
) -> type[BaseModel]:
"""Create a Pydantic model from SchemaInfo."""
# Build field definitions
field_definitions = {}
for field_name, field_info in schema_info.fields.items():
# Create field kwargs
field_kwargs = {}
if not field_info.is_required:
if field_info.default_factory:
field_kwargs["default_factory"] = field_info.default_factory
elif field_info.default_value is not None:
field_kwargs["default"] = field_info.default_value
else:
field_kwargs["default"] = None
if field_info.description:
field_kwargs["description"] = field_info.description
if field_info.alias:
field_kwargs["alias"] = field_info.alias
# Create field
field_definitions[field_name] = (
field_info.type_info.type_hint,
Field(**field_kwargs) if field_kwargs else ...,
)
# Choose base class
if base_class is None:
# Try to detect appropriate base class
if "messages" in schema_info.fields:
try:
from haive.core.schema.prebuilt.messages_state import MessagesState
base_class = MessagesState
except ImportError:
base_class = BaseModel
else:
base_class = BaseModel
# Create model
model = create_model(
name,
__base__=base_class,
**field_definitions,
)
# Add metadata as class attributes
if schema_info.shared_fields:
model.__shared_fields__ = list(schema_info.shared_fields)
if schema_info.reducer_fields:
model.__reducer_fields__ = dict(schema_info.reducer_fields)
if schema_info.engine_io_mappings:
model.__engine_io_mappings__ = dict(schema_info.engine_io_mappings)
return model
# Convenience functions
[docs]
def merge_schemas(
schemas: list[type[BaseModel] | SchemaInfo],
strategy: str = "union",
name: str | None = None,
) -> type[BaseModel]:
"""Merge multiple schemas using specified strategy."""
merger = SchemaMerger(strategy=strategy)
return merger.merge_schemas(schemas, name=name)
[docs]
def create_union_schema(
*schemas: type[BaseModel] | SchemaInfo,
name: str | None = None,
) -> type[BaseModel]:
"""Create a union of multiple schemas."""
return merge_schemas(list(schemas), strategy="union", name=name)
[docs]
def create_intersection_schema(
*schemas: type[BaseModel] | SchemaInfo,
name: str | None = None,
) -> type[BaseModel]:
"""Create an intersection of multiple schemas."""
return merge_schemas(list(schemas), strategy="intersection", name=name)