Source code for langgraph.prebuilt.tool_node
import asyncio
import inspect
import json
from copy import copy, deepcopy
from dataclasses import replace
from typing import (
Any,
Callable,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
get_type_hints,
)
from langchain_core.messages import (
AIMessage,
AnyMessage,
ToolCall,
ToolMessage,
convert_to_messages,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
get_config_list,
get_executor_for_config,
)
from langchain_core.tools import BaseTool, InjectedToolArg
from langchain_core.tools import tool as create_tool
from langchain_core.tools.base import get_all_basemodel_annotations
from pydantic import BaseModel
from typing_extensions import Annotated, get_args, get_origin
from langgraph.errors import GraphBubbleUp
from langgraph.store.base import BaseStore
from langgraph.types import Command, Send
from langgraph.utils.runnable import RunnableCallable
INVALID_TOOL_NAME_ERROR_TEMPLATE = (
"Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
)
TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
def msg_content_output(output: Any) -> Union[str, list[dict]]:
recognized_content_block_types = ("image", "image_url", "text", "json")
if isinstance(output, str):
return output
elif isinstance(output, list) and all(
[
isinstance(x, dict) and x.get("type") in recognized_content_block_types
for x in output
]
):
return output
# Technically a list of strings is also valid message content but it's not currently
# well tested that all chat models support this. And for backwards compatibility
# we want to make sure we don't break any existing ToolNode usage.
else:
try:
return json.dumps(output, ensure_ascii=False)
except Exception:
return str(output)
def _handle_tool_error(
e: Exception,
*,
flag: Union[
bool,
str,
Callable[..., str],
tuple[type[Exception], ...],
],
) -> str:
if isinstance(flag, (bool, tuple)):
content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
elif isinstance(flag, str):
content = flag
elif callable(flag):
content = flag(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {flag}"
)
return content
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
if params:
# If it's a method, the first argument is typically 'self' or 'cls'
if params[0].name in ["self", "cls"] and len(params) == 2:
first_param = params[1]
else:
first_param = params[0]
type_hints = get_type_hints(handler)
if first_param.name in type_hints:
origin = get_origin(first_param.annotation)
if origin is Union:
args = get_args(first_param.annotation)
if all(issubclass(arg, Exception) for arg in args):
return tuple(args)
else:
raise ValueError(
"All types in the error handler error annotation must be Exception types. "
"For example, `def custom_handler(e: Union[ValueError, TypeError])`. "
f"Got '{first_param.annotation}' instead."
)
exception_type = type_hints[first_param.name]
if Exception in exception_type.__mro__:
return (exception_type,)
else:
raise ValueError(
f"Arbitrary types are not supported in the error handler signature. "
"Please annotate the error with either a specific Exception type or a union of Exception types. "
"For example, `def custom_handler(e: ValueError)` or `def custom_handler(e: Union[ValueError, TypeError])`. "
f"Got '{exception_type}' instead."
)
# If no type information is available, return (Exception,) for backwards compatibility.
return (Exception,)
[docs]
class ToolNode(RunnableCallable):
"""A node that runs the tools called in the last AIMessage.
It can be used either in StateGraph with a "messages" state key (or a custom key passed via ToolNode's 'messages_key').
If multiple tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
Tool calls can also be passed directly as a list of `ToolCall` dicts.
Args:
tools: A sequence of tools that can be invoked by the ToolNode.
name: The name of the ToolNode in the graph. Defaults to "tools".
tags: Optional tags to associate with the node. Defaults to None.
handle_tool_errors: How to handle tool errors raised by tools inside the node. Defaults to True.
Must be one of the following:
- True: all errors will be caught and
a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned.
- str: all errors will be caught and
a ToolMessage with the string value of 'handle_tool_errors' will be returned.
- tuple[type[Exception], ...]: exceptions in the tuple will be caught and
a ToolMessage with a default error message (TOOL_CALL_ERROR_TEMPLATE) will be returned.
- Callable[..., str]: exceptions from the signature of the callable will be caught and
a ToolMessage with the string value of the result of the 'handle_tool_errors' callable will be returned.
- False: none of the errors raised by the tools will be caught
messages_key: The state key in the input that contains the list of messages.
The same key will be used for the output from the ToolNode.
Defaults to "messages".
The `ToolNode` is roughly analogous to:
```python
tools_by_name = {tool.name: tool for tool in tools}
def tool_node(state: dict):
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
```
Tool calls can also be passed directly to a ToolNode. This can be useful when using
the Send API, e.g., in a conditional edge:
```python
def example_conditional_edge(state: dict) -> List[Send]:
tool_calls = state["messages"][-1].tool_calls
# If tools rely on state or store variables (whose values are not generated
# directly by a model), you can inject them into the tool calls.
tool_calls = [
tool_node.inject_tool_args(call, state, store)
for call in last_message.tool_calls
]
return [Send("tools", [tool_call]) for tool_call in tool_calls]
```
Important:
- The input state can be one of the following:
- A dict with a messages key containing a list of messages.
- A list of messages.
- A list of tool calls.
- If operating on a message list, the last message must be an `AIMessage` with
`tool_calls` populated.
"""
name: str = "ToolNode"
[docs]
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
name: str = "tools",
tags: Optional[list[str]] = None,
handle_tool_errors: Union[
bool, str, Callable[..., str], tuple[type[Exception], ...]
] = True,
messages_key: str = "messages",
) -> None:
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
self.tools_by_name: dict[str, BaseTool] = {}
self.tool_to_state_args: dict[str, dict[str, Optional[str]]] = {}
self.tool_to_store_arg: dict[str, Optional[str]] = {}
self.handle_tool_errors = handle_tool_errors
self.messages_key = messages_key
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = create_tool(tool_)
self.tools_by_name[tool_.name] = tool_
self.tool_to_state_args[tool_.name] = _get_state_args(tool_)
self.tool_to_store_arg[tool_.name] = _get_store_arg(tool_)
def _func(
self,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
config: RunnableConfig,
*,
store: Optional[BaseStore],
) -> Any:
tool_calls, input_type = self._parse_input(input, store)
config_list = get_config_list(config, len(tool_calls))
input_types = [input_type] * len(tool_calls)
with get_executor_for_config(config) as executor:
outputs = [
*executor.map(self._run_one, tool_calls, input_types, config_list)
]
return self._combine_tool_outputs(outputs, input_type)
async def _afunc(
self,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
config: RunnableConfig,
*,
store: Optional[BaseStore],
) -> Any:
tool_calls, input_type = self._parse_input(input, store)
outputs = await asyncio.gather(
*(self._arun_one(call, input_type, config) for call in tool_calls)
)
return self._combine_tool_outputs(outputs, input_type)
def _combine_tool_outputs(
self,
outputs: list[ToolMessage],
input_type: Literal["list", "dict", "tool_calls"],
) -> list[Union[Command, list[ToolMessage], dict[str, list[ToolMessage]]]]:
# preserve existing behavior for non-command tool outputs for backwards
# compatibility
if not any(isinstance(output, Command) for output in outputs):
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if input_type == "list" else {self.messages_key: outputs}
# LangGraph will automatically handle list of Command and non-command node
# updates
combined_outputs: list[
Command | list[ToolMessage] | dict[str, list[ToolMessage]]
] = []
# combine all parent commands with goto into a single parent command
parent_command: Optional[Command] = None
for output in outputs:
if isinstance(output, Command):
if (
output.graph is Command.PARENT
and isinstance(output.goto, list)
and all(isinstance(send, Send) for send in output.goto)
):
if parent_command:
parent_command = replace(
parent_command,
goto=cast(list[Send], parent_command.goto) + output.goto,
)
else:
parent_command = Command(graph=Command.PARENT, goto=output.goto)
else:
combined_outputs.append(output)
else:
combined_outputs.append(
[output] if input_type == "list" else {self.messages_key: [output]}
)
if parent_command:
combined_outputs.append(parent_command)
return combined_outputs
def _run_one(
self,
call: ToolCall,
input_type: Literal["list", "dict", "tool_calls"],
config: RunnableConfig,
) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
try:
input = {**call, **{"type": "tool_call"}}
response = self.tools_by_name[call["name"]].invoke(input, config)
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
elif callable(self.handle_tool_errors):
handled_types = _infer_handled_types(self.handle_tool_errors)
else:
# default behavior is catching all exceptions
handled_types = (Exception,)
# Unhandled
if not self.handle_tool_errors or not isinstance(e, handled_types):
raise e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)
return ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)
async def _arun_one(
self,
call: ToolCall,
input_type: Literal["list", "dict", "tool_calls"],
config: RunnableConfig,
) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
try:
input = {**call, **{"type": "tool_call"}}
response = await self.tools_by_name[call["name"]].ainvoke(input, config)
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphBubbleUp as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
elif callable(self.handle_tool_errors):
handled_types = _infer_handled_types(self.handle_tool_errors)
else:
# default behavior is catching all exceptions
handled_types = (Exception,)
# Unhandled
if not self.handle_tool_errors or not isinstance(e, handled_types):
raise e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_errors)
return ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
elif isinstance(response, ToolMessage):
response.content = cast(
Union[str, list], msg_content_output(response.content)
)
return response
else:
raise TypeError(
f"Tool {call['name']} returned unexpected type: {type(response)}"
)
def _parse_input(
self,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
store: Optional[BaseStore],
) -> Tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
if isinstance(input, list):
if isinstance(input[-1], dict) and input[-1].get("type") == "tool_call":
input_type = "tool_calls"
tool_calls = input
return tool_calls, input_type
else:
input_type = "list"
message: AnyMessage = input[-1]
elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])):
input_type = "dict"
message = messages[-1]
elif messages := getattr(input, self.messages_key, None):
# Assume dataclass-like state that can coerce from dict
input_type = "dict"
message = messages[-1]
else:
raise ValueError("No message found in input")
if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")
tool_calls = [
self.inject_tool_args(call, input, store) for call in message.tool_calls
]
return tool_calls, input_type
def _validate_tool_call(self, call: ToolCall) -> Optional[ToolMessage]:
if (requested_tool := call["name"]) not in self.tools_by_name:
content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
requested_tool=requested_tool,
available_tools=", ".join(self.tools_by_name.keys()),
)
return ToolMessage(
content, name=requested_tool, tool_call_id=call["id"], status="error"
)
else:
return None
def _inject_state(
self,
tool_call: ToolCall,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
) -> ToolCall:
state_args = self.tool_to_state_args[tool_call["name"]]
if state_args and isinstance(input, list):
required_fields = list(state_args.values())
if (
len(required_fields) == 1
and required_fields[0] == self.messages_key
or required_fields[0] is None
):
input = {self.messages_key: input}
else:
err_msg = (
f"Invalid input to ToolNode. Tool {tool_call['name']} requires "
f"graph state dict as input."
)
if any(state_field for state_field in state_args.values()):
required_fields_str = ", ".join(f for f in required_fields if f)
err_msg += f" State should contain fields {required_fields_str}."
raise ValueError(err_msg)
if isinstance(input, dict):
tool_state_args = {
tool_arg: input[state_field] if state_field else input
for tool_arg, state_field in state_args.items()
}
else:
tool_state_args = {
tool_arg: getattr(input, state_field) if state_field else input
for tool_arg, state_field in state_args.items()
}
tool_call["args"] = {
**tool_call["args"],
**tool_state_args,
}
return tool_call
def _inject_store(
self, tool_call: ToolCall, store: Optional[BaseStore]
) -> ToolCall:
store_arg = self.tool_to_store_arg[tool_call["name"]]
if not store_arg:
return tool_call
if store is None:
raise ValueError(
"Cannot inject store into tools with InjectedStore annotations - "
"please compile your graph with a store."
)
tool_call["args"] = {
**tool_call["args"],
store_arg: store,
}
return tool_call
[docs]
def inject_tool_args(
self,
tool_call: ToolCall,
input: Union[
list[AnyMessage],
dict[str, Any],
BaseModel,
],
store: Optional[BaseStore],
) -> ToolCall:
"""Injects the state and store into the tool call.
Tool arguments with types annotated as `InjectedState` and `InjectedStore` are
ignored in tool schemas for generation purposes. This method injects them into
tool calls for tool invocation.
Args:
tool_call (ToolCall): The tool call to inject state and store into.
input (Union[list[AnyMessage], dict[str, Any], BaseModel]): The input state
to inject.
store (Optional[BaseStore]): The store to inject.
Returns:
ToolCall: The tool call with injected state and store.
"""
if tool_call["name"] not in self.tools_by_name:
return tool_call
tool_call_copy: ToolCall = copy(tool_call)
tool_call_with_state = self._inject_state(tool_call_copy, input)
tool_call_with_store = self._inject_store(tool_call_with_state, store)
return tool_call_with_store
def _validate_tool_command(
self,
command: Command,
call: ToolCall,
input_type: Literal["list", "dict", "tool_calls"],
) -> Command:
if isinstance(command.update, dict):
# input type is dict when ToolNode is invoked with a dict input (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
if input_type not in ("dict", "tool_calls"):
raise ValueError(
f"Tools can provide a dict in Command.update only when using dict with '{self.messages_key}' key as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)
updated_command = deepcopy(command)
state_update = cast(dict[str, Any], updated_command.update) or {}
messages_update = state_update.get(self.messages_key, [])
elif isinstance(command.update, list):
# input type is list when ToolNode is invoked with a list input (e.g. [AIMessage(..., tool_calls=[...])])
if input_type != "list":
raise ValueError(
f"Tools can provide a list of messages in Command.update only when using list of messages as ToolNode input, "
f"got: {command.update} for tool '{call['name']}'"
)
updated_command = deepcopy(command)
messages_update = updated_command.update
else:
return command
# convert to message objects if updates are in a dict format
messages_update = convert_to_messages(messages_update)
has_matching_tool_message = False
for message in messages_update:
if not isinstance(message, ToolMessage):
continue
if message.tool_call_id == call["id"]:
message.name = call["name"]
has_matching_tool_message = True
# validate that we always have a ToolMessage matching the tool call in
# Command.update if command is sent to the CURRENT graph
if updated_command.graph is None and not has_matching_tool_message:
example_update = (
'`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
if input_type == "dict"
else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
)
raise ValueError(
f"Expected to have a matching ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}. "
"Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. "
f"You can fix it by modifying the tool to return {example_update}."
)
return updated_command
def tools_condition(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
messages_key: str = "messages",
) -> Literal["tools", "__end__"]:
"""Use in the conditional_edge to route to the ToolNode if the last message
has tool calls. Otherwise, route to the end.
Args:
state (Union[list[AnyMessage], dict[str, Any], BaseModel]): The state to check for
tool calls. Must have a list of messages (MessageGraph) or have the
"messages" key (StateGraph).
Returns:
The next node to route to.
Examples:
Create a custom ReAct-style agent with tools.
```pycon
>>> from langchain_anthropic import ChatAnthropic
>>> from langchain_core.tools import tool
...
>>> from langgraph.graph import StateGraph
>>> from langgraph.prebuilt import ToolNode, tools_condition
>>> from langgraph.graph.message import add_messages
...
>>> from typing import Annotated
>>> from typing_extensions import TypedDict
...
>>> @tool
>>> def divide(a: float, b: float) -> int:
... \"\"\"Return a / b.\"\"\"
... return a / b
...
>>> llm = ChatAnthropic(model="claude-3-haiku-20240307")
>>> tools = [divide]
...
>>> class State(TypedDict):
... messages: Annotated[list, add_messages]
>>>
>>> graph_builder = StateGraph(State)
>>> graph_builder.add_node("tools", ToolNode(tools))
>>> graph_builder.add_node("chatbot", lambda state: {"messages":llm.bind_tools(tools).invoke(state['messages'])})
>>> graph_builder.add_edge("tools", "chatbot")
>>> graph_builder.add_conditional_edges(
... "chatbot", tools_condition
... )
>>> graph_builder.set_entry_point("chatbot")
>>> graph = graph_builder.compile()
>>> graph.invoke({"messages": {"role": "user", "content": "What's 329993 divided by 13662?"}})
```
"""
if isinstance(state, list):
ai_message = state[-1]
elif isinstance(state, dict) and (messages := state.get(messages_key, [])):
ai_message = messages[-1]
elif messages := getattr(state, messages_key, []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return "__end__"
class InjectedState(InjectedToolArg):
"""Annotation for a Tool arg that is meant to be populated with the graph state.
Any Tool argument annotated with InjectedState will be hidden from a tool-calling
model, so that the model doesn't attempt to generate the argument. If using
ToolNode, the appropriate graph state field will be automatically injected into
the model-generated tool args.
Args:
field: The key from state to insert. If None, the entire state is expected to
be passed in.
Example:
```python
from typing import List
from typing_extensions import Annotated, TypedDict
from langchain_core.messages import BaseMessage, AIMessage
from langchain_core.tools import tool
from langgraph.prebuilt import InjectedState, ToolNode
class AgentState(TypedDict):
messages: List[BaseMessage]
foo: str
@tool
def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
'''Do something with state.'''
if len(state["messages"]) > 2:
return state["foo"] + str(x)
else:
return "not enough messages"
@tool
def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
'''Do something else with state.'''
return foo + str(x + 1)
node = ToolNode([state_tool, foo_tool])
tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"}
state = {
"messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])],
"foo": "bar",
}
node.invoke(state)
```
```pycon
[
ToolMessage(content='not enough messages', name='state_tool', tool_call_id='1'),
ToolMessage(content='bar2', name='foo_tool', tool_call_id='2')
]
```
""" # noqa: E501
def __init__(self, field: Optional[str] = None) -> None:
self.field = field
class InjectedStore(InjectedToolArg):
"""Annotation for a Tool arg that is meant to be populated with LangGraph store.
Any Tool argument annotated with InjectedStore will be hidden from a tool-calling
model, so that the model doesn't attempt to generate the argument. If using
ToolNode, the appropriate store field will be automatically injected into
the model-generated tool args. Note: if a graph is compiled with a store object,
the store will be automatically propagated to the tools with InjectedStore args
when using ToolNode.
!!! Warning
`InjectedStore` annotation requires `langchain-core >= 0.3.8`
Example:
```python
from typing import Any
from typing_extensions import Annotated
from langchain_core.messages import AIMessage
from langchain_core.tools import tool
from langgraph.store.memory import InMemoryStore
from langgraph.prebuilt import InjectedStore, ToolNode
store = InMemoryStore()
store.put(("values",), "foo", {"bar": 2})
@tool
def store_tool(x: int, my_store: Annotated[Any, InjectedStore()]) -> str:
'''Do something with store.'''
stored_value = my_store.get(("values",), "foo").value["bar"]
return stored_value + x
node = ToolNode([store_tool])
tool_call = {"name": "store_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
state = {
"messages": [AIMessage("", tool_calls=[tool_call])],
}
node.invoke(state, store=store)
```
```pycon
{
"messages": [
ToolMessage(content='3', name='store_tool', tool_call_id='1'),
]
}
```
""" # noqa: E501
def _is_injection(
type_arg: Any, injection_type: Union[Type[InjectedState], Type[InjectedStore]]
) -> bool:
if isinstance(type_arg, injection_type) or (
isinstance(type_arg, type) and issubclass(type_arg, injection_type)
):
return True
origin_ = get_origin(type_arg)
if origin_ is Union or origin_ is Annotated:
return any(_is_injection(ta, injection_type) for ta in get_args(type_arg))
return False
def _get_state_args(tool: BaseTool) -> dict[str, Optional[str]]:
full_schema = tool.get_input_schema()
tool_args_to_state_fields: dict = {}
for name, type_ in get_all_basemodel_annotations(full_schema).items():
injections = [
type_arg
for type_arg in get_args(type_)
if _is_injection(type_arg, InjectedState)
]
if len(injections) > 1:
raise ValueError(
"A tool argument should not be annotated with InjectedState more than "
f"once. Received arg {name} with annotations {injections}."
)
elif len(injections) == 1:
injection = injections[0]
if isinstance(injection, InjectedState) and injection.field:
tool_args_to_state_fields[name] = injection.field
else:
tool_args_to_state_fields[name] = None
else:
pass
return tool_args_to_state_fields
def _get_store_arg(tool: BaseTool) -> Optional[str]:
full_schema = tool.get_input_schema()
for name, type_ in get_all_basemodel_annotations(full_schema).items():
injections = [
type_arg
for type_arg in get_args(type_)
if _is_injection(type_arg, InjectedStore)
]
if len(injections) > 1:
ValueError(
"A tool argument should not be annotated with InjectedStore more than "
f"once. Received arg {name} with annotations {injections}."
)
elif len(injections) == 1:
return name
else:
pass
return None