"""Reference classes for serializing callables and types."""
import importlib
import inspect
import logging
import uuid
from collections.abc import Callable
from functools import partial
from typing import Any, Optional, TypeVar, get_args, get_origin
from pydantic import BaseModel, ConfigDict, Field
logger = logging.getLogger(__name__)
T = TypeVar("T")
[docs]
class CallableReference(BaseModel):
"""Serializable reference to a callable.
This class can store and resolve references to:
- Module-level functions
- Lambda functions
- Methods
- Partial functions
- Directly passed callables
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
module_path: str | None = None
name: str | None = None
callable_type: str = "function"
source_code: str | None = None
args: tuple | None = None
kwargs: dict[str, Any] | None = None
# Store function object in a non-serialized field
# Using proper naming that Pydantic allows
runtime_func: Callable | None = Field(default=None, exclude=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
[docs]
@classmethod
def from_callable(cls, callable_obj: Callable) -> Optional["CallableReference"]:
"""Create reference from a callable."""
if callable_obj is None:
return None
ref = cls()
# Store the function directly for runtime use
ref.runtime_func = callable_obj
ref.module_path = getattr(callable_obj, "__module__", None)
ref.name = getattr(callable_obj, "__name__", None)
# Handle different callable types
if inspect.isfunction(callable_obj):
ref.callable_type = "function"
# Special handling for lambda functions
if callable_obj.__name__ == "<lambda>":
ref.callable_type = "lambda"
try:
source_code = inspect.getsource(callable_obj).strip()
# Clean source code to prevent tuple evaluation issues
source_code = source_code.rstrip(",;")
ref.source_code = source_code
except (OSError, TypeError):
# If we can't get source, we'll rely on the direct
# reference
logger.debug(f"Could not get source for lambda: {callable_obj}")
elif inspect.ismethod(callable_obj):
ref.callable_type = "method"
# Try to get instance info
if hasattr(callable_obj, "__self__"):
ref.kwargs = {"instance_id": id(callable_obj.__self__)}
elif callable_obj.__class__.__name__ == "partial":
ref.callable_type = "partial"
# Try to get info about the wrapped function
wrapped_func = callable_obj.func
ref.name = getattr(wrapped_func, "__name__", "partial_func")
ref.module_path = getattr(wrapped_func, "__module__", None)
# Store args and kwargs
ref.args = callable_obj.args
ref.kwargs = callable_obj.keywords
else:
ref.callable_type = callable_obj.__class__.__name__
return ref
[docs]
def resolve(self) -> Callable | None:
"""Resolve reference to a callable.
Priority:
1. Direct function reference if available
2. Module/name resolution for regular functions
3. Source code evaluation for lambdas
4. Partial function reconstruction
"""
# First priority: use directly stored function if available
if self.runtime_func is not None:
return self.runtime_func
# Second priority: resolve by module and name
if self.module_path and self.name and self.callable_type != "lambda":
try:
module = importlib.import_module(self.module_path)
func = getattr(module, self.name)
# Handle partial functions
if self.callable_type == "partial":
args = self.args or ()
kwargs = self.kwargs or {}
return partial(func, *args, **kwargs)
return func
except Exception as e:
logger.exception(
f"Failed to resolve function {self.module_path}.{self.name}: {e}"
)
# Third priority: evaluate lambda source code
if self.source_code and self.callable_type == "lambda":
try:
# Clean source code to prevent tuple issues
source = self.source_code
# Remove trailing comma/semicolon if present
source = source.rstrip(",;")
# SECURITY NOTE: Only use with trusted data
# Using a local scope to evaluate lambda
local_vars = {}
exec("func = " + source, globals(), local_vars)
return local_vars["func"]
except Exception as e:
logger.exception(f"Failed to resolve lambda from source: {e}")
# Last resort: dynamic import for specific cases
if self.callable_type == "function" and self.module_path and self.name:
try:
# Try to dynamically import the module
spec = importlib.util.find_spec(self.module_path)
if spec:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if hasattr(module, self.name):
return getattr(module, self.name)
except Exception as e:
logger.debug(
f"Failed dynamic import for {self.module_path}.{self.name}: {e}"
)
# Could not resolve
logger.warning(
f"Could not resolve callable reference: {self.module_path}.{self.name}"
)
return None
[docs]
class TypeReference(BaseModel):
"""Reference to a type that can be serialized."""
module_path: str | None = None
name: str
is_generic: bool = False
generic_args: list["TypeReference"] | None = None
generic_origin: str | None = None
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
[docs]
@classmethod
def from_type(cls, type_obj: type | None) -> Optional["TypeReference"]:
"""Create a TypeReference from a type object."""
if type_obj is None:
return None
# Get basic type info
ref = cls(
name=getattr(type_obj, "__name__", str(type_obj)),
module_path=getattr(type_obj, "__module__", None),
)
# Handle generic types (like List[str], Dict[str, int], etc.)
origin = get_origin(type_obj)
if origin is not None:
ref.is_generic = True
ref.generic_origin = getattr(origin, "__name__", str(origin))
# Get generic arguments
args = get_args(type_obj)
if args:
ref.generic_args = [cls.from_type(arg) for arg in args]
return ref
[docs]
def resolve(self) -> type | None:
"""Resolve the reference back to a type."""
if not self.module_path or not self.name:
return None
try:
# Import the base module
module = importlib.import_module(self.module_path)
base_type = getattr(module, self.name)
# If not generic, return the base type
if not self.is_generic:
return base_type
# If generic, reconstruct with arguments
if self.generic_origin and self.generic_args:
# Import typing module for generic types
import typing
origin = getattr(typing, self.generic_origin, None)
if origin:
# Resolve generic arguments
args = [arg.resolve() for arg in self.generic_args]
# Filter out None values
args = [arg for arg in args if arg is not None]
# Create generic type
if len(args) == 1:
return origin[args[0]]
if len(args) > 1:
return origin[tuple(args)]
# Fallback to base type
return base_type
except (ImportError, AttributeError) as e:
logger.exception(f"Error resolving type: {e}")
return None