Source code for haive.core.common.mixins.checkpointer_mixin

"""Checkpointer mixin for stateful graphs and execution persistence.

This module provides a mixin that adds checkpointing capabilities to any class
that uses LangGraph or LangChain for stateful execution. It handles both
synchronous and asynchronous checkpointing patterns, state restoration, and
runtime configuration management.

Usage:
            from pydantic import BaseModel, Field
            from langgraph.graph import StateGraph
            from haive.core.common.mixins import CheckpointerMixin
            from haive.core.persistence.config import CheckpointerConfig

            class MyAgent(CheckpointerMixin, BaseModel):
                # Define the required fields
                persistence: Optional[CheckpointerConfig] = Field(default=None)
                checkpoint_mode: str = Field(default="sync")

                def __init__(self, **data):
                    super().__init__(**data)
                    # Create graph
                    builder = StateGraph(...)
                    self.app = builder.compile()

                # Use run with automatic checkpointing
                def process(self, input_data, thread_id=None):
                    return self.run(input_data, thread_id=thread_id)

                # Use streaming with automatic checkpointing
                def process_stream(self, input_data, thread_id=None):
                    return self.stream(input_data, thread_id=thread_id)
"""

import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator, Generator
from typing import Any

try:
    from langchain_core.runnables import RunnableConfig
except ImportError:
    # Fallback for documentation builds
    class RunnableConfig: pass
from pydantic import BaseModel, PrivateAttr

from haive.core.config.runnable import RunnableConfigManager
from haive.core.persistence.handlers import (
    prepare_merged_input,
    register_async_thread_if_needed,
    register_thread_if_needed,
)

logger = logging.getLogger(__name__)


