"""Field and model validation framework with async support."""
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel
from haive.core.schema.compatibility.types import (
FieldInfo,
SchemaInfo,
ValidationError,
ValidationResult,
)
[docs]
class ValidationContext(BaseModel):
"""Context passed through validation chain."""
current_path: list[str] = []
root_value: Any = None
parent_value: Any = None
field_info: FieldInfo | None = None
schema_info: SchemaInfo | None = None
custom_data: dict[str, Any] = {}
[docs]
def push_path(self, segment: str) -> None:
"""Push a path segment."""
self.current_path.append(segment)
[docs]
def pop_path(self) -> str | None:
"""Pop a path segment."""
return self.current_path.pop() if self.current_path else None
@property
def current_path_str(self) -> str:
"""Get current path as string."""
return ".".join(self.current_path)
[docs]
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
[docs]
class Validator(ABC):
"""Base validator class."""
[docs]
@abstractmethod
def validate(self, value: Any, context: ValidationContext) -> ValidationResult:
"""Validate a value."""
@property
def supports_async(self) -> bool:
"""Whether this validator supports async validation."""
return hasattr(self, "avalidate")
[docs]
async def avalidate(
self, value: Any, context: ValidationContext
) -> ValidationResult:
"""Async validation (optional)."""
# Default: run sync validation in executor
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.validate, value, context)
[docs]
class FieldValidator(Validator):
"""Validator for individual fields."""
def __init__(
self,
field_name: str,
validators: list[Callable[[Any], bool]] | None = None,
error_messages: dict[str, str] | None = None,
):
"""Init .
Args:
field_name: [TODO: Add description]
validators: [TODO: Add description]
error_messages: [TODO: Add description]
"""
self.field_name = field_name
self.validators = validators or []
self.error_messages = error_messages or {}
[docs]
def add_validator(
self,
validator: Callable[[Any], bool],
error_message: str | None = None,
) -> None:
"""Add a validator function."""
self.validators.append(validator)
if error_message:
self.error_messages[str(validator)] = error_message
[docs]
def validate(self, value: Any, context: ValidationContext) -> ValidationResult:
"""Validate field value."""
result = ValidationResult(is_valid=True)
for validator in self.validators:
try:
if not validator(value):
error_msg = self.error_messages.get(
str(validator),
f"Validation failed for {self.field_name}",
)
result.add_error(
ValidationError(
field=self.field_name,
message=error_msg,
error_type="field_validation",
context={"value": value, "path": context.current_path_str},
)
)
except Exception as e:
result.add_error(
ValidationError(
field=self.field_name,
message=f"Validator error: {e!s}",
error_type="validator_exception",
context={"exception": str(e)},
)
)
return result
[docs]
class ModelValidator(Validator):
"""Validator for entire models/schemas."""
def __init__(
self,
schema_info: SchemaInfo | None = None,
field_validators: dict[str, FieldValidator] | None = None,
cross_field_validators: list[Callable] | None = None,
):
"""Init .
Args:
schema_info: [TODO: Add description]
field_validators: [TODO: Add description]
cross_field_validators: [TODO: Add description]
"""
self.schema_info = schema_info
self.field_validators = field_validators or {}
self.cross_field_validators = cross_field_validators or []
[docs]
def add_field_validator(self, field_name: str, validator: FieldValidator) -> None:
"""Add a field validator."""
self.field_validators[field_name] = validator
[docs]
def add_cross_field_validator(self, validator: Callable) -> None:
"""Add a cross-field validator."""
self.cross_field_validators.append(validator)
[docs]
def validate(self, value: Any, context: ValidationContext) -> ValidationResult:
"""Validate entire model."""
result = ValidationResult(is_valid=True)
# Convert to dict if BaseModel
if isinstance(value, BaseModel):
data = value.model_dump()
elif isinstance(value, dict):
data = value
else:
result.add_error(
ValidationError(
field=None,
message="Value must be a dict or BaseModel",
error_type="type_error",
)
)
return result
# Validate individual fields
for field_name, field_validator in self.field_validators.items():
if field_name in data:
context.push_path(field_name)
field_result = field_validator.validate(data[field_name], context)
for error in field_result.errors:
result.add_error(error)
for warning in field_result.warnings:
result.add_warning(warning)
context.pop_path()
# Cross-field validation
for validator in self.cross_field_validators:
try:
validation = validator(data)
if isinstance(validation, bool) and not validation:
result.add_error(
ValidationError(
field=None,
message="Cross-field validation failed",
error_type="cross_field_validation",
)
)
elif isinstance(validation, tuple) and len(validation) == 2:
is_valid, message = validation
if not is_valid:
result.add_error(
ValidationError(
field=None,
message=message,
error_type="cross_field_validation",
)
)
except Exception as e:
result.add_error(
ValidationError(
field=None,
message=f"Cross-field validator error: {e!s}",
error_type="validator_exception",
)
)
return result
[docs]
@dataclass
class ValidatorChain:
"""Chain multiple validators together."""
validators: list[Validator]
stop_on_first_error: bool = False
[docs]
def validate(
self, value: Any, context: ValidationContext | None = None
) -> ValidationResult:
"""Run all validators in chain."""
if context is None:
context = ValidationContext(root_value=value)
result = ValidationResult(is_valid=True)
for validator in self.validators:
val_result = validator.validate(value, context)
# Merge results
for error in val_result.errors:
result.add_error(error)
for warning in val_result.warnings:
result.add_warning(warning)
# Stop if requested
if self.stop_on_first_error and not val_result.is_valid:
break
return result
[docs]
async def avalidate(
self, value: Any, context: ValidationContext | None = None
) -> ValidationResult:
"""Async validation of chain."""
if context is None:
context = ValidationContext(root_value=value)
result = ValidationResult(is_valid=True)
for validator in self.validators:
if validator.supports_async:
val_result = await validator.avalidate(value, context)
else:
val_result = validator.validate(value, context)
# Merge results
for error in val_result.errors:
result.add_error(error)
for warning in val_result.warnings:
result.add_warning(warning)
# Stop if requested
if self.stop_on_first_error and not val_result.is_valid:
break
return result
[docs]
class ValidatorBuilder:
"""Builder for creating validators."""
[docs]
@staticmethod
def for_type(type_hint: type) -> FieldValidator:
"""Create validator for a specific type."""
validator = FieldValidator(f"{type_hint.__name__}_validator")
# Add type check
validator.add_validator(
lambda x: isinstance(x, type_hint),
f"Value must be of type {type_hint.__name__}",
)
return validator
[docs]
@staticmethod
def for_range(
min_value: float | None = None,
max_value: float | None = None,
field_name: str = "value",
) -> FieldValidator:
"""Create range validator."""
validator = FieldValidator(field_name)
if min_value is not None:
validator.add_validator(
lambda x: x >= min_value, f"Value must be >= {min_value}"
)
if max_value is not None:
validator.add_validator(
lambda x: x <= max_value, f"Value must be <= {max_value}"
)
return validator
[docs]
@staticmethod
def for_length(
min_length: int | None = None,
max_length: int | None = None,
field_name: str = "value",
) -> FieldValidator:
"""Create length validator."""
validator = FieldValidator(field_name)
if min_length is not None:
validator.add_validator(
lambda x: len(x) >= min_length, f"Length must be >= {min_length}"
)
if max_length is not None:
validator.add_validator(
lambda x: len(x) <= max_length, f"Length must be <= {max_length}"
)
return validator
[docs]
@staticmethod
def for_pattern(pattern: str, field_name: str = "value") -> FieldValidator:
"""Create regex pattern validator."""
import re
regex = re.compile(pattern)
validator = FieldValidator(field_name)
validator.add_validator(
lambda x: bool(regex.match(str(x))), f"Value must match pattern: {pattern}"
)
return validator
[docs]
@staticmethod
def combine(*validators: Validator) -> ValidatorChain:
"""Combine multiple validators into a chain."""
return ValidatorChain(validators=list(validators))
# Common validators
[docs]
class CommonValidators:
"""Collection of common validators."""
[docs]
@staticmethod
def not_empty(value: Any) -> bool:
"""Check value is not empty."""
if value is None:
return False
if hasattr(value, "__len__"):
return len(value) > 0
return True
[docs]
@staticmethod
def is_email(value: str) -> bool:
"""Basic email validation."""
import re
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
return bool(re.match(pattern, value))
[docs]
@staticmethod
def is_url(value: str) -> bool:
"""Basic URL validation."""
import re
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
return bool(re.match(pattern, value))
[docs]
@staticmethod
def is_uuid(value: str) -> bool:
"""UUID validation."""
import re
pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
return bool(re.match(pattern, value.lower()))
# Convenience function
[docs]
def create_validator(
schema_info: SchemaInfo,
custom_validators: dict[str, list[Callable]] | None = None,
) -> ModelValidator:
"""Create a model validator from schema info."""
model_validator = ModelValidator(schema_info=schema_info)
# Add validators for each field
for field_name, field_info in schema_info.fields.items():
field_validator = FieldValidator(field_name)
# Add type validator
if field_info.type_info.type_hint:
field_validator.add_validator(
lambda x, t=field_info.type_info.type_hint: isinstance(x, t),
f"Must be of type {field_info.type_info.type_hint}",
)
# Add required validator
if field_info.is_required:
field_validator.add_validator(
CommonValidators.not_empty, "Field is required"
)
# Add custom validators
if custom_validators and field_name in custom_validators:
for validator in custom_validators[field_name]:
field_validator.add_validator(validator)
model_validator.add_field_validator(field_name, field_validator)
return model_validator