Source code for haive.core.schema.compatibility.analyzer

"""Advanced type analysis engine for deep type introspection."""

from __future__ import annotations

import inspect
import sys
from functools import lru_cache

# Handle different Python versions
from types import UnionType
from typing import Any, ForwardRef, Union, get_args, get_origin

from pydantic import BaseModel
from pydantic.fields import FieldInfo as PydanticFieldInfo

from haive.core.schema.compatibility.types import FieldInfo, SchemaInfo, TypeInfo


[docs] class TypeAnalyzer: """Advanced type analysis with caching and deep introspection.""" def __init__(self, cache_size: int = 1000): """Initialize analyzer with cache.""" self._cache_size = cache_size self._clear_cache() def _clear_cache(self): """Clear all caches.""" # Create cached versions of methods self.analyze_type = lru_cache(maxsize=self._cache_size)(self._analyze_type_impl) self.get_type_info = lru_cache(maxsize=self._cache_size)( self._get_type_info_impl )
[docs] def analyze_schema(self, schema_type: type[BaseModel]) -> SchemaInfo: """Analyze a Pydantic BaseModel schema.""" if not inspect.isclass(schema_type) or not issubclass(schema_type, BaseModel): raise TypeError(f"{schema_type} is not a BaseModel subclass") # Get type info type_info = self.get_type_info(schema_type) # Create schema info schema_info = SchemaInfo( name=schema_type.__name__, type_info=type_info, base_classes=list(schema_type.__bases__), ) # Analyze fields for field_name, field_info in schema_type.model_fields.items(): schema_info.fields[field_name] = self._analyze_field(field_name, field_info) # Extract Haive-specific metadata if hasattr(schema_type, "__shared_fields__"): schema_info.shared_fields = set(schema_type.__shared_fields__) if hasattr(schema_type, "__reducer_fields__"): schema_info.reducer_fields = dict(schema_type.__reducer_fields__) if hasattr(schema_type, "__engine_io_mappings__"): schema_info.engine_io_mappings = dict(schema_type.__engine_io_mappings__) # Extract methods for name, method in inspect.getmembers(schema_type, inspect.ismethod): if not name.startswith("_"): schema_info.methods[name] = method return schema_info
def _analyze_field(self, name: str, pydantic_field: PydanticFieldInfo) -> FieldInfo: """Analyze a single field.""" # Get type information type_info = self.get_type_info(pydantic_field.annotation) # Determine if required is_required = pydantic_field.is_required() # Extract default information has_default = not is_required default_value = None default_factory = None if hasattr(pydantic_field, "default") and pydantic_field.default is not ...: has_default = True default_value = pydantic_field.default elif ( hasattr(pydantic_field, "default_factory") and pydantic_field.default_factory ): has_default = True default_factory = pydantic_field.default_factory # Create field info field_info = FieldInfo( name=name, type_info=type_info, is_required=is_required, has_default=has_default, default_value=default_value, default_factory=default_factory, description=pydantic_field.description, alias=pydantic_field.alias, ) # Extract validators if available if hasattr(pydantic_field, "validators"): field_info.validators = list(pydantic_field.validators) return field_info def _analyze_type_impl(self, type_hint: type[Any]) -> dict[str, Any]: """Implementation of type analysis (cached).""" origin = get_origin(type_hint) args = get_args(type_hint) analysis = { "type": type_hint, "origin": origin, "args": args, "is_generic": origin is not None, "is_union": self._is_union(type_hint), "is_optional": self._is_optional(type_hint), "is_protocol": self._is_protocol(type_hint), "is_typeddict": self._is_typeddict(type_hint), "is_literal": origin is not None and hasattr(origin, "__name__") and origin.__name__ == "Literal", "is_forward_ref": isinstance(type_hint, ForwardRef | str), "is_basemodel": self._is_basemodel(type_hint), "is_callable": callable(type_hint), } # Get module and qualname if hasattr(type_hint, "__module__"): analysis["module"] = type_hint.__module__ if hasattr(type_hint, "__qualname__"): analysis["qualname"] = type_hint.__qualname__ elif hasattr(type_hint, "__name__"): analysis["qualname"] = type_hint.__name__ return analysis def _get_type_info_impl(self, type_hint: type[Any]) -> TypeInfo: """Implementation of get_type_info (cached).""" analysis = self.analyze_type(type_hint) return TypeInfo( type_hint=type_hint, origin=analysis.get("origin"), args=analysis.get("args", ()), is_generic=analysis.get("is_generic", False), is_union=analysis.get("is_union", False), is_optional=analysis.get("is_optional", False), is_protocol=analysis.get("is_protocol", False), is_typeddict=analysis.get("is_typeddict", False), is_literal=analysis.get("is_literal", False), is_forward_ref=analysis.get("is_forward_ref", False), is_basemodel=analysis.get("is_basemodel", False), module=analysis.get("module"), qualname=analysis.get("qualname"), ) def _is_union(self, type_hint: type[Any]) -> bool: """Check if type is a Union.""" origin = get_origin(type_hint) return origin is Union or (sys.version_info >= (3, 10) and origin is UnionType) def _is_optional(self, type_hint: type[Any]) -> bool: """Check if type is Optional (Union[X, None]).""" if not self._is_union(type_hint): return False args = get_args(type_hint) return type(None) in args def _is_protocol(self, type_hint: type[Any]) -> bool: """Check if type is a Protocol.""" if not inspect.isclass(type_hint): return False # Check if it's a Protocol return any( base.__name__ == "Protocol" for base in inspect.getmro(type_hint) if hasattr(base, "__module__") and "typing" in base.__module__ ) def _is_typeddict(self, type_hint: type[Any]) -> bool: """Check if type is a TypedDict.""" if not hasattr(type_hint, "__annotations__"): return False # Check for TypedDict in MRO mro = getattr(type_hint, "__mro__", []) return any( base.__name__ == "TypedDict" for base in mro if hasattr(base, "__module__") and "typing" in base.__module__ ) def _is_basemodel(self, type_hint: type[Any]) -> bool: """Check if type is a Pydantic BaseModel.""" if not inspect.isclass(type_hint): return False try: return issubclass(type_hint, BaseModel) except TypeError: return False
[docs] def resolve_forward_refs( self, type_hint: type[Any], globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None, ) -> type[Any]: """Resolve forward references in a type hint.""" if isinstance(type_hint, str): # It's a string forward reference try: return eval(type_hint, globalns or {}, localns or {}) except Exception: return type_hint elif isinstance(type_hint, ForwardRef): # It's a ForwardRef object try: return type_hint._evaluate(globalns or {}, localns or {}) except Exception: return type_hint # Check if it's a generic with forward refs origin = get_origin(type_hint) if origin is not None: args = get_args(type_hint) if args: # Recursively resolve args resolved_args = tuple( self.resolve_forward_refs(arg, globalns, localns) for arg in args ) # Reconstruct the generic return origin[resolved_args] return type_hint
[docs] def get_generic_parameters(self, type_hint: type[Any]) -> dict[str, type[Any]]: """Extract generic type parameters.""" params = {} # Get type parameters if it's a generic class if hasattr(type_hint, "__parameters__"): for i, param in enumerate(type_hint.__parameters__): params[f"T{i}"] = param # Get actual type arguments if it's a generic instance origin = get_origin(type_hint) if origin is not None: args = get_args(type_hint) if hasattr(origin, "__parameters__"): for param, arg in zip(origin.__parameters__, args, strict=False): param_name = getattr(param, "__name__", f"T{id(param)}") params[param_name] = arg return params
[docs] def is_subtype(self, subtype: type[Any], supertype: type[Any]) -> bool: """Check if subtype is a subtype of supertype.""" # Handle None if subtype is type(None): return self._is_optional(supertype) # Direct subclass check if inspect.isclass(subtype) and inspect.isclass(supertype): try: return issubclass(subtype, supertype) except TypeError: pass # Handle generics sub_origin = get_origin(subtype) super_origin = get_origin(supertype) if sub_origin and super_origin: # Check origins match if not self.is_subtype(sub_origin, super_origin): return False # Check args sub_args = get_args(subtype) super_args = get_args(supertype) if len(sub_args) != len(super_args): return False # Check each argument for sub_arg, super_arg in zip(sub_args, super_args, strict=False): if not self.is_subtype(sub_arg, super_arg): return False return True # Handle Union types if self._is_union(supertype): super_args = get_args(supertype) return any(self.is_subtype(subtype, arg) for arg in super_args) return False
# Module-level convenience functions _default_analyzer = TypeAnalyzer() analyze_type = _default_analyzer.analyze_type get_type_info = _default_analyzer.get_type_info analyze_schema = _default_analyzer.analyze_schema