[docs] class CheckpointerMixin(BaseModel): """Mixin that provides checkpointing capabilities for stateful graph execution. This mixin adds methods for running stateful graph executions with checkpointing support, including automatic state restoration, thread management, and proper configuration handling for both synchronous and asynchronous execution patterns. The mixin expects the host class to provide: - persistence: Optional[CheckpointerConfig] - Configuration for the checkpointer - checkpoint_mode: str - Mode of checkpointing ("sync", "async", or "none") - runnable_config: RunnableConfig - Base configuration for runnables - input_schema, state_schema (optional) - Schemas for input and state validation - app or compile() method - The LangGraph compiled application Attributes: None publicly, but requires the above attributes from the host class. """ # Private attributes for runtime checkpointer instances (not serialized) _sync_checkpointer: Any = PrivateAttr(default=None) _async_checkpointer: Any = PrivateAttr(default=None) _checkpointer_initialized: bool = PrivateAttr(default=False) _async_setup_pending: bool = PrivateAttr(default=False) def _ensure_checkpointer_initialized(self) -> None: """Initialize checkpointers if not already done. This method creates the appropriate checkpointer instances based on the persistence configuration and checkpoint mode. """ if self._checkpointer_initialized: return # Get persistence config persistence = getattr(self, "persistence", None) checkpoint_mode = getattr(self, "checkpoint_mode", "sync") if persistence is None or checkpoint_mode == "none": self._sync_checkpointer = None self._async_checkpointer = None else: # Create sync checkpointer self._sync_checkpointer = persistence.create_checkpointer() # Mark async setup as pending if needed if checkpoint_mode == "async": self._async_setup_pending = True self._checkpointer_initialized = True async def _ensure_async_checkpointer_initialized(self) -> None: """Initialize async checkpointer if needed. This method creates the asynchronous checkpointer instance if async mode is enabled and setup is pending. """ if not self._async_setup_pending: return persistence = getattr(self, "persistence", None) if persistence and hasattr(persistence, "create_async_checkpointer"): try: self._async_checkpointer = await persistence.create_async_checkpointer() logger.debug( f"Initialized async checkpointer: {type(self._async_checkpointer).__name__}" ) except Exception as e: logger.warning(f"Failed to create async checkpointer: {e}") self._async_checkpointer = None self._async_setup_pending = False
[docs] def get_checkpointer(self, async_mode: bool = False) -> Any: """Get the appropriate checkpointer. Args: async_mode: Whether to return the async checkpointer. Returns: The appropriate checkpointer instance or None if not available. """ self._ensure_checkpointer_initialized() if async_mode: return self._async_checkpointer return self._sync_checkpointer
def _prepare_runnable_config( self, thread_id: str | None = None, config: RunnableConfig | None = None, **kwargs, ) -> RunnableConfig: """Prepare runnable config with thread management. This method creates or merges runnable configurations with appropriate thread management and checkpointing settings. Args: thread_id: Optional thread ID for the execution. config: Optional base configuration to extend. **kwargs: Additional configuration parameters. Returns: The prepared runnable configuration. """ # Get base config from the class base_config = getattr(self, "runnable_config", None) # Create or merge configs if thread_id: runtime_config = RunnableConfigManager.create( thread_id=thread_id, user_id=kwargs.pop("user_id", None) ) if base_config: runtime_config = RunnableConfigManager.merge( base_config, runtime_config ) if config: runtime_config = RunnableConfigManager.merge(runtime_config, config) elif config: if base_config: runtime_config = RunnableConfigManager.merge(base_config, config) else: runtime_config = config else: runtime_config = base_config or RunnableConfigManager.create() # Ensure required fields if "configurable" not in runtime_config: runtime_config["configurable"] = {} if "thread_id" not in runtime_config["configurable"]: runtime_config["configurable"]["thread_id"] = str(uuid.uuid4()) # Add checkpoint mode checkpoint_mode = getattr(self, "checkpoint_mode", "sync") runtime_config["configurable"]["checkpoint_mode"] = kwargs.pop( "checkpoint_mode", checkpoint_mode ) # Add other kwargs for key, value in kwargs.items(): if key.startswith("configurable_"): param_name = key.replace("configurable_", "") runtime_config["configurable"][param_name] = value elif key == "configurable" and isinstance(value, dict): runtime_config["configurable"].update(value) else: runtime_config[key] = value return runtime_config
[docs] def run( self, input_data: Any, thread_id: str | None = None, config: RunnableConfig | None = None, **kwargs, ) -> Any: """Run with checkpointer support. This method runs a graph execution with checkpointing support, automatically handling state restoration and persistence. Args: input_data: The input data for the execution. thread_id: Optional thread ID for state tracking. config: Optional configuration override. **kwargs: Additional configuration parameters. Returns: The result of the graph execution. """ self._ensure_checkpointer_initialized() # Get the compiled app - this should be implemented by the class using # the mixin app = ( self.compile() if hasattr(self, "compile") and callable(getattr(self, "compile", None)) else getattr(self, "app", None) ) # Prepare runtime config runtime_config = self._prepare_runnable_config( thread_id=thread_id, config=config, **kwargs ) thread_id = runtime_config.get("configurable", {}).get("thread_id") checkpointer = self.get_checkpointer(async_mode=False) # Register thread if needed if checkpointer and thread_id: register_thread_if_needed(checkpointer, thread_id) # Get previous state if available previous_state = None try: if checkpointer and thread_id and app: previous_state = app.get_state(runtime_config) except Exception as e: logger.warning(f"Error retrieving previous state: {e}") # Prepare merged input if previous_state: try: input_schema = getattr(self, "input_schema", None) state_schema = getattr(self, "state_schema", None) input_data = prepare_merged_input( input_data, previous_state, dict(runtime_config) if runtime_config else None, input_schema, state_schema, ) except Exception as e: logger.warning(f"Error merging with previous state: {e}") # Invoke the app if app: return app.invoke(input_data, runtime_config) else: raise ValueError( "No app available to invoke - compile() method or app attribute required" )
[docs] async def arun( self, input_data: Any, thread_id: str | None = None, config: RunnableConfig | None = None, **kwargs, ) -> Any: """Async run with checkpointer support. This method runs a graph execution asynchronously with checkpointing support, automatically handling state restoration and persistence. Args: input_data: The input data for the execution. thread_id: Optional thread ID for state tracking. config: Optional configuration override. **kwargs: Additional configuration parameters. Returns: The result of the graph execution. """ self._ensure_checkpointer_initialized() await self._ensure_async_checkpointer_initialized() # Get the compiled app app = ( self.compile() if hasattr(self, "compile") and callable(getattr(self, "compile", None)) else getattr(self, "app", None) ) # Check if we should use async mode checkpoint_mode = kwargs.get( "checkpoint_mode", getattr(self, "checkpoint_mode", "sync") ) use_async = checkpoint_mode == "async" and self._async_checkpointer # Prepare runtime config runtime_config = self._prepare_runnable_config( thread_id=thread_id, config=config, checkpoint_mode=checkpoint_mode, **kwargs, ) thread_id = runtime_config.get("configurable", {}).get("thread_id") if use_async: # Use async checkpointer - recompile app with it if thread_id: await register_async_thread_if_needed( self._async_checkpointer, thread_id ) # Get store if available store = ( getattr(app, "store", None) if hasattr(app, "store") else getattr(self, "store", None) ) # Recompile with async checkpointer async_app = app.graph.compile( checkpointer=self._async_checkpointer, store=store ) # Get previous state previous_state = None try: previous_state = await async_app.aget_state(runtime_config) except Exception as e: logger.warning(f"Error retrieving previous state: {e}") # Prepare merged input if previous_state: try: input_schema = getattr(self, "input_schema", None) state_schema = getattr(self, "state_schema", None) input_data = prepare_merged_input( input_data, previous_state, runtime_config, input_schema, state_schema, ) except Exception as e: logger.warning(f"Error merging with previous state: {e}") # Invoke async app return await async_app.ainvoke(input_data, runtime_config) # Use sync checkpointer checkpointer = self.get_checkpointer(async_mode=False) if checkpointer and thread_id: register_thread_if_needed(checkpointer, thread_id) # Get previous state previous_state = None try: if checkpointer and thread_id and app: previous_state = app.get_state(runtime_config) except Exception as e: logger.warning(f"Error retrieving previous state: {e}") # Prepare merged input if previous_state: try: input_schema = getattr(self, "input_schema", None) state_schema = getattr(self, "state_schema", None) input_data = prepare_merged_input( input_data, previous_state, dict(runtime_config) if runtime_config else None, input_schema, state_schema, ) except Exception as e: logger.warning(f"Error merging with previous state: {e}") # Use ainvoke if available, otherwise thread pool if hasattr(app, "ainvoke"): return await app.ainvoke(input_data, runtime_config) loop = asyncio.get_event_loop() return await loop.run_in_executor( None, lambda: app.invoke(input_data, runtime_config) )
[docs] def stream( self, input_data: Any, thread_id: str | None = None, stream_mode: str = "values", config: RunnableConfig | None = None, **kwargs, ) -> Generator[dict[str, Any], None, None]: """Stream with checkpointer support. This method streams graph execution results with checkpointing support, automatically handling state restoration and persistence. Args: input_data: The input data for the execution. thread_id: Optional thread ID for state tracking. stream_mode: The streaming mode to use (values, actions, etc.). config: Optional configuration override. **kwargs: Additional configuration parameters. Returns: A generator yielding execution chunks. """ self._ensure_checkpointer_initialized() # Get the compiled app app = ( self.compile() if hasattr(self, "compile") and callable(getattr(self, "compile", None)) else getattr(self, "app", None) ) # Prepare runtime config runtime_config = self._prepare_runnable_config( thread_id=thread_id, config=config, stream_mode=stream_mode, **kwargs ) thread_id = runtime_config.get("configurable", {}).get("thread_id") checkpointer = self.get_checkpointer(async_mode=False) # Register thread if needed if checkpointer and thread_id: register_thread_if_needed(checkpointer, thread_id) # Get previous state if available previous_state = None try: if checkpointer and thread_id and app: previous_state = app.get_state(runtime_config) except Exception as e: logger.warning(f"Error retrieving previous state: {e}") # Prepare merged input if previous_state: try: input_schema = getattr(self, "input_schema", None) state_schema = getattr(self, "state_schema", None) input_data = prepare_merged_input( input_data, previous_state, dict(runtime_config) if runtime_config else None, input_schema, state_schema, ) except Exception as e: logger.warning(f"Error merging with previous state: {e}") # Stream execution yield from app.stream(input_data, runtime_config)
[docs] async def astream( self, input_data: Any, thread_id: str | None = None, stream_mode: str = "values", config: RunnableConfig | None = None, **kwargs, ) -> AsyncGenerator[dict[str, Any], None]: """Async stream with checkpointer support. This method streams graph execution results asynchronously with checkpointing support, automatically handling state restoration and persistence. Args: input_data: The input data for the execution. thread_id: Optional thread ID for state tracking. stream_mode: The streaming mode to use (values, actions, etc.). config: Optional configuration override. **kwargs: Additional configuration parameters. Yields: Execution chunks as they become available. """ self._ensure_checkpointer_initialized() await self._ensure_async_checkpointer_initialized() # Get the compiled app app = ( self.compile() if hasattr(self, "compile") and callable(getattr(self, "compile", None)) else getattr(self, "app", None) ) # Check if we should use async mode checkpoint_mode = kwargs.get( "checkpoint_mode", getattr(self, "checkpoint_mode", "sync") ) use_async = checkpoint_mode == "async" and self._async_checkpointer # Prepare runtime config runtime_config = self._prepare_runnable_config( thread_id=thread_id, config=config, stream_mode=stream_mode, checkpoint_mode=checkpoint_mode, **kwargs, ) thread_id = runtime_config.get("configurable", {}).get("thread_id") if use_async: # Use async checkpointer if thread_id: await register_async_thread_if_needed( self._async_checkpointer, thread_id ) # Get store if available store = ( getattr(app, "store", None) if hasattr(app, "store") else getattr(self, "store", None) ) # Recompile with async checkpointer async_app = app.graph.compile( checkpointer=self._async_checkpointer, store=store ) # Get previous state previous_state = None try: previous_state = await async_app.aget_state(runtime_config) except Exception as e: logger.warning(f"Error retrieving previous state: {e}") # Prepare merged input if previous_state: try: input_schema = getattr(self, "input_schema", None) state_schema = getattr(self, "state_schema", None) input_data = prepare_merged_input( input_data, previous_state, runtime_config, input_schema, state_schema, ) except Exception as e: logger.warning(f"Error merging with previous state: {e}") # Stream async app if hasattr(async_app, "astream"): async for chunk in async_app.astream(input_data, runtime_config): yield chunk else: # Convert sync to async for chunk in async_app.stream(input_data, runtime_config): yield chunk else: # Use sync checkpointer checkpointer = self.get_checkpointer(async_mode=False) if checkpointer and thread_id: register_thread_if_needed(checkpointer, thread_id) # Get previous state previous_state = None try: if checkpointer and thread_id: previous_state = app.get_state(runtime_config) except Exception as e: logger.warning(f"Error retrieving previous state: {e}") # Prepare merged input if previous_state: try: input_schema = getattr(self, "input_schema", None) state_schema = getattr(self, "state_schema", None) input_data = prepare_merged_input( input_data, previous_state, runtime_config, input_schema, state_schema, ) except Exception as e: logger.warning(f"Error merging with previous state: {e}") # Stream with astream if available, otherwise convert sync to async if hasattr(app, "astream"): async for chunk in app.astream(input_data, runtime_config): yield chunk else: for chunk in app.stream(input_data, runtime_config): yield chunk
model_config = {"arbitrary_types_allowed": True}