Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test caching #5

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ jobs:
TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }}
LANGSMITH_API_KEY: ${{ secrets.LANGSMITH_API_KEY }}
LANGSMITH_TRACING: true
LANGSMITH_TEST_CACHE: tests/cassettes
run: |
uv run pytest tests/integration_tests
39 changes: 17 additions & 22 deletions src/react_agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, List, Literal, cast

from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
Expand Down Expand Up @@ -36,25 +35,21 @@ async def call_model(
"""
configuration = Configuration.from_runnable_config(config)

# Create a prompt template. Customize this to change the agent's behavior.
prompt = ChatPromptTemplate.from_messages(
[("system", configuration.system_prompt), ("placeholder", "{messages}")]
)

# Initialize the model with tool binding. Change the model or add more tools here.
model = load_chat_model(configuration.model).bind_tools(TOOLS)

# Prepare the input for the model, including the current system time
message_value = await prompt.ainvoke(
{
"messages": state.messages,
"system_time": datetime.now(tz=timezone.utc).isoformat(),
},
config,
# Format the system prompt. Customize this to change the agent's behavior.
system_message = configuration.system_prompt.format(
system_time=datetime.now(tz=timezone.utc).isoformat()
)

# Get the model's response
response = cast(AIMessage, await model.ainvoke(message_value, config))
response = cast(
AIMessage,
await model.ainvoke(
[{"role": "system", "content": system_message}, *state.messages], config
),
)

# Handle the case when it's the last step and the model still wants to use a tool
if state.is_last_step and response.tool_calls:
Expand All @@ -73,15 +68,15 @@ async def call_model(

# Define a new graph

workflow = StateGraph(State, input=InputState, config_schema=Configuration)
builder = StateGraph(State, input=InputState, config_schema=Configuration)

# Define the two nodes we will cycle between
workflow.add_node(call_model)
workflow.add_node("tools", ToolNode(TOOLS))
builder.add_node(call_model)
builder.add_node("tools", ToolNode(TOOLS))

# Set the entrypoint as `call_model`
# This means that this node is the first one called
workflow.add_edge("__start__", "call_model")
builder.add_edge("__start__", "call_model")


def route_model_output(state: State) -> Literal["__end__", "tools"]:
Expand All @@ -108,7 +103,7 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:


# Add a conditional edge to determine the next step after `call_model`
workflow.add_conditional_edges(
builder.add_conditional_edges(
"call_model",
# After call_model finishes running, the next node(s) are scheduled
# based on the output from route_model_output
Expand All @@ -117,11 +112,11 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:

# Add a normal edge from `tools` to `call_model`
# This creates a cycle: after using tools, we always return to the model
workflow.add_edge("tools", "call_model")
builder.add_edge("tools", "call_model")

# Compile the workflow into an executable graph
# Compile the builder into an executable graph
# You can customize this by adding interrupt points for state updates
graph = workflow.compile(
graph = builder.compile(
interrupt_before=[], # Add node names here to update state before they're called
interrupt_after=[], # Add node names here to update state after they're called
)
Expand Down
Loading
Loading