Source code for haive.core.persistence.postgres_config

"""PostgreSQL-based persistence implementation for the Haive framework.

This module provides a PostgreSQL-backed checkpoint persistence implementation that
stores state data in a PostgreSQL database. This allows for durable, reliable state
persistence across application restarts and deployments.

The PostgreSQL implementation offers advanced features including connection pooling,
automatic retry with exponential backoff, comprehensive error handling, and support
for both synchronous and asynchronous operation modes. It integrates with LangGraph's
checkpoint system while adding enhanced robustness and configurability.

For production deployments, the PostgreSQL implementation is generally recommended
over in-memory or SQLite options due to its scalability, reliability, and
concurrent access capabilities.
"""

import logging
import urllib.parse
from contextlib import asynccontextmanager
from typing import Any

try:
    from psycopg_pool import AsyncConnectionPool, ConnectionPool
except ImportError:
    # Fallback classes for documentation builds
    class AsyncConnectionPool: pass
    class ConnectionPool: pass
from pydantic import Field, SecretStr

from haive.core.persistence.base import CheckpointerConfig
from haive.core.persistence.postgres_saver_override import configure_postgres_json
from haive.core.persistence.postgres_saver_with_thread_creation import (
    create_async_postgres_saver_with_thread_creation,
    create_postgres_saver_with_thread_creation,
)
from haive.core.persistence.serializers import create_encrypted_serializer_for_postgres
from haive.core.persistence.types import CheckpointerMode, CheckpointerType

logger = logging.getLogger(__name__)


[docs] class PostgresCheckpointerConfig(CheckpointerConfig[dict[str, Any]]): """Configuration for PostgreSQL-based checkpoint persistence. This implementation provides a robust, production-ready persistence solution using PostgreSQL as the storage backend. It offers comprehensive configuration options for database connections, connection pooling, security, and performance tuning. PostgreSQL persistence is recommended for production deployments where durability, reliability, and concurrent access are important. It supports both full history tracking and space-efficient shallow mode that only retains the most recent state. Key features include: - Connection pooling for optimal performance under load - Automatic retry with exponential backoff for resilience - Comprehensive security options including SSL/TLS support - Support for both synchronous and asynchronous operation - Transaction management and prepared statement optimization - Thread registration for tracking active sessions - Support for both full history and shallow (latest-only) storage modes The implementation maintains connection pools separately for synchronous and asynchronous usage, ensuring optimal performance in both contexts. It also includes table setup and validation to ensure the database schema is properly configured. Examples: from haive.core.persistence import PostgresCheckpointerConfig from haive.core.persistence.types import CheckpointerMode, CheckpointStorageMode # Create a basic PostgreSQL checkpointer config = PostgresCheckpointerConfig( db_host="localhost", db_port=5432, db_name="haive", db_user="postgres", db_pass="secure_password", ssl_mode="require", mode=CheckpointerMode.ASYNC, storage_mode=CheckpointStorageMode.SHALLOW ) # For async usage async def setup(): async_checkpointer = await config.create_async_checkpointer() # Use the checkpointer... Notes: - Requires the psycopg and psycopg_pool packages to be installed - For best performance, use connection pooling with appropriate sizing - Consider shallow mode for applications that don't need full history """ type: CheckpointerType = CheckpointerType.POSTGRES # Database connection parameters db_host: str = Field(default="localhost", description="PostgreSQL server hostname") db_port: int = Field(default=5432, description="PostgreSQL server port") db_name: str = Field(default="postgres", description="Database name") db_user: str = Field(default="postgres", description="Database username") db_pass: SecretStr = Field( default_factory=lambda: SecretStr("postgres"), description="Database password" ) ssl_mode: str = Field( default="disable", description="SSL mode for database connection" ) # Connection pool configuration min_pool_size: int = Field( default=1, description="Minimum number of connections in the pool" ) max_pool_size: int = Field( default=5, description="Maximum number of connections in the pool" ) # Additional connection options auto_commit: bool = Field(default=True, description="Auto-commit transactions") prepare_threshold: int | None = Field( default=None, description="Prepared statement threshold (None to disable)" ) connection_kwargs: dict[str, Any] = Field( default_factory=lambda: { "keepalives": 1, "keepalives_idle": 30, "keepalives_interval": 10, "keepalives_count": 5, "connect_timeout": 30, }, description="Additional connection keyword arguments", ) # Optional direct connection string connection_string: str | None = Field( default=None, description="Direct connection string (overrides individual parameters)", ) # Pipeline mode use_pipeline: bool = Field( default=False, description="Whether to use pipeline mode for better performance" ) class Config: arbitrary_types_allowed = True
[docs] def is_async_mode(self) -> bool: """Check if this configuration is set to operate in asynchronous mode. This method determines whether the PostgreSQL checkpointer should use asynchronous operations based on the configured mode. It affects which connection pools and checkpointer implementations are used. For PostgreSQL, this is an important distinction as it determines whether synchronous or asynchronous database drivers are used, which have different connection management patterns and performance characteristics. Returns: bool: True if configured for async operations, False for synchronous """ return self.mode == CheckpointerMode.ASYNC
[docs] def get_connection_uri(self) -> str: """Generate a formatted connection URI for PostgreSQL. This method constructs a properly formatted PostgreSQL connection string based on the configured connection parameters. It handles proper escaping of special characters in passwords and formatting according to PostgreSQL standards. The method prioritizes using a direct connection string if one is provided, otherwise it builds the string from individual connection parameters. Returns: str: Formatted PostgreSQL connection string ready for use Examples: config = PostgresCheckpointerConfig( db_host="db.example.com", db_port=5432, db_name="haive", db_user="app_user", db_pass="secret_password", ssl_mode="require" uri = config.get_connection_uri() # uri = "postgresql://app_user:secret_password@db.example.com:5432/haive?sslmode=require" """ # Use direct connection string if provided if self.connection_string: return self.connection_string # Generate from individual parameters encoded_pass = urllib.parse.quote_plus(self.db_pass.get_secret_value()) db_uri = f"postgresql://{self.db_user}:{encoded_pass}@{self.db_host}:{self.db_port}/{self.db_name}" if self.ssl_mode: db_uri += f"?sslmode={self.ssl_mode}" return db_uri
[docs] def get_connection_kwargs(self) -> dict[str, Any]: """Get connection keyword arguments for PostgreSQL connections. This method constructs a dictionary of connection options to be passed to the PostgreSQL connection pool and individual connections. It combines the standard configuration parameters with any additional custom parameters specified in connection_kwargs. The options include settings for transaction management, prepared statement handling, and timeout configuration, which can significantly impact performance and reliability. Returns: Dict[str, Any]: Dictionary of connection options ready to use with PostgreSQL connections or connection pools Examples: config = PostgresCheckpointerConfig( auto_commit=True, prepare_threshold=5, connection_kwargs={"application_name": "haive_app"} kwargs = config.get_connection_kwargs() # kwargs = { # "autocommit": True, # "prepare_threshold": 5, # "application_name": "haive_app" # } """ kwargs = { "autocommit": self.auto_commit, "prepare_threshold": self.prepare_threshold, } # Add any additional kwargs kwargs.update(self.connection_kwargs) return kwargs
[docs] def create_checkpointer(self) -> Any: """Create a synchronous PostgreSQL checkpointer. This method creates and configures a synchronous PostgreSQL checkpointer that matches the settings in this configuration. It handles connection pool creation, checkpointer initialization, and database table setup. The method automatically selects the appropriate implementation based on the storage_mode setting (full or shallow), and performs error checking to ensure the requested configuration is valid. Returns: Any: A configured PostgresSaver or ShallowPostgresSaver instance Raises: RuntimeError: If async mode is requested (use create_async_checkpointer instead) RuntimeError: If the PostgreSQL dependencies are missing or connection fails Examples: config = PostgresCheckpointerConfig( db_host="localhost", db_port=5432, storage_mode=CheckpointStorageMode.SHALLOW try: # Creates a ShallowPostgresSaver instance checkpointer = config.create_checkpointer() # Use with a graph graph = Graph(checkpointer=checkpointer) except RuntimeError as e: print(f"Failed to create PostgreSQL checkpointer: {e}") # Handle error - perhaps fall back to memory checkpointer """ try: # Handle async mode request if self.is_async_mode(): raise RuntimeError( "Cannot use create_checkpointer for async mode, use create_async_checkpointer instead" ) # Import our enhanced PostgresSaver with thread creation # Create connection pool with forced parameters to avoid SSL issues connection_kwargs = self.get_connection_kwargs() # Force disable prepared statements and ensure proper SSL handling connection_kwargs.update( { "prepare_threshold": None, # Force disable prepared statements "autocommit": True, # Ensure autocommit } ) # Configure function for JSON handling pool = ConnectionPool( conninfo=self.get_connection_uri(), min_size=self.min_pool_size, max_size=self.max_pool_size, kwargs=connection_kwargs, configure=configure_postgres_json, # Configure each connection check=ConnectionPool.check_connection, # Add connection health checking max_lifetime=1800, # 30 minutes max connection lifetime open=False, # Don't open in constructor to avoid early failures ) # Explicitly open the pool with error handling try: pool.open() logger.info("PostgreSQL connection pool opened successfully") except Exception as e: logger.exception(f"Failed to open PostgreSQL connection pool: {e}") raise # Import our production serializer factory # Create production-grade encrypted serializer for PostgreSQL production_serializer = create_encrypted_serializer_for_postgres( connection_string=self.get_connection_uri() ) # Create PostgresSaver with automatic thread creation checkpointer = create_postgres_saver_with_thread_creation( pool, serde=production_serializer ) # Setup tables if needed if self.setup_needed: try: checkpointer.setup() logger.info("PostgreSQL tables set up successfully") except Exception as e: logger.warning(f"Error during PostgreSQL setup: {e}") return checkpointer except Exception as e: logger.exception(f"Failed to create PostgreSQL checkpointer: {e}") raise RuntimeError(f"Failed to create PostgreSQL checkpointer: {e}")
[docs] async def create_async_checkpointer(self) -> Any: """Create an asynchronous PostgreSQL checkpointer. This method creates and configures an asynchronous PostgreSQL checkpointer that matches the settings in this configuration. It handles async connection pool creation, checkpointer initialization, and database table setup. The method automatically selects the appropriate implementation based on the storage_mode setting (full or shallow). It uses the asynchronous PostgreSQL driver and connection pool for non-blocking database operations. Returns: Any: A configured AsyncPostgresSaver or AsyncShallowPostgresSaver instance Raises: RuntimeError: If the asynchronous PostgreSQL dependencies are missing or connection fails Examples: config = PostgresCheckpointerConfig( db_host="localhost", db_port=5432, mode=CheckpointerMode.ASYNC, storage_mode=CheckpointStorageMode.FULL ) async def setup_graph(): try: # Creates an AsyncPostgresSaver instance async_checkpointer = await config.create_async_checkpointer() # Use with an async graph graph = AsyncGraph(checkpointer=async_checkpointer) return graph except RuntimeError as e: print(f"Failed to create async PostgreSQL checkpointer: {e}") # Handle error Note: This method automatically forces the mode to ASYNC for consistency, ensuring that the configuration accurately reflects the type of checkpointer being created. """ try: # Force async mode self.mode = CheckpointerMode.ASYNC # Import our enhanced AsyncPostgresSaver with thread creation # Create connection pool with forced parameters to avoid SSL issues connection_kwargs = self.get_connection_kwargs() # Force disable prepared statements and ensure proper SSL handling connection_kwargs.update( { "prepare_threshold": None, # Force disable prepared statements "autocommit": True, # Ensure autocommit } ) # Configure function for JSON handling pool = AsyncConnectionPool( conninfo=self.get_connection_uri(), min_size=self.min_pool_size, max_size=self.max_pool_size, kwargs=connection_kwargs, configure=configure_postgres_json, # Configure each connection check=AsyncConnectionPool.check_connection, # Add connection health checking max_lifetime=1800, # 30 minutes max connection lifetime open=False, # Don't open in constructor to avoid deprecation warning ) # Explicitly open the pool with error handling try: await pool.open() logger.info("Async PostgreSQL connection pool opened successfully") except Exception as e: logger.exception( f"Failed to open async PostgreSQL connection pool: {e}" ) raise # Import our production serializer factory # Create production-grade encrypted serializer for PostgreSQL production_serializer = create_encrypted_serializer_for_postgres( connection_string=self.get_connection_uri() ) # Create AsyncPostgresSaver with automatic thread creation checkpointer = await create_async_postgres_saver_with_thread_creation( pool, serde=production_serializer ) # Setup tables if needed if self.setup_needed: try: await checkpointer.setup() logger.info("PostgreSQL tables set up successfully (async)") except Exception as e: logger.warning(f"Error during PostgreSQL async setup: {e}") return checkpointer except Exception as e: logger.exception(f"Failed to create async PostgreSQL checkpointer: {e}") logger.exception(f"Connection URI: {self.get_connection_uri()}") logger.exception(f"Storage mode: {self.storage_mode}") logger.exception(f"Mode: {self.mode}") # Try to provide more specific error information if "AsyncPostgresSaver" in str(e): logger.exception("AsyncPostgresSaver import or creation failed") if "pool" in str(e).lower(): logger.exception("Connection pool creation failed") if "connection" in str(e).lower(): logger.exception("Database connection failed") raise RuntimeError(f"Failed to create async PostgreSQL checkpointer: {e}")
[docs] async def initialize_async_checkpointer(self) -> Any: """Initialize an async checkpointer with proper resource management. This method creates and initializes an asynchronous PostgreSQL checkpointer with proper resource lifecycle management using an async context manager. This ensures that database connections are properly closed when they're no longer needed, preventing connection leaks and other resource issues. Unlike create_async_checkpointer, which returns a raw checkpointer instance, this method returns an async context manager that automatically handles resource cleanup when the context is exited, making it ideal for use in production environments. Returns: Any: An async context manager that yields a configured checkpointer and automatically cleans up resources on exit Raises: RuntimeError: If the asynchronous PostgreSQL dependencies are missing or connection fails Examples: config = PostgresCheckpointerConfig( db_host="localhost", db_port=5432, mode=CheckpointerMode.ASYNC ) async def run_with_managed_resources(): # Resources will be properly initialized and cleaned up async with await config.initialize_async_checkpointer() as checkpointer: # Use checkpointer with async code graph = AsyncGraph(checkpointer=checkpointer) # Run operations with graph... # Connection pool is automatically closed here Note: This is the recommended method for asynchronous usage in production environments, as it ensures proper resource cleanup even if errors occur. """ @asynccontextmanager async def async_checkpointer_context(): """Context manager for async checkpointer with proper resource management.""" checkpointer = None try: # Create the checkpointer checkpointer = await self.create_async_checkpointer() # Yield it for use yield checkpointer finally: # Clean up resources if checkpointer and hasattr(checkpointer, "conn") and checkpointer.conn: # Close pool if available try: if hasattr(checkpointer.conn, "close"): await checkpointer.conn.close() logger.debug("Async PostgreSQL pool closed successfully") except Exception as e: logger.warning(f"Error closing async PostgreSQL pool: {e}") # Return the context manager return async_checkpointer_context()