"""Utility functions for the Haive persistence system.
This module provides helper functions for working with checkpointers and their
associated resources. It includes utilities for connection pool management,
serialization/deserialization of metadata, and other common operations needed
across different persistence implementations.
The utilities are designed to be used by the persistence system internals and
generally aren't intended to be used directly by application code. They provide
consistent behavior across different checkpointer implementations and handle
edge cases and error conditions gracefully.
"""
import json
import logging
from typing import Any
from psycopg_pool import ConnectionPool
logger = logging.getLogger(__name__)
[docs]
def ensure_pool_open(checkpointer: Any) -> Any | None:
"""Ensure that a PostgreSQL connection pool is properly opened.
This function checks if a checkpointer has an associated connection pool
and ensures that it's properly opened. It handles different pool
implementations and versions, checking appropriate attributes and calling
the open method if needed.
The function is used to ensure that connection pools are ready for use
before attempting database operations, preventing errors from closed
or unopened pools.
Args:
checkpointer: The checkpointer instance to check for a connection pool
Returns:
Optional[Any]: The opened connection pool if one was found and opened,
None otherwise
Note:
This function gracefully handles the case where psycopg_pool is not
available, making it safe to call even if the PostgreSQL dependencies
are not installed.
"""
if not hasattr(checkpointer, "conn"):
return None
conn = checkpointer.conn
try:
if isinstance(conn, ConnectionPool):
# Check if opened using the _opened attribute
is_open = getattr(conn, "_opened", False)
if not is_open:
logger.info("Opening PostgreSQL connection pool")
conn.open()
logger.info("Pool opened successfully")
return conn
except ImportError:
logger.debug("psycopg_pool not available")
# Make sure setup is called
if hasattr(checkpointer, "setup"):
try:
checkpointer.setup()
except Exception as e:
logger.exception(f"Error setting up checkpointer: {e}")
return None
[docs]
async def ensure_async_pool_open(checkpointer: Any) -> Any | None:
"""Ensure that an async PostgreSQL connection pool is properly opened.
This asynchronous function checks if an async checkpointer has an associated
connection pool and ensures that it's properly opened. It handles different
async pool implementations and versions, checking appropriate attributes and
calling the async open method if needed.
The function is particularly important for async contexts, where proper
connection management is critical for maintaining good performance and
resource utilization. It prevents errors from closed or unopened pools
in async code.
Args:
checkpointer: The async checkpointer instance to check for a connection pool
Returns:
Optional[Any]: The opened async connection pool if one was found and
opened, None otherwise
Note:
This function gracefully handles the case where the async PostgreSQL
dependencies are not available, making it safe to call even if the
async database modules are not installed.
Examples:
async def prepare_checkpointer(checkpointer):
# Ensure the pool is open before using it
pool = await ensure_async_pool_open(checkpointer)
if pool:
print("Pool is ready for use")
# Continue with checkpointer operations...
"""
opened_pool = None
try:
# Check for connection pools in the checkpointer
if hasattr(checkpointer, "conn"):
conn = checkpointer.conn
# Import here to avoid dependency issues
try:
# Check if it's an async pool
if isinstance(conn, AsyncPool):
# Check if the pool is already open
try:
if hasattr(conn, "is_open"):
is_open = conn.is_open()
else:
# Older versions might not have is_open()
is_open = getattr(conn, "_opened", False)
# Open the pool if needed
if not is_open:
logger.info("Opening async PostgreSQL connection pool")
try:
await conn.open()
opened_pool = conn
logger.info("Successfully opened async pool")
except Exception as e:
logger.exception(f"Error opening async pool: {e}")
except Exception as e:
logger.exception(f"Error checking if async pool is open: {e}")
except ImportError:
logger.debug("psycopg_pool AsyncPool not available")
# Additional check for other types of pools or connections
if not opened_pool and hasattr(checkpointer, "setup"):
# If the checkpointer has a setup method but no connection was found,
# just make sure tables are set up
logger.debug("No async pool found but checkpointer has setup method")
try:
await checkpointer.setup()
except Exception as e:
logger.exception(f"Error setting up async checkpointer: {e}")
except Exception as e:
logger.exception(f"Error ensuring async pool is open: {e}")
return opened_pool
# In utils.py
[docs]
def register_thread(
checkpointer: Any, thread_id: str, metadata: dict[str, Any] | None = None
) -> bool:
"""Register a thread in the PostgreSQL database if needed."""
try:
if hasattr(checkpointer, "conn"):
pool = checkpointer.conn
if pool:
# Ensure connection pool is usable
ensure_pool_open(checkpointer)
# Register the thread
with pool.connection() as conn, conn.cursor() as cursor:
# Check if threads table exists
cursor.execute(
"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'threads'
);
"""
)
table_exists = cursor.fetchone()[0]
if not table_exists:
logger.debug("Creating threads table")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threads (
thread_id VARCHAR(255) PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
last_access TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
metadata JSONB DEFAULT '{}'::jsonb,
user_id VARCHAR(255)
);
"""
)
# Convert metadata to JSON string first
metadata_json = json.dumps(metadata) if metadata else "{}"
# Insert the thread if not exists, or update last_access
cursor.execute(
"""
INSERT INTO threads (thread_id, last_access, metadata)
VALUES (%s, CURRENT_TIMESTAMP, %s::jsonb)
ON CONFLICT (thread_id)
DO UPDATE SET last_access = CURRENT_TIMESTAMP, metadata = %s::jsonb
""",
(thread_id, metadata_json, metadata_json),
)
logger.info(f"Thread {thread_id} registered/updated in PostgreSQL")
return True
except Exception as e:
logger.warning(f"Error registering thread: {e}")
return False
[docs]
async def register_thread_async(
checkpointer: Any, thread_id: str, metadata: dict[str, Any] | None = None
) -> bool:
"""Register a thread in the PostgreSQL database asynchronously."""
if not hasattr(checkpointer, "conn"):
return False
try:
metadata_json = json.dumps(metadata) if metadata else "{}"
async with checkpointer.conn.connection() as conn, conn.cursor() as cursor:
# Check if table exists
await cursor.execute(
"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'threads'
);
"""
)
table_exists = (await cursor.fetchone())[0]
if not table_exists:
await cursor.execute(
"""
CREATE TABLE IF NOT EXISTS threads (
thread_id VARCHAR(255) PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
last_access TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
metadata JSONB DEFAULT '{}'::jsonb,
user_id VARCHAR(255)
);
"""
)
# Use parameterized query with proper JSONB handling
await cursor.execute(
"""
INSERT INTO threads (thread_id, last_access, metadata)
VALUES (%s, CURRENT_TIMESTAMP, %s::jsonb)
ON CONFLICT (thread_id)
DO UPDATE SET last_access = CURRENT_TIMESTAMP, metadata = %s::jsonb
""",
(thread_id, metadata_json, metadata_json),
)
logger.info(f"Thread {thread_id} registered/updated in PostgreSQL (async)")
return True
except Exception as e:
logger.warning(f"Error registering thread asynchronously: {e}")
return False