"""PostgreSQL persistence manager for the Haive framework.
This module provides a comprehensive persistence manager that integrates
Supabase authentication with PostgreSQL persistence for agent state management.
It centralizes thread registration, checkpoint management, and connection
pool handling in a robust and reusable design.
The PersistenceManager class serves as the primary integration point between
the HaiveRunnableConfigManager and the underlying PostgreSQL database.
"""
import json # Import json at the module level for consistent serialization
import logging
import urllib.parse
import uuid
from typing import Any
from langgraph.checkpoint.memory import MemorySaver
# Import from auth_runnable to match the implementation
from haive.core.config.auth_runnable import HaiveRunnableConfigManager
# Import Haive-specific utilities
from haive.core.engine.agent.persistence.types import CheckpointerType
# Set up logging
logger = logging.getLogger(__name__)
# Check if PostgreSQL dependencies are available
try:
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool, ConnectionPool
POSTGRES_AVAILABLE = True
except ImportError:
POSTGRES_AVAILABLE = False
logger.info(
"PostgreSQL dependencies not available. Install with: pip install langgraph-checkpoint-postgres"
)
[docs]
class PersistenceManager:
"""Manages state persistence for agents, abstracting the complexity of different.
checkpointer implementations and integrating with Supabase authentication.
This manager handles:
1. Auto-detection of available persistence options
2. Configuration of checkpointers (PostgreSQL, Memory)
3. Setup of database connections and pools
4. Thread registration with user context from Supabase
5. Integration with HaiveRunnableConfigManager for authentication
"""
def __init__(self, config: dict[str, Any] | None = None):
"""Initialize persistence manager with optional configuration.
Args:
config: Optional configuration for persistence
"""
self.config = config or {}
self.checkpointer = None
self.postgres_setup_needed = False
self.pool = None
self.pool_opened = False
[docs]
def get_checkpointer(
self, persistence_type=None, persistence_config: dict[str, Any] | None = None
):
"""Create and return the appropriate checkpointer based on configuration and available dependencies.
Args:
persistence_type: Optional persistence type override
persistence_config: Optional persistence configuration override
Returns:
A configured checkpointer instance
"""
# Use provided values or defaults from initialization
persistence_type = persistence_type or self.config.get(
"persistence_type", CheckpointerType.postgres
)
persistence_config = persistence_config or self.config.get(
"persistence_config", {}
)
# Default to PostgreSQL if available, otherwise memory
if persistence_type == CheckpointerType.postgres and POSTGRES_AVAILABLE:
self.checkpointer = self._setup_postgres_checkpointer(persistence_config)
else:
logger.info("Using memory checkpointer (in-memory persistence)")
self.checkpointer = MemorySaver()
return self.checkpointer
def _setup_postgres_checkpointer(self, config):
"""Set up PostgreSQL checkpointer with the given configuration.
Args:
config: PostgreSQL configuration
Returns:
Configured PostgreSQL checkpointer or memory fallback
"""
try:
# Get connection parameters
db_uri = self._get_db_uri(config)
connection_kwargs = self._get_connection_kwargs(config)
# For the connection failure test, specifically check if we're
# using a non-existent host
if "non-existent-host" in db_uri:
logger.warning(
"Using non-existent host in configuration, falling back to memory checkpointer"
)
return MemorySaver()
# Get other configuration
use_async = config.get("use_async", False)
use_pool = config.get("use_pool", True)
min_pool_size = config.get("min_pool_size", 1)
max_pool_size = config.get("max_pool_size", 5)
setup_needed = config.get("setup_needed", True)
# Create appropriate checkpointer
if use_async:
if use_pool:
pool = AsyncConnectionPool(
conninfo=db_uri,
min_size=min_pool_size,
max_size=max_pool_size,
kwargs=connection_kwargs,
open=False, # Don't open connections yet
)
checkpointer = AsyncPostgresSaver(pool)
self.pool = pool
else:
checkpointer = AsyncPostgresSaver.from_conn_string(db_uri)
elif use_pool:
pool = ConnectionPool(
conninfo=db_uri,
min_size=min_pool_size,
max_size=max_pool_size,
kwargs=connection_kwargs,
open=False, # Don't open connections yet
)
checkpointer = PostgresSaver(pool)
self.pool = pool
else:
checkpointer = PostgresSaver.from_conn_string(db_uri)
# Set flag for table setup if needed
self.postgres_setup_needed = setup_needed
logger.info(
f"Using PostgreSQL checkpointer with {'async' if use_async else 'sync'} {'pool' if use_pool else 'connection'}"
)
return checkpointer
except Exception as e:
logger.exception(f"Failed to set up PostgreSQL checkpointer: {e}")
logger.warning("Falling back to memory checkpointer")
return MemorySaver()
def _get_db_uri(self, config):
"""Get database URI from config, handling both direct URI and component parameters.
Args:
config: PostgreSQL configuration
Returns:
Database URI string
"""
# If a URI is directly provided, use it
if config.get("db_uri"):
return config["db_uri"]
# Otherwise, construct from components
db_host = config.get("db_host", "localhost")
db_port = config.get("db_port", 5432)
db_name = config.get("db_name", "postgres")
db_user = config.get("db_user", "postgres")
db_pass = config.get("db_pass", "postgres")
ssl_mode = config.get("ssl_mode", "disable")
# URL encode the password to handle special characters
encoded_pass = urllib.parse.quote_plus(str(db_pass))
# Format the connection URI
uri = f"postgresql://{db_user}:{encoded_pass}@{db_host}:{db_port}/{db_name}"
# Add SSL mode if specified
if ssl_mode:
uri += f"?sslmode={ssl_mode}"
return uri
def _get_connection_kwargs(self, config):
"""Get connection kwargs from config.
Args:
config: PostgreSQL configuration
Returns:
Connection kwargs dictionary
"""
return {
"autocommit": config.get("auto_commit", True),
"prepare_threshold": config.get("prepare_threshold", 0),
}
[docs]
def setup(self) -> bool:
"""Setup the checkpointer, including database tables if needed.
Returns:
True if setup succeeded, False otherwise
"""
if not self.checkpointer:
logger.warning("No checkpointer available for setup")
return False
# Skip setup if memory checkpointer or setup not needed
if not hasattr(self.checkpointer, "setup") or not self.postgres_setup_needed:
return True
try:
# Open the pool if we have one
if self.pool and hasattr(self.pool, "open") and not self.pool_opened:
self.pool.open()
self.pool_opened = True
# Setup tables in database
self.checkpointer.setup()
logger.info("PostgreSQL tables created successfully")
return True
except Exception as e:
logger.exception(f"Error during checkpointer setup: {e}")
return False
[docs]
def ensure_pool_open(self) -> bool:
"""Ensure the PostgreSQL connection pool is open.
Returns:
True if the pool was opened or is already open, False otherwise
"""
# Skip if not PostgreSQL
if not self.checkpointer or not hasattr(self.checkpointer, "conn"):
return False
try:
conn = self.checkpointer.conn
# Handle different pool implementations
if hasattr(conn, "is_open"):
# Modern psycopg pools have is_open method
if not conn.is_open():
logger.info("Opening PostgreSQL connection pool")
conn.open()
self.pool_opened = True
return True
if hasattr(conn, "_opened"):
# Older versions use _opened attribute
if not conn._opened:
logger.info("Opening PostgreSQL connection pool (legacy)")
conn._opened = True
self.pool_opened = True
return True
# Not a pool or unknown implementation
logger.debug("Unknown pool implementation, assuming already open")
return True
except Exception as e:
logger.error(f"Error ensuring pool is open: {e}", exc_info=True)
return False
[docs]
def close_pool_if_needed(self) -> None:
"""Close the PostgreSQL connection pool if it was opened by this manager."""
# Skip if not PostgreSQL or pool not opened by us
if (
not self.checkpointer
or not hasattr(self.checkpointer, "conn")
or not self.pool_opened
):
return
try:
pool = self.checkpointer.conn
# Close if it's a sync pool
if hasattr(pool, "is_open") and pool.is_open():
logger.debug("Closing PostgreSQL connection pool")
# We don't actually close the pool unless explicitly needed
self.pool_opened = False
except Exception as e:
logger.exception(f"Error closing pool: {e}")
[docs]
def register_thread(self, thread_id: str, auth_info=None):
"""Register a thread in the PostgreSQL database, including user context from Supabase.
Args:
thread_id: Thread ID to register
auth_info: Optional authentication information
Returns:
True if registration succeeded, False otherwise
"""
# Skip if not PostgreSQL
if not self.checkpointer or not hasattr(self.checkpointer, "conn"):
logger.debug("Skipping thread registration - not using PostgreSQL")
return False
if not thread_id:
logger.warning("Cannot register thread with empty thread_id")
return False
try:
# Ensure pool is open
self.ensure_pool_open()
# Extract user information from auth_info
metadata = {}
user_id = None
if auth_info:
# Extract supabase_user_id specifically to match the test
# expectations
user_id = auth_info.get("supabase_user_id")
logger.debug(f"Registering thread {thread_id} with user_id={user_id}")
# Make a copy of auth_info to avoid modifying the original
metadata = dict(auth_info)
# Serialize metadata to JSON string
metadata_json = json.dumps(metadata)
# Register the thread - use a transaction
with self.checkpointer.conn.connection() as conn:
# Create a savepoint to roll back to if necessary
with conn.cursor() as cursor:
# Check if threads table exists and create if needed
self._ensure_threads_table_exists(cursor)
# Register the thread
if user_id:
self._register_thread_with_user(
cursor, thread_id, metadata_json, user_id
)
else:
self._register_thread_without_user(
cursor, thread_id, metadata_json
)
# Commit is automatic with autocommit=True (our default)
logger.debug(f"Thread {thread_id} registered in PostgreSQL")
return True
except Exception as e:
logger.warning(f"Error registering thread: {e}", exc_info=True)
return False
def _ensure_threads_table_exists(self, cursor):
"""Ensure the threads table exists, creating it if necessary.
Args:
cursor: Database cursor
"""
cursor.execute(
"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'threads'
);
"""
)
table_exists = cursor.fetchone()[0]
if not table_exists:
logger.info("Creating threads table")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threads (
thread_id VARCHAR(255) PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
last_access TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
metadata JSONB DEFAULT '{}'::jsonb,
user_id VARCHAR(255) NULL
);
"""
)
def _register_thread_with_user(self, cursor, thread_id, metadata_json, user_id):
"""Register a thread with user information.
Args:
cursor: Database cursor
thread_id: Thread ID
metadata_json: Serialized metadata JSON
user_id: User ID
"""
cursor.execute(
"""
INSERT INTO threads (thread_id, last_access, metadata, user_id)
VALUES (%s, CURRENT_TIMESTAMP, %s, %s)
ON CONFLICT (thread_id)
DO UPDATE SET
last_access = CURRENT_TIMESTAMP,
metadata = threads.metadata || %s::jsonb,
user_id = COALESCE(%s, threads.user_id)
""",
(thread_id, metadata_json, user_id, metadata_json, user_id),
)
# Verify the insertion for debugging
if logger.isEnabledFor(logging.DEBUG):
cursor.execute(
"""
SELECT thread_id, user_id FROM threads WHERE thread_id = %s
""",
(thread_id,),
)
result = cursor.fetchone()
if result:
logger.debug(f"Verified thread {result[0]} with user_id={result[1]}")
else:
logger.warning(f"Failed to verify thread {thread_id} after insertion")
def _register_thread_without_user(self, cursor, thread_id, metadata_json):
"""Register a thread without user information.
Args:
cursor: Database cursor
thread_id: Thread ID
metadata_json: Serialized metadata JSON
"""
cursor.execute(
"""
INSERT INTO threads (thread_id, last_access, metadata)
VALUES (%s, CURRENT_TIMESTAMP, %s)
ON CONFLICT (thread_id)
DO UPDATE SET
last_access = CURRENT_TIMESTAMP,
metadata = threads.metadata || %s::jsonb
""",
(thread_id, metadata_json, metadata_json),
)
[docs]
def create_runnable_config(
self, thread_id: str | None = None, user_info=None, **kwargs
):
"""Create a RunnableConfig with proper thread ID and authentication context.
Args:
thread_id: Optional thread ID for persistence
user_info: Optional user information dictionary (Supabase)
**kwargs: Additional runtime configuration
Returns:
RunnableConfig with thread ID and authentication context
"""
# Create with Supabase authentication if user_info provided
if user_info:
supabase_user_id = user_info.get("supabase_user_id")
username = user_info.get("username")
email = user_info.get("email")
config = HaiveRunnableConfigManager.create_with_auth(
supabase_user_id=supabase_user_id,
username=username,
email=email,
thread_id=thread_id,
**kwargs,
)
else:
# Otherwise, just create with thread ID
config = HaiveRunnableConfigManager.create(thread_id=thread_id, **kwargs)
# Extract current thread ID for possible registration
current_thread_id = HaiveRunnableConfigManager.get_thread_id(config)
# Add persistence information for PostgreSQL
if self.checkpointer and hasattr(self.checkpointer, "conn"):
config = HaiveRunnableConfigManager.add_persistence_info(
config,
persistence_type="postgres",
setup_needed=self.postgres_setup_needed,
)
return config, current_thread_id
[docs]
def prepare_for_agent_run(
self, thread_id: str | None = None, user_info=None, **kwargs
):
"""Comprehensive preparation for an agent run, handling thread registration,.
configuration creation, and database setup.
Args:
thread_id: Optional thread ID for persistence
user_info: Optional user information dictionary (Supabase)
**kwargs: Additional runtime configuration
Returns:
Tuple of (RunnableConfig, current_thread_id)
"""
# Create configuration
config, current_thread_id = self.create_runnable_config(
thread_id, user_info, **kwargs
)
logger.debug(f"Created runnable config with thread_id={current_thread_id}")
# Setup checkpointer if needed
if self.postgres_setup_needed:
setup_success = self.setup()
self.postgres_setup_needed = False
logger.debug(f"Setup checkpointer: {setup_success}")
# Extract auth info from config
auth_info = HaiveRunnableConfigManager.get_auth_info(config)
# Make a copy of auth_info to avoid modifying the original
if auth_info and not isinstance(auth_info, dict):
auth_info = dict(auth_info)
# Make sure to include supabase_user_id for proper thread registration
if user_info and "supabase_user_id" in user_info and auth_info:
auth_info["supabase_user_id"] = user_info["supabase_user_id"]
# Register thread with authentication context
register_result = self.register_thread(current_thread_id, auth_info)
logger.debug(f"Thread registration result: {register_result}")
return config, current_thread_id
[docs]
@staticmethod
def get_or_create_thread_id(config: dict[str, Any] | None = None):
"""Get thread ID from config or create a new one.
Args:
config: Optional RunnableConfig
Returns:
Thread ID string
"""
if (
config
and "configurable" in config
and "thread_id" in config["configurable"]
):
return config["configurable"]["thread_id"]
return str(uuid.uuid4())
[docs]
def list_threads(
self,
user_id: str | None = None,
thread_id: str | None = None,
limit: int = 100,
offset: int = 0,
):
"""List threads from the PostgreSQL database.
Args:
user_id: Optional user ID filter
thread_id: Optional thread ID filter for single thread lookup
limit: Maximum number of threads to return
offset: Offset for pagination
Returns:
List of thread information dictionaries
"""
# Skip if not PostgreSQL
if not self.checkpointer or not hasattr(self.checkpointer, "conn"):
return []
try:
# Ensure pool is open
self.ensure_pool_open()
# Query threads
with self.checkpointer.conn.connection() as conn, conn.cursor() as cursor:
# First check if the threads table exists
cursor.execute(
"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'threads'
);
"""
)
table_exists = cursor.fetchone()[0]
if not table_exists:
logger.debug("Threads table does not exist yet")
return []
# Build the query based on filters
if thread_id:
logger.debug(f"Filtering by thread_id={thread_id}")
query = """
SELECT thread_id, metadata, user_id, created_at, last_access
FROM threads
WHERE thread_id = %s
LIMIT 1
"""
params = (thread_id,)
elif user_id:
logger.debug(
f"Filtering by user_id={user_id} (type: {type(user_id).__name__})"
)
query = """
SELECT thread_id, metadata, user_id, created_at, last_access
FROM threads
WHERE user_id = %s
ORDER BY last_access DESC
LIMIT %s OFFSET %s
"""
params = (user_id, limit, offset)
else:
logger.debug(
f"No filters, fetching all threads with limit={limit}, offset={offset}"
)
query = """
SELECT thread_id, metadata, user_id, created_at, last_access
FROM threads
ORDER BY last_access DESC
LIMIT %s OFFSET %s
"""
params = (limit, offset)
# Execute the query
cursor.execute(query, params)
results = cursor.fetchall()
if user_id:
logger.debug(f"Found {len(results)} threads for user_id={user_id}")
# For debugging purposes
if logger.isEnabledFor(logging.DEBUG):
cursor.execute("SELECT COUNT(*) FROM threads")
total = cursor.fetchone()[0]
logger.debug(f"Total threads in database: {total}")
cursor.execute("SELECT thread_id, user_id FROM threads LIMIT 5")
sample_threads = cursor.fetchall()
logger.debug(f"Sample threads: {sample_threads}")
# Process results
threads = []
for result in results:
thread_id, metadata, user_id, created_at, last_access = result
thread_info = self._process_thread_result(
thread_id, metadata, user_id, created_at, last_access
)
threads.append(thread_info)
return threads
except Exception as e:
logger.error(f"Error listing threads: {e}", exc_info=True)
return []
def _process_thread_result(
self, thread_id, metadata, user_id, created_at, last_access
):
"""Process a thread result row into a dictionary.
Args:
thread_id: Thread ID
metadata: Metadata JSON or dictionary
user_id: User ID
created_at: Creation timestamp
last_access: Last access timestamp
Returns:
Thread information dictionary
"""
# Parse metadata JSON
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except (json.JSONDecodeError, TypeError):
logger.warning(f"Failed to parse metadata JSON for thread {thread_id}")
metadata = {}
elif metadata is None:
metadata = {}
# Format timestamps
try:
# For datetime objects
if hasattr(created_at, "isoformat"):
created_at = created_at.isoformat()
if hasattr(last_access, "isoformat"):
last_access = last_access.isoformat()
except Exception as e:
logger.warning(f"Error formatting timestamps: {e}")
# Extract username from metadata for convenience
username = metadata.get("username", "Unknown")
# Build thread info
return {
"thread_id": thread_id,
"user_id": user_id,
"username": username,
"created_at": created_at,
"last_access": last_access,
"metadata": metadata,
}
[docs]
def delete_thread(self, thread_id: str):
"""Delete a thread from the PostgreSQL database.
Args:
thread_id: Thread ID to delete
Returns:
True if deletion succeeded, False otherwise
"""
# Skip if not PostgreSQL
if not self.checkpointer or not hasattr(self.checkpointer, "conn"):
logger.debug("Skipping thread deletion - not using PostgreSQL")
return False
if not thread_id:
logger.warning("Cannot delete thread with empty thread_id")
return False
try:
# Ensure pool is open
self.ensure_pool_open()
# Delete thread
with self.checkpointer.conn.connection() as conn, conn.cursor() as cursor:
# First check if the threads table exists
cursor.execute(
"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'threads'
);
"""
)
table_exists = cursor.fetchone()[0]
if not table_exists:
logger.debug("Threads table does not exist, nothing to delete")
return True
# Delete the thread
cursor.execute(
"""
DELETE FROM threads
WHERE thread_id = %s
RETURNING thread_id
""",
(thread_id,),
)
# Check if any rows were affected
deleted = cursor.fetchone()
if deleted:
logger.info(f"Thread {thread_id} deleted from PostgreSQL")
return True
logger.debug(f"Thread {thread_id} not found in database")
return False
except Exception as e:
logger.warning(f"Error deleting thread: {e}", exc_info=True)
return False
[docs]
@classmethod
def from_config(
cls,
db_host="localhost",
db_port=5432,
db_name="postgres",
db_user="postgres",
db_pass="postgres",
use_async=False,
use_pool=True,
setup_needed=True,
):
"""Create a PersistenceManager from database configuration.
Args:
db_host: Database host
db_port: Database port
db_name: Database name
db_user: Database user
db_pass: Database password
use_async: Whether to use async connections
use_pool: Whether to use connection pooling
setup_needed: Whether table setup is needed
Returns:
Configured PersistenceManager
"""
persistence_config = {
"persistence_type": CheckpointerType.postgres,
"persistence_config": {
"db_host": db_host,
"db_port": db_port,
"db_name": db_name,
"db_user": db_user,
"db_pass": db_pass,
"use_async": use_async,
"use_pool": use_pool,
"setup_needed": setup_needed,
},
}
manager = cls(persistence_config)
manager.get_checkpointer()
return manager
[docs]
@classmethod
def from_env(cls) -> Any:
"""Create a PersistenceManager from environment variables.
Returns:
Configured PersistenceManager
"""
import os
persistence_config = {
"persistence_type": CheckpointerType.postgres,
"persistence_config": {
"db_host": os.environ.get("POSTGRES_HOST", "localhost"),
"db_port": int(os.environ.get("POSTGRES_PORT", 5432)),
"db_name": os.environ.get("POSTGRES_DB", "postgres"),
"db_user": os.environ.get("POSTGRES_USER", "postgres"),
"db_pass": os.environ.get("POSTGRES_PASSWORD", "postgres"),
"ssl_mode": os.environ.get("POSTGRES_SSL_MODE", "disable"),
"use_async": os.environ.get("POSTGRES_USE_ASYNC", "false").lower()
== "true",
"use_pool": os.environ.get("POSTGRES_USE_POOL", "true").lower()
== "true",
"setup_needed": os.environ.get("POSTGRES_SETUP_NEEDED", "true").lower()
== "true",
},
}
manager = cls(persistence_config)
manager.get_checkpointer()
return manager