agents.rag.db_rag.sql_rag.agent¶
SQL RAG Agent for natural language database querying.
This module implements a sophisticated SQL Retrieval-Augmented Generation (RAG) agent that enables natural language querying of SQL databases. The agent uses a multi-step workflow to analyze queries, generate SQL, validate results, and produce accurate answers.
Example
Basic usage of the SQL RAG Agent:
>>> from haive.agents.rag.db_rag.sql_rag import SQLRAGAgent, SQLRAGConfig
>>>
>>> # Create configuration
>>> config = SQLRAGConfig(
... domain_name="sales",
... db_config=SQLDatabaseConfig(
... db_type="postgresql",
... db_name="sales_db"
... )
... )
>>>
>>> # Initialize agent
>>> agent = SQLRAGAgent(config)
>>>
>>> # Query the database
>>> result = agent.invoke({
... "question": "What were the total sales last quarter?"
... })
>>> print(result["answer"])
'Total sales last quarter were $1.2M across 450 transactions...'
- agents.rag.db_rag.sql_rag.agent.logger¶
Module-level logger for debugging and monitoring.
- Type:
Note
This agent requires proper database credentials and connection details to be configured either through environment variables or explicit configuration.
Classes¶
SQL RAG Agent for querying SQL databases with natural language. |
Module Contents¶
- class agents.rag.db_rag.sql_rag.agent.SQLRAGAgent(config)¶
Bases:
haive.core.engine.agent.agent.Agent
[haive.agents.rag.db_rag.sql_rag.config.SQLRAGConfig
]SQL RAG Agent for querying SQL databases with natural language.
This agent implements a sophisticated workflow for converting natural language questions into SQL queries, executing them safely, and generating accurate natural language responses. The workflow includes:
Domain Relevance Check: Validates if the query is database-related
Schema Retrieval: Gets relevant database schema information
Query Analysis: Determines required tables, columns, and operations
SQL Generation: Creates syntactically correct SQL from natural language
Validation & Correction: Ensures SQL is safe and correct
Execution: Runs the query safely with proper error handling
Answer Generation: Produces natural language answers from results
- sql_db¶
Connected database instance
- Type:
SQLDatabase
- toolkit¶
LangChain SQL toolkit instance
- Type:
SQLDatabaseToolkit
- tools¶
Available database tools
- Type:
List[BaseTool]
Example
Creating and using a SQL RAG Agent:
>>> # Configure for a PostgreSQL database >>> config = SQLRAGConfig( ... domain_name="e-commerce", ... db_config=SQLDatabaseConfig( ... db_type="postgresql", ... db_host="localhost", ... db_name="shop_db", ... include_tables=["orders", "products", "customers"] ... ), ... hallucination_check=True, ... max_iterations=3 ... ) >>> >>> # Create agent >>> agent = SQLRAGAgent(config) >>> >>> # Ask questions in natural language >>> response = agent.invoke({ ... "question": "Which products had the highest sales last month?" ... }) >>> >>> # Access results >>> print(f"Answer: {response['answer']}") >>> print(f"SQL Used: {response['sql_statement']}")
Note
The agent includes safety features like SQL validation, hallucination detection, and query result verification to ensure accurate responses.
Initialize the SQL RAG Agent with the given configuration.
- Parameters:
config (SQLRAGConfig) – Configuration object containing database connection details, LLM settings, and workflow parameters.
- Raises:
ValueError – If database connection fails or required engines are missing from the configuration.
Example
>>> config = SQLRAGConfig( ... db_config=SQLDatabaseConfig(db_uri="sqlite:///sales.db") ... ) >>> agent = SQLRAGAgent(config)
- analyze_query(state)¶
Analyze the query to determine relevant tables and fields.
This method uses an LLM to understand the user’s natural language question and identify which database tables, columns, joins, and aggregations will be needed to answer it.
- Parameters:
state (OverallState) – Current state containing the question and schema information.
- Returns:
- Update command with query analysis including:
relevant_tables: Tables needed for the query
needed_columns: Specific columns to select
constraints: WHERE clause conditions
aggregations: GROUP BY/aggregate functions needed
joins_needed: Required table joins
- Return type:
Command
Example
Analyzing a complex query:
>>> state = OverallState( ... question="Show me top 5 customers by total order value" ... ) >>> command = agent.analyze_query(state) >>> analysis = command.update["query_analysis"] >>> print(analysis.relevant_tables) ['customers', 'orders', 'order_items'] >>> print(analysis.aggregations) ['SUM(order_items.price * order_items.quantity)']
Note
The analysis helps the SQL generation step create more accurate queries by understanding the semantic intent.
- check_domain_relevance(state)¶
Determine if the query is relevant to databases.
This method uses a guardrails LLM to check if the user’s question is about querying the database or if it’s an unrelated question that should be rejected.
- Parameters:
state (OverallState) – Current state containing the question.
- Returns:
- Update command with next_action set to either
’end’ (if irrelevant) or continue to schema retrieval.
- Return type:
Command
Example
>>> state = OverallState(question="What tables are in the database?") >>> command = agent.check_domain_relevance(state) >>> print(command.update["next_action"]) 'retrieve_schema'
Note
Questions like “What’s the weather?” or “Tell me a joke” will be rejected with an appropriate message.
- correct_query(state)¶
Correct the SQL query based on validation errors.
This method takes a query with validation errors and attempts to fix them, learning from the specific issues identified.
- Parameters:
state (OverallState) – Current state containing the invalid query and the list of errors.
- Returns:
- Update command with corrected SQL query and instruction
to re-validate.
- Return type:
Command
Example
Correcting a query with a missing GROUP BY:
>>> state = OverallState( ... sql_query="SELECT category, SUM(amount) FROM sales", ... sql_errors=["Missing GROUP BY clause for 'category'"] ... ) >>> command = agent.correct_query(state) >>> print(command.update["sql_query"]) 'SELECT category, SUM(amount) FROM sales GROUP BY category'
Note
The correction process may iterate multiple times based on the max_iterations configuration setting.
- domain_router(state)¶
Route based on domain relevance check.
- Parameters:
state (OverallState) – Current state with next_action field.
- Returns:
Next node name - either END or “retrieve_schema”.
- Return type:
- execute_query(state)¶
Execute the SQL query against the database.
This method safely executes the validated SQL query and captures the results. It includes error handling and timeout protection.
- Parameters:
state (OverallState) – Current state containing the validated SQL query to execute.
- Returns:
Update command with query results or error message.
- Return type:
Command
Example
Executing a query and handling results:
>>> state = OverallState( ... sql_query="SELECT COUNT(*) as total FROM orders" ... ) >>> command = agent.execute_query(state) >>> print(command.update["query_result"]) 'total\n-----\n1234'
Warning
Only SELECT queries are executed. Any DML operations (INSERT, UPDATE, DELETE) are rejected during validation.
- generate_answer(state)¶
Generate the final answer based on the query results.
This method converts the raw SQL query results into a natural language answer that directly addresses the user’s question. It includes optional hallucination checking to ensure accuracy.
- Parameters:
state (OverallState) – Current state containing the question, SQL query, and query results.
- Returns:
- Update command with the final answer and optional
hallucination check results.
- Return type:
Command
Example
Generating an answer with hallucination checking:
>>> state = OverallState( ... question="Who is our top customer?", ... query_result="customer_name | total_spent\\n-----------\\nAcme Corp | 50000" ... ) >>> command = agent.generate_answer(state) >>> print(command.update["answer"]) 'Your top customer is Acme Corp with total spending of $50,000.' >>> print(command.update["hallucination_check"]) 'no' # No hallucinations detected
Note
If hallucination checking is enabled and detects issues, a warning is appended to the answer.
- generate_query(state)¶
Generate an SQL query from the natural language question.
This method converts the user’s natural language question into a syntactically correct SQL query, using the analysis results and database schema information.
- Parameters:
state (OverallState) – Current state containing question, schema, and query analysis.
- Returns:
- Update command with the generated SQL query,
formatted for readability.
- Return type:
Command
Example
SQL generation from natural language:
>>> state = OverallState( ... question="What were the total sales by category last month?" ... ) >>> command = agent.generate_query(state) >>> print(command.update["sql_query"]) ''' SELECT c.category_name, SUM(oi.quantity * oi.price) as total_sales FROM categories c JOIN products p ON c.category_id = p.category_id JOIN order_items oi ON p.product_id = oi.product_id JOIN orders o ON oi.order_id = o.order_id WHERE o.order_date >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH) GROUP BY c.category_name ORDER BY total_sales DESC '''
Note
The method includes special handling for metadata queries like “what tables are in the database” that don’t require complex SQL.
- retrieve_schema(state)¶
Retrieve database schema information for context.
This method gathers detailed schema information about all relevant tables in the database, including column names, types, and relationships.
- Parameters:
state (OverallState) – Current state of the workflow.
- Returns:
- Update command with schema information and instruction
to proceed to query analysis.
- Return type:
Command
Example
Schema retrieval for a sales database:
>>> state = OverallState(question="Show me all orders") >>> command = agent.retrieve_schema(state) >>> schema_info = command.update["schema_info"] >>> print(f"Found {len(schema_info)} tables") Found 5 tables
Note
For large databases, this method limits schema retrieval to the first 10 tables to avoid overwhelming the LLM context.
- setup_workflow()¶
Set up the SQL RAG workflow graph.
This method constructs the workflow graph with all necessary nodes and edges, including conditional routing based on validation results.
The workflow follows this path: 1. START -> check_domain_relevance 2. Domain routing (end if irrelevant) 3. retrieve_schema -> analyze_query -> generate_query 4. validate_query with correction loop if needed 5. execute_query -> generate_answer -> END
Note
This method is called automatically during agent initialization and should not be called directly.
- Return type:
None
- validate_query(state)¶
Validate the SQL query for syntax and schema correctness.
This method performs comprehensive validation of the generated SQL including syntax checking, table/column name verification, join validation, and security checks.
- Parameters:
state (OverallState) – Current state containing the SQL query to validate.
- Returns:
- Update command with either:
next_action=”execute_query” if valid
next_action=”correct_query” if errors found
sql_errors list containing any validation errors
- Return type:
Command
Example
Validating a query with an error:
>>> state = OverallState( ... sql_query="SELECT * FROM non_existent_table" ... ) >>> command = agent.validate_query(state) >>> print(command.update["sql_errors"]) ["Table 'non_existent_table' does not exist in the schema"] >>> print(command.update["next_action"]) 'correct_query'
Note
The validation prevents SQL injection and ensures only safe SELECT queries are executed.
- validation_router(state)¶
Route based on query validation results.
- Parameters:
state (OverallState) – Current state with next_action field.
- Returns:
Next node name - END, “correct_query”, or “execute_query”.
- Return type: