"""Advanced field mapping with path resolution and transformations."""
from __future__ import annotations
import re
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from haive.core.schema.compatibility.types import ConversionContext
[docs]
@dataclass
class FieldMapping:
"""Represents a mapping between fields with transformation."""
source_path: str # Can be nested: "user.profile.name"
target_field: str
transformer: Callable[[Any], Any] | None = None
condition: Callable[[dict[str, Any]], bool] | None = None
default_value: Any = None
default_factory: Callable[[], Any] | None = None
validator: Callable[[Any], bool] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
# Path patterns
is_computed: bool = False # No source, generated value
is_aggregate: bool = False # Multiple sources to one target
aggregator: Callable[[list[Any]], Any] | None = None
[docs]
def apply(self, source_data: dict[str, Any]) -> tuple[bool, Any]:
"""Apply mapping to source data.
Returns:
Tuple of (success, value)
"""
# Check condition first
if self.condition and not self.condition(source_data):
return False, None
# Handle computed fields
if self.is_computed:
value = (
self.default_factory() if self.default_factory else self.default_value
)
# Handle aggregate fields
elif self.is_aggregate and self.aggregator:
# Source path contains multiple paths separated by |
paths = self.source_path.split("|")
values = []
for path in paths:
path_value = self._extract_path_value(source_data, path.strip())
if path_value is not None:
values.append(path_value)
if not values and self.default_value is not None:
value = self.default_value
else:
value = self.aggregator(values)
# Normal field extraction
else:
value = self._extract_path_value(source_data, self.source_path)
# Use default if not found
if value is None:
if self.default_factory:
value = self.default_factory()
elif self.default_value is not None:
value = self.default_value
else:
return False, None
# Apply transformer
if self.transformer and value is not None:
try:
value = self.transformer(value)
except Exception as e:
return False, f"Transform error: {e}"
# Validate
if self.validator and not self.validator(value):
return False, f"Validation failed for {self.target_field}"
return True, value
def _extract_path_value(self, data: dict[str, Any], path: str) -> Any:
"""Extract value from nested path."""
# Handle array notation: messages[0].content
array_pattern = re.compile(r"()\[(\d+)\]")
# Handle filter notation: messages[?type=="human"].content
filter_pattern = re.compile(r'()\[\?()==["\']([^"\']+)["\']\]')
current = data
parts = path.split(".")
for _i, part in enumerate(parts):
if current is None:
return None
# Check for array access
array_match = array_pattern.match(part)
if array_match:
field_name = array_match.group(1)
index = int(array_match.group(2))
if isinstance(current, dict) and field_name in current:
array = current[field_name]
if isinstance(array, list) and 0 <= index < len(array):
current = array[index]
else:
return None
else:
return None
continue
# Check for filter
filter_match = filter_pattern.match(part)
if filter_match:
field_name = filter_match.group(1)
filter_field = filter_match.group(2)
filter_value = filter_match.group(3)
if isinstance(current, dict) and field_name in current:
array = current[field_name]
if isinstance(array, list):
filtered = [
item
for item in array
if isinstance(item, dict)
and item.get(filter_field) == filter_value
]
current = filtered[0] if filtered else None
else:
return None
else:
return None
continue
# Normal field access
if isinstance(current, dict):
current = current.get(part)
elif hasattr(current, part):
current = getattr(current, part)
else:
return None
return current
[docs]
class FieldMapper:
"""Manages field mappings between schemas."""
def __init__(self) -> None:
"""Init .
Returns:
[TODO: Add return description]
"""
self.mappings: dict[str, FieldMapping] = {}
self._source_index: dict[str, set[str]] = {} # source -> targets
[docs]
def add_mapping(
self,
source: str | list[str],
target: str,
transformer: Callable[[Any], Any] | None = None,
condition: Callable[[dict[str, Any]], bool] | None = None,
default: Any = None,
validator: Callable[[Any], bool] | None = None,
) -> FieldMapping:
"""Add a field mapping."""
# Handle multiple sources (aggregate)
if isinstance(source, list):
source_path = " | ".join(source)
is_aggregate = True
else:
source_path = source
is_aggregate = False
mapping = FieldMapping(
source_path=source_path,
target_field=target,
transformer=transformer,
condition=condition,
default_value=default,
validator=validator,
is_aggregate=is_aggregate,
)
self.mappings[target] = mapping
# Update index
if isinstance(source, list):
for s in source:
if s not in self._source_index:
self._source_index[s] = set()
self._source_index[s].add(target)
else:
if source not in self._source_index:
self._source_index[source] = set()
self._source_index[source].add(target)
return mapping
[docs]
def add_computed_field(
self,
target: str,
generator: Callable[[], Any],
condition: Callable[[dict[str, Any]], bool] | None = None,
) -> FieldMapping:
"""Add a computed field with no source."""
mapping = FieldMapping(
source_path="",
target_field=target,
default_factory=generator,
condition=condition,
is_computed=True,
)
self.mappings[target] = mapping
return mapping
[docs]
def add_aggregate_field(
self,
sources: list[str],
target: str,
aggregator: Callable[[list[Any]], Any],
default: Any = None,
) -> FieldMapping:
"""Add an aggregate field from multiple sources."""
mapping = FieldMapping(
source_path=" | ".join(sources),
target_field=target,
is_aggregate=True,
aggregator=aggregator,
default_value=default,
)
self.mappings[target] = mapping
# Update index
for source in sources:
if source not in self._source_index:
self._source_index[source] = set()
self._source_index[source].add(target)
return mapping
[docs]
def map_data(
self,
source_data: dict[str, Any],
target_fields: set[str] | None = None,
include_unmapped: bool = False,
context: ConversionContext | None = None,
) -> dict[str, Any]:
"""Map source data to target schema.
Args:
source_data: Source data dictionary
target_fields: Specific fields to map (None = all)
include_unmapped: Include unmapped source fields
context: Conversion context for tracking
Returns:
Mapped data dictionary
"""
result = {}
mapped_sources = set()
# Apply mappings
for target_field, mapping in self.mappings.items():
if target_fields and target_field not in target_fields:
continue
success, value = mapping.apply(source_data)
if success:
result[target_field] = value
# Track mapped source fields
if not mapping.is_computed:
if mapping.is_aggregate:
sources = mapping.source_path.split("|")
mapped_sources.update(s.strip() for s in sources)
else:
mapped_sources.add(mapping.source_path)
elif context:
context.add_warning(f"Failed to map field '{target_field}': {value}")
# Include unmapped fields if requested
if include_unmapped:
for key, value in source_data.items():
if key not in mapped_sources and key not in result:
result[key] = value
return result
[docs]
def get_mapping_for_target(self, target_field: str) -> FieldMapping | None:
"""Get mapping for a target field."""
return self.mappings.get(target_field)
[docs]
def get_targets_for_source(self, source_field: str) -> set[str]:
"""Get all target fields that use a source field."""
return self._source_index.get(source_field, set())
[docs]
def validate_mappings(
self,
source_fields: set[str],
target_fields: set[str],
) -> tuple[bool, list[str]]:
"""Validate that mappings are complete and valid.
Returns:
Tuple of (is_valid, issues)
"""
issues = []
# Check all required target fields have mappings
for target in target_fields:
if target not in self.mappings:
issues.append(f"No mapping for required field '{target}'")
# Check all source references exist
for mapping in self.mappings.values():
if not mapping.is_computed:
if mapping.is_aggregate:
sources = [s.strip() for s in mapping.source_path.split("|")]
else:
sources = [mapping.source_path]
for source in sources:
# Extract base field (before any dots or brackets)
base_field = source.split(".")[0].split("[")[0]
if base_field not in source_fields:
issues.append(
f"Mapping references non-existent source '{base_field}'"
)
return len(issues) == 0, issues
[docs]
def to_dict(self) -> dict[str, dict[str, Any]]:
"""Export mappings as dictionary."""
return {
target: {
"source": mapping.source_path,
"is_computed": mapping.is_computed,
"is_aggregate": mapping.is_aggregate,
"has_transformer": mapping.transformer is not None,
"has_condition": mapping.condition is not None,
"has_default": mapping.default_value is not None,
}
for target, mapping in self.mappings.items()
}
# Convenience function
[docs]
def create_mapping(
mappings: dict[str, str | tuple[str, Callable]],
computed_fields: dict[str, Callable] | None = None,
) -> FieldMapper:
"""Create a field mapper from simple mapping dict.
Args:
mappings: Dict of target -> source or (source, transformer)
computed_fields: Dict of target -> generator function
"""
mapper = FieldMapper()
# Add simple mappings
for target, source_spec in mappings.items():
if isinstance(source_spec, tuple):
source, transformer = source_spec
mapper.add_mapping(source, target, transformer=transformer)
else:
mapper.add_mapping(source_spec, target)
# Add computed fields
if computed_fields:
for target, generator in computed_fields.items():
mapper.add_computed_field(target, generator)
return mapper