Source code for haive.core.schema.compatibility.protocols

"""from typing import Any.
Protocol definitions for extending the schema compatibility system.
"""

from __future__ import annotations

from typing import Any, Protocol, TypeVar, runtime_checkable

from pydantic import BaseModel

from haive.core.schema.compatibility.types import (
    CompatibilityLevel,
    ConversionContext,
    FieldInfo,
    SchemaInfo,
)

T = TypeVar("T")
U = TypeVar("U")


[docs] @runtime_checkable class SchemaConvertible(Protocol): """Protocol for objects that can be converted to/from schemas."""
[docs] def to_schema(self) -> type[BaseModel]: """Convert to a Pydantic schema.""" ...
[docs] @classmethod def from_schema(cls: type[T], schema: type[BaseModel]) -> T: """Create instance from a Pydantic schema.""" ...
[docs] @runtime_checkable class FieldTransformer(Protocol): """Protocol for field transformation functions.""" def __call__(self, value: Any, context: dict[str, Any] | None = None) -> Any: """Transform a field value.""" ...
[docs] @runtime_checkable class SchemaValidator(Protocol): """Protocol for schema validators."""
[docs] def validate_schema(self, schema: SchemaInfo) -> list[str]: """Validate a schema and return list of issues.""" ...
[docs] def validate_compatibility( self, source: SchemaInfo, target: SchemaInfo, ) -> list[str]: """Validate compatibility between schemas.""" ...
[docs] @runtime_checkable class ConversionStrategy(Protocol): """Protocol for conversion strategies.""" @property def name(self) -> str: """Strategy name.""" ...
[docs] def can_convert(self, source: type, target: type) -> bool: """Check if strategy can handle conversion.""" ...
[docs] def convert( self, value: Any, source_type: type, target_type: type, context: ConversionContext, ) -> Any: """Perform conversion.""" ...
[docs] @runtime_checkable class FieldResolver(Protocol): """Protocol for resolving field mappings."""
[docs] def resolve_field( self, source_fields: dict[str, FieldInfo], target_field: FieldInfo, ) -> str | None: """Resolve source field for a target field.""" ...
[docs] def suggest_mapping( self, source_schema: SchemaInfo, target_schema: SchemaInfo, ) -> dict[str, str]: """Suggest field mappings.""" ...
[docs] @runtime_checkable class TypeInspector(Protocol): """Protocol for custom type inspection."""
[docs] def can_inspect(self, type_hint: type) -> bool: """Check if this inspector can handle the type.""" ...
[docs] def inspect(self, type_hint: type) -> dict[str, Any]: """Inspect the type and return metadata.""" ...
[docs] def extract_constraints(self, type_hint: type) -> dict[str, Any]: """Extract validation constraints from type.""" ...
[docs] @runtime_checkable class SchemaEvolution(Protocol): """Protocol for schema evolution/migration.""" @property def version(self) -> str: """Schema version.""" ...
[docs] def can_migrate(self, from_version: str, to_version: str) -> bool: """Check if migration is possible.""" ...
[docs] def migrate( self, data: dict[str, Any], from_version: str, to_version: str, ) -> dict[str, Any]: """Migrate data between schema versions.""" ...
[docs] @runtime_checkable class CompatibilityPlugin(Protocol): """Protocol for compatibility checker plugins.""" @property def name(self) -> str: """Plugin name.""" ... @property def priority(self) -> int: """Plugin priority (higher = runs first).""" ...
[docs] def check_compatibility( self, source_type: type, target_type: type, ) -> CompatibilityLevel | None: """Check compatibility between types.""" ...
[docs] def enhance_report( self, report: Any, # CompatibilityReport source: SchemaInfo, target: SchemaInfo, ) -> None: """Enhance compatibility report with additional info.""" ...
[docs] @runtime_checkable class AsyncConverter(Protocol): """Protocol for async converters."""
[docs] async def aconvert( self, value: Any, context: ConversionContext, ) -> Any: """Async conversion.""" ...
@property def supports_sync(self) -> bool: """Whether sync conversion is also supported.""" ...
[docs] @runtime_checkable class SchemaRegistry(Protocol): """Protocol for schema registries."""
[docs] def register(self, name: str, schema: type[BaseModel]) -> None: """Register a schema.""" ...
[docs] def get(self, name: str) -> type[BaseModel] | None: """Get a schema by name.""" ...
[docs] def list_schemas(self) -> list[str]: """List all registered schema names.""" ...
[docs] def find_compatible( self, target: type[BaseModel], min_score: float = 0.7, ) -> list[tuple[str, float]]: """Find compatible schemas with scores.""" ...
[docs] class PluginManager: """Manages plugins for the compatibility system.""" def __init__(self) -> None: """Init . Returns: [TODO: Add return description] """ self._plugins: dict[str, list[Any]] = { "converters": [], "validators": [], "inspectors": [], "compatibility": [], "resolvers": [], }
[docs] def register_converter(self, converter: ConversionStrategy) -> None: """Register a conversion strategy.""" self._plugins["converters"].append(converter)
[docs] def register_validator(self, validator: SchemaValidator) -> None: """Register a schema validator.""" self._plugins["validators"].append(validator)
[docs] def register_inspector(self, inspector: TypeInspector) -> None: """Register a type inspector.""" self._plugins["inspectors"].append(inspector)
[docs] def register_compatibility_plugin(self, plugin: CompatibilityPlugin) -> None: """Register a compatibility plugin.""" self._plugins["compatibility"].append(plugin) # Sort by priority self._plugins["compatibility"].sort( key=lambda p: p.priority, reverse=True, )
[docs] def register_resolver(self, resolver: FieldResolver) -> None: """Register a field resolver.""" self._plugins["resolvers"].append(resolver)
[docs] def get_converters(self) -> list[ConversionStrategy]: """Get all registered converters.""" return self._plugins["converters"].copy()
[docs] def get_validators(self) -> list[SchemaValidator]: """Get all registered validators.""" return self._plugins["validators"].copy()
[docs] def get_inspectors(self) -> list[TypeInspector]: """Get all registered inspectors.""" return self._plugins["inspectors"].copy()
[docs] def get_compatibility_plugins(self) -> list[CompatibilityPlugin]: """Get all registered compatibility plugins.""" return self._plugins["compatibility"].copy()
[docs] def get_resolvers(self) -> list[FieldResolver]: """Get all registered resolvers.""" return self._plugins["resolvers"].copy()
# Global plugin manager _plugin_manager = PluginManager() # Decorator for registering plugins
[docs] def converter_plugin(cls) -> Any: """Decorator to register a converter plugin.""" _plugin_manager.register_converter(cls()) return cls
[docs] def validator_plugin(cls) -> Any: """Decorator to register a validator plugin.""" _plugin_manager.register_validator(cls()) return cls
[docs] def compatibility_plugin(priority: int = 0): """Decorator to register a compatibility plugin.""" def decorator(cls) -> Any: """Decorator. Returns: [TODO: Add return description] """ instance = cls() if not hasattr(instance, "priority"): instance.priority = priority _plugin_manager.register_compatibility_plugin(instance) return cls return decorator
# Example plugin implementations
[docs] class ExampleFieldResolver: """Example field resolver using similarity matching."""
[docs] def resolve_field( self, source_fields: dict[str, FieldInfo], target_field: FieldInfo, ) -> str | None: """Resolve by name similarity.""" target_name = target_field.name.lower() # Exact match if target_field.name in source_fields: return target_field.name # Case-insensitive match for source_name in source_fields: if source_name.lower() == target_name: return source_name # Partial match for source_name in source_fields: if target_name in source_name.lower() or source_name.lower() in target_name: return source_name return None
[docs] def suggest_mapping( self, source_schema: SchemaInfo, target_schema: SchemaInfo, ) -> dict[str, str]: """Suggest mappings for all fields.""" suggestions = {} for target_name, target_field in target_schema.fields.items(): source_name = self.resolve_field(source_schema.fields, target_field) if source_name: suggestions[target_name] = source_name return suggestions
[docs] class ExampleTypeInspector: """Example type inspector for custom types."""
[docs] def can_inspect(self, type_hint: type) -> bool: """Check if type has custom metadata.""" return hasattr(type_hint, "__metadata__")
[docs] def inspect(self, type_hint: type) -> dict[str, Any]: """Extract custom metadata.""" return { "has_metadata": True, "metadata": getattr(type_hint, "__metadata__", {}), }
[docs] def extract_constraints(self, type_hint: type) -> dict[str, Any]: """Extract validation constraints.""" metadata = getattr(type_hint, "__metadata__", {}) constraints = {} # Example: extract min/max constraints if "min" in metadata: constraints["minimum"] = metadata["min"] if "max" in metadata: constraints["maximum"] = metadata["max"] if "pattern" in metadata: constraints["pattern"] = metadata["pattern"] return constraints