Source code for langgraph.prebuilt.tool_validator

"""This module provides a ValidationNode class that can be used to validate tool calls
in a langchain graph. It applies a pydantic schema to tool_calls in the models' outputs,
and returns a ToolMessage with the validated content. If the schema is not valid, it
returns a ToolMessage with the error message. The ValidationNode can be used in a
StateGraph with a "messages" key or in a MessageGraph. If multiple tool calls are
requested, they will be run in parallel.
"""

from typing import (
    Any,
    Callable,
    Dict,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)

from langchain_core.messages import (
    AIMessage,
    AnyMessage,
    ToolCall,
    ToolMessage,
)
from langchain_core.runnables import (
    RunnableConfig,
)
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.tools import BaseTool, create_schema_from_function
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1

from langgraph.utils.runnable import RunnableCallable


def _default_format_error(
    error: BaseException,
    call: ToolCall,
    schema: Union[Type[BaseModel], Type[BaseModelV1]],
) -> str:
    """Default error formatting function."""
    return f"{repr(error)}\n\nRespond after fixing all validation errors."


[docs] class ValidationNode(RunnableCallable): """A node that validates all tools requests from the last AIMessage. It can be used either in StateGraph with a "messages" key or in MessageGraph. !!! note This node does not actually **run** the tools, it only validates the tool calls, which is useful for extraction and other use cases where you need to generate structured output that conforms to a complex schema without losing the original messages and tool IDs (for use in multi-turn conversations). Args: schemas: A list of schemas to validate the tool calls with. These can be any of the following: - A pydantic BaseModel class - A BaseTool instance (the args_schema will be used) - A function (a schema will be created from the function signature) format_error: A function that takes an exception, a ToolCall, and a schema and returns a formatted error string. By default, it returns the exception repr and a message to respond after fixing validation errors. name: The name of the node. tags: A list of tags to add to the node. Returns: (Union[Dict[str, List[ToolMessage]], Sequence[ToolMessage]]): A list of ToolMessages with the validated content or error messages. Examples: Example usage for re-prompting the model to generate a valid response: >>> from typing import Literal, Annotated >>> from typing_extensions import TypedDict ... >>> from langchain_anthropic import ChatAnthropic >>> from pydantic import BaseModel, field_validator ... >>> from langgraph.graph import END, START, StateGraph >>> from langgraph.prebuilt import ValidationNode >>> from langgraph.graph.message import add_messages ... ... >>> class SelectNumber(BaseModel): ... a: int ... ... @field_validator("a") ... def a_must_be_meaningful(cls, v): ... if v != 37: ... raise ValueError("Only 37 is allowed") ... return v ... ... >>> builder = StateGraph(Annotated[list, add_messages]) >>> llm = ChatAnthropic(model="claude-3-5-haiku-latest").bind_tools([SelectNumber]) >>> builder.add_node("model", llm) >>> builder.add_node("validation", ValidationNode([SelectNumber])) >>> builder.add_edge(START, "model") ... ... >>> def should_validate(state: list) -> Literal["validation", "__end__"]: ... if state[-1].tool_calls: ... return "validation" ... return END ... ... >>> builder.add_conditional_edges("model", should_validate) ... ... >>> def should_reprompt(state: list) -> Literal["model", "__end__"]: ... for msg in state[::-1]: ... # None of the tool calls were errors ... if msg.type == "ai": ... return END ... if msg.additional_kwargs.get("is_error"): ... return "model" ... return END ... ... >>> builder.add_conditional_edges("validation", should_reprompt) ... ... >>> graph = builder.compile() >>> res = graph.invoke(("user", "Select a number, any number")) >>> # Show the retry logic >>> for msg in res: ... msg.pretty_print() ================================ Human Message ================================= Select a number, any number ================================== Ai Message ================================== [{'id': 'toolu_01JSjT9Pq8hGmTgmMPc6KnvM', 'input': {'a': 42}, 'name': 'SelectNumber', 'type': 'tool_use'}] Tool Calls: SelectNumber (toolu_01JSjT9Pq8hGmTgmMPc6KnvM) Call ID: toolu_01JSjT9Pq8hGmTgmMPc6KnvM Args: a: 42 ================================= Tool Message ================================= Name: SelectNumber ValidationError(model='SelectNumber', errors=[{'loc': ('a',), 'msg': 'Only 37 is allowed', 'type': 'value_error'}]) Respond after fixing all validation errors. ================================== Ai Message ================================== [{'id': 'toolu_01PkxSVxNxc5wqwCPW1FiSmV', 'input': {'a': 37}, 'name': 'SelectNumber', 'type': 'tool_use'}] Tool Calls: SelectNumber (toolu_01PkxSVxNxc5wqwCPW1FiSmV) Call ID: toolu_01PkxSVxNxc5wqwCPW1FiSmV Args: a: 37 ================================= Tool Message ================================= Name: SelectNumber {"a": 37} """
[docs] def __init__( self, schemas: Sequence[Union[BaseTool, Type[BaseModel], Callable]], *, format_error: Optional[ Callable[[BaseException, ToolCall, Type[BaseModel]], str] ] = None, name: str = "validation", tags: Optional[list[str]] = None, ) -> None: super().__init__(self._func, None, name=name, tags=tags, trace=False) self._format_error = format_error or _default_format_error self.schemas_by_name: Dict[str, Type[BaseModel]] = {} for schema in schemas: if isinstance(schema, BaseTool): if schema.args_schema is None: raise ValueError( f"Tool {schema.name} does not have an args_schema defined." ) elif not isinstance( schema.args_schema, type ) or not is_basemodel_subclass(schema.args_schema): raise ValueError( "Validation node only works with tools that have a pydantic BaseModel args_schema. " f"Got {schema.name} with args_schema: {schema.args_schema}." ) self.schemas_by_name[schema.name] = schema.args_schema elif isinstance(schema, type) and issubclass( schema, (BaseModel, BaseModelV1) ): self.schemas_by_name[schema.__name__] = cast(Type[BaseModel], schema) elif callable(schema): base_model = create_schema_from_function("Validation", schema) self.schemas_by_name[schema.__name__] = base_model else: raise ValueError( f"Unsupported input to ValidationNode. Expected BaseModel, tool or function. Got: {type(schema)}." )
def _get_message( self, input: Union[list[AnyMessage], dict[str, Any]] ) -> Tuple[str, AIMessage]: """Extract the last AIMessage from the input.""" if isinstance(input, list): output_type = "list" messages: list = input elif messages := input.get("messages", []): output_type = "dict" else: raise ValueError("No message found in input") message: AnyMessage = messages[-1] if not isinstance(message, AIMessage): raise ValueError("Last message is not an AIMessage") return output_type, message def _func( self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig ) -> Any: """Validate and run tool calls synchronously.""" output_type, message = self._get_message(input) def run_one(call: ToolCall) -> ToolMessage: schema = self.schemas_by_name[call["name"]] try: if issubclass(schema, BaseModel): output = schema.model_validate(call["args"]) content = output.model_dump_json() elif issubclass(schema, BaseModelV1): output = schema.validate(call["args"]) content = output.json() else: raise ValueError( f"Unsupported schema type: {type(schema)}. Expected BaseModel or BaseModelV1." ) return ToolMessage( content=content, name=call["name"], tool_call_id=cast(str, call["id"]), ) except (ValidationError, ValidationErrorV1) as e: return ToolMessage( content=self._format_error(e, call, schema), name=call["name"], tool_call_id=cast(str, call["id"]), additional_kwargs={"is_error": True}, ) with get_executor_for_config(config) as executor: outputs = [*executor.map(run_one, message.tool_calls)] if output_type == "list": return outputs else: return {"messages": outputs}