Source code for haive.core.schema.prebuilt.messages.compatibility
"""Compatibility layer for MessagesState implementations.
This module provides adapter classes and utilities that enable backward
compatibility while adding new features from the enhanced MessagesState
implementation. It serves as a bridge between the old and new architectures.
"""
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langgraph.graph import END
from langgraph.types import Send
from haive.core.schema.prebuilt.messages.utils import (
MessageRound,
ToolCallInfo,
extract_tool_calls,
inject_state_into_tool_calls,
is_real_human_message,
is_tool_error,
)
[docs]
class MessagesStateAdapter:
"""Adapter that enables old MessagesState instances to use new features.
with minimal changes to their API.
This adapter wraps an existing MessagesState instance and provides
methods that implement the enhanced functionality from the new
MessagesState architecture.
"""
def __init__(self, messages_state) -> None:
"""Initialize the adapter with an existing MessagesState instance.
Args:
messages_state: The MessagesState instance to adapt
"""
self.state = messages_state
[docs]
def get_conversation_rounds(self) -> list[MessageRound]:
"""Get detailed information about each conversation round.
A conversation round typically consists of a human message,
followed by one or more AI responses and possibly tool calls/responses.
Returns:
List of MessageRound objects with round details
"""
rounds = []
current_round = None
round_number = 0
for msg in self.state.messages:
if isinstance(msg, HumanMessage) and is_real_human_message(msg):
# Start a new round
if current_round:
current_round.is_complete = self._is_round_complete(current_round)
rounds.append(current_round)
round_number += 1
current_round = MessageRound(
round_number=round_number, human_message=msg
)
elif current_round:
if isinstance(msg, AIMessage):
current_round.ai_responses.append(msg)
# Track tool calls
tool_calls = extract_tool_calls(msg)
if tool_calls:
current_round.tool_calls.extend(tool_calls)
elif isinstance(msg, ToolMessage):
current_round.tool_responses.append(msg)
# Check for errors
if is_tool_error(msg):
current_round.has_errors = True
# Add the last round if it exists
if current_round:
current_round.is_complete = self._is_round_complete(current_round)
rounds.append(current_round)
return rounds
def _is_round_complete(self, round_info: MessageRound) -> bool:
"""Check if a conversation round is complete.
A round is considered complete if:
1. There's at least one AI response
2. All tool calls have corresponding responses (if any)
Args:
round_info: The round to check
Returns:
True if the round is complete, False otherwise
"""
if not round_info.ai_responses:
return False
tool_call_ids = set()
for tool_call in round_info.tool_calls:
if isinstance(tool_call, dict) and "id" in tool_call:
tool_call_ids.add(tool_call["id"])
elif hasattr(tool_call, "id") and tool_call.id:
tool_call_ids.add(tool_call.id)
tool_response_ids = set()
for tool_response in round_info.tool_responses:
if hasattr(tool_response, "tool_call_id") and tool_response.tool_call_id:
tool_response_ids.add(tool_response.tool_call_id)
# Round is complete if all tool calls have responses
return tool_call_ids.issubset(tool_response_ids)
[docs]
def deduplicate_tool_calls(self) -> int:
"""Remove duplicate tool calls based on tool call ID.
This is useful when the same API call might be made multiple times
due to agent or LLM quirks.
Returns:
Number of duplicates removed
"""
seen_tool_call_ids = set()
duplicates_removed = 0
for msg in self.state.messages:
if not isinstance(msg, AIMessage):
continue
tool_calls = extract_tool_calls(msg)
if not tool_calls:
continue
unique_tool_calls = []
for tool_call in tool_calls:
# Handle different tool call formats
if isinstance(tool_call, dict):
tool_call_id = tool_call.get("id")
else:
tool_call_id = getattr(tool_call, "id", None)
if tool_call_id and tool_call_id not in seen_tool_call_ids:
unique_tool_calls.append(tool_call)
seen_tool_call_ids.add(tool_call_id)
elif tool_call_id and tool_call_id in seen_tool_call_ids:
duplicates_removed += 1
elif not tool_call_id:
# If no ID, keep it (can't deduplicate)
unique_tool_calls.append(tool_call)
# Update tool_calls on the message
if hasattr(msg, "tool_calls"):
msg.tool_calls = unique_tool_calls
elif (
hasattr(msg, "additional_kwargs")
and "tool_calls" in msg.additional_kwargs
):
msg.additional_kwargs["tool_calls"] = unique_tool_calls
return duplicates_removed
[docs]
def get_completed_tool_calls(self) -> list[ToolCallInfo]:
"""Get all completed tool calls with their responses.
This method matches tool calls in AI messages with their
corresponding tool responses.
Returns:
List of ToolCallInfo objects with tool call details
"""
completed = []
# Build a mapping of tool call IDs to their messages
tool_messages = {}
for msg in self.state.messages:
if isinstance(msg, ToolMessage):
tool_call_id = getattr(msg, "tool_call_id", None)
if tool_call_id:
is_error = is_tool_error(msg)
tool_messages[tool_call_id] = {"message": msg, "is_error": is_error}
# Find AI messages with tool calls and match them to tool messages
for msg in self.state.messages:
if not isinstance(msg, AIMessage):
continue
tool_calls = extract_tool_calls(msg)
if not tool_calls:
continue
for tool_call in tool_calls:
# Handle different tool call formats
if isinstance(tool_call, dict):
tool_call_id = tool_call.get("id")
else:
tool_call_id = getattr(tool_call, "id", None)
if tool_call_id and tool_call_id in tool_messages:
tool_msg_info = tool_messages[tool_call_id]
completed.append(
ToolCallInfo(
tool_call_id=tool_call_id,
tool_call=tool_call,
tool_message=tool_msg_info["message"],
ai_message=msg,
is_successful=not tool_msg_info["is_error"],
)
)
return completed
[docs]
def send_tool_calls(self, node_name: str = "tools") -> str | list[Send]:
"""Convert tool calls from the last AI message into Send objects for LangGraph routing.
Args:
node_name: The name of the node to send tool calls to
Returns:
Either a string (if no tool calls) or a list of Send objects
"""
last_ai = None
for msg in reversed(self.state.messages):
if isinstance(msg, AIMessage):
last_ai = msg
break
if not last_ai:
return END
tool_calls = extract_tool_calls(last_ai)
if not tool_calls:
return END
# Create state data to inject
state_data = {"messages": self.state.messages}
if hasattr(self.state, "model_dump"):
state_data = self.state.model_dump()
# Inject state into tool calls
injected_calls = inject_state_into_tool_calls(tool_calls, state_data)
# Create a Send object for each tool call
return [Send(node_name, tool_call) for tool_call in injected_calls]
[docs]
def transform_ai_to_human(
self,
preserve_metadata: bool = True,
engine_id: str | None = None,
engine_name: str | None = None,
) -> None:
"""Transform AI messages to Human messages in place.
This is useful for agent-to-agent communication or for
creating synthetic conversations.
Args:
preserve_metadata: Whether to preserve message metadata
engine_id: Optional engine ID to add to transformed messages
engine_name: Optional engine name to add to transformed messages
"""
transformed_messages = []
for msg in self.state.messages:
if isinstance(msg, AIMessage):
kwargs = {"content": msg.content}
if preserve_metadata:
if hasattr(msg, "additional_kwargs") and msg.additional_kwargs:
kwargs["additional_kwargs"] = msg.additional_kwargs.copy()
if engine_id:
if "additional_kwargs" not in kwargs:
kwargs["additional_kwargs"] = {}
kwargs["additional_kwargs"]["engine_id"] = engine_id
if hasattr(msg, "name") and msg.name:
kwargs["name"] = msg.name
transformed_messages.append(HumanMessage(**kwargs))
else:
transformed_messages.append(msg)
self.state.messages = transformed_messages