Skip to content

Commit

Permalink
langgraph: add 'messages_key' param to ToolNode / tools_condition (#2049
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vbarda authored Oct 22, 2024
1 parent 0042889 commit 916affa
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
23 changes: 13 additions & 10 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def msg_content_output(output: Any) -> str | List[dict]:
class ToolNode(RunnableCallable):
"""A node that runs the tools called in the last AIMessage.
It can be used either in StateGraph with a "messages" key or in MessageGraph. If
multiple tool calls are requested, they will be run in parallel. The output will be
It can be used either in StateGraph with a "messages" key (or a custom key passed via ToolNode's 'messages_key').
If multiple tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
The `ToolNode` is roughly analogous to:
Expand Down Expand Up @@ -102,12 +102,14 @@ def __init__(
name: str = "tools",
tags: Optional[list[str]] = None,
handle_tool_errors: Optional[bool] = True,
messages_key: str = "messages",
) -> None:
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
self.tool_to_state_args: Dict[str, Dict[str, Optional[str]]] = {}
self.tool_to_store_arg: Dict[str, Optional[str]] = {}
self.handle_tool_errors = handle_tool_errors
self.messages_key = messages_key
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = cast(BaseTool, create_tool(tool_))
Expand All @@ -131,7 +133,7 @@ def _func(
with get_executor_for_config(config) as executor:
outputs = [*executor.map(self._run_one, tool_calls, config_list)]
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {"messages": outputs}
return outputs if output_type == "list" else {self.messages_key: outputs}

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
Expand Down Expand Up @@ -163,7 +165,7 @@ async def _afunc(
*(self._arun_one(call, config) for call in tool_calls)
)
# TypedDict, pydantic, dataclass, etc. should all be able to load from dict
return outputs if output_type == "list" else {"messages": outputs}
return outputs if output_type == "list" else {self.messages_key: outputs}

def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
if invalid_tool_message := self._validate_tool_call(call):
Expand Down Expand Up @@ -214,10 +216,10 @@ def _parse_input(
if isinstance(input, list):
output_type = "list"
message: AnyMessage = input[-1]
elif isinstance(input, dict) and (messages := input.get("messages", [])):
elif isinstance(input, dict) and (messages := input.get(self.messages_key, [])):
output_type = "dict"
message = messages[-1]
elif messages := getattr(input, "messages", None):
elif messages := getattr(input, self.messages_key, None):
# Assume dataclass-like state that can coerce from dict
output_type = "dict"
message = messages[-1]
Expand Down Expand Up @@ -256,10 +258,10 @@ def _inject_state(
required_fields = list(state_args.values())
if (
len(required_fields) == 1
and required_fields[0] == "messages"
and required_fields[0] == self.messages_key
or required_fields[0] is None
):
input = {"messages": input}
input = {self.messages_key: input}
else:
err_msg = (
f"Invalid input to ToolNode. Tool {tool_call['name']} requires "
Expand Down Expand Up @@ -325,6 +327,7 @@ def _inject_tool_args(

def tools_condition(
state: Union[list[AnyMessage], dict[str, Any], BaseModel],
messages_key: str = "messages",
) -> Literal["tools", "__end__"]:
"""Use in the conditional_edge to route to the ToolNode if the last message
Expand Down Expand Up @@ -377,9 +380,9 @@ def tools_condition(
"""
if isinstance(state, list):
ai_message = state[-1]
elif isinstance(state, dict) and (messages := state.get("messages", [])):
elif isinstance(state, dict) and (messages := state.get(messages_key, [])):
ai_message = messages[-1]
elif messages := getattr(state, "messages", []):
elif messages := getattr(state, messages_key, []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
Expand Down
53 changes: 50 additions & 3 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import json
from functools import partial
from typing import (
Annotated,
Any,
Expand Down Expand Up @@ -35,8 +36,13 @@
from typing_extensions import TypedDict

from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, ValidationNode, create_react_agent
from langgraph.graph import START, MessagesState, StateGraph, add_messages
from langgraph.prebuilt import (
ToolNode,
ValidationNode,
create_react_agent,
tools_condition,
)
from langgraph.prebuilt.tool_node import InjectedState, InjectedStore
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
Expand All @@ -46,7 +52,7 @@
IS_LANGCHAIN_CORE_030_OR_GREATER,
awith_checkpointer,
)
from tests.messages import _AnyIdHumanMessage
from tests.messages import _AnyIdHumanMessage, _AnyIdToolMessage

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -826,6 +832,47 @@ def get_day_list(days: list[str]) -> list[str]:
assert outputs[0].content == json.dumps(data, ensure_ascii=False)


def test_tool_node_messages_key() -> None:
@dec_tool
def add(a: int, b: int):
"""Adds a and b."""
return a + b

model = FakeToolCallingModel(
tool_calls=[[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")]]
)

class State(TypedDict):
subgraph_messages: Annotated[list[AnyMessage], add_messages]

def call_model(state: State):
response = model.invoke(state["subgraph_messages"])
model.tool_calls = []
return {"subgraph_messages": response}

builder = StateGraph(State)
builder.add_node("agent", call_model)
builder.add_node("tools", ToolNode([add], messages_key="subgraph_messages"))
builder.add_conditional_edges(
"agent", partial(tools_condition, messages_key="subgraph_messages")
)
builder.add_edge(START, "agent")
builder.add_edge("tools", "agent")

graph = builder.compile()
result = graph.invoke({"subgraph_messages": [HumanMessage(content="hi")]})
assert result["subgraph_messages"] == [
_AnyIdHumanMessage(content="hi"),
AIMessage(
content="hi",
id="0",
tool_calls=[ToolCall(name=add.name, args={"a": 1, "b": 2}, id="test_id")],
),
_AnyIdToolMessage(content="3", name=add.name, tool_call_id="test_id"),
AIMessage(content="hi-hi-3", id="1"),
]


async def test_return_direct() -> None:
@dec_tool(return_direct=True)
def tool_return_direct(input: str) -> str:
Expand Down

0 comments on commit 916affa

Please sign in to comment.