Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Checkpointers] MemorySaver: refrain from overwriting writes #2399

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio

Check notice on line 1 in libs/checkpoint/langgraph/checkpoint/memory/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 57.6 ms +- 1.3 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 49.4 ms +- 0.5 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 83.5 ms +- 1.5 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 91.5 ms +- 1.1 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 537 ms +- 11 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 481 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 884 ms +- 40 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 902 ms +- 18 ms ......................................... react_agent_10x: Mean +- std dev: 30.3 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.2 ms +- 0.3 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.5 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.5 ms +- 0.5 ms ......................................... react_agent_100x: Mean +- std dev: 340 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 272 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 937 ms +- 15 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 836 ms +- 15 ms ......................................... wide_state_25x300: Mean +- std dev: 24.0 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.1 ms +- 0.2 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 276 ms +- 3 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 266 ms +- 6 ms ......................................... wide_state_15x600: Mean +- std dev: 27.9 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.5 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 477 ms +- 5 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 464 ms +- 7 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.7 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.4 ms +- 0.5 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 311 ms +- 5 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 297 ms +- 4 ms

Check notice on line 1 in libs/checkpoint/langgraph/checkpoint/memory/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +====================================+=========+=======================+ | react_agent_100x_checkpoint_sync | 853 ms | 836 ms: 1.02x faster | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 543 ms | 537 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 894 ms | 884 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 947 ms | 937 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 279 ms | 276 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 480 ms | 477 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 299 ms | 297 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 467 ms | 464 ms: 1.01x faster | +------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 482 ms | 481 ms: 1.00x faster | +------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.4 ms | 17.5 ms: 1.00x slower | +------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 270 ms | 272 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | react_agent_100x | 337 ms | 340 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_15x600 | 27.6 ms | 27.9 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | wide_state_25x300 | 23.7 ms | 24.0 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.0 ms | 36.5 ms: 1.01x slower | +------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x faster | +------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (13): fanout_to_subgraph_10x_checkpoint, fanout_to_subgraph_10x, fanout_to_subgraph_100x_checkpoint_sync, wide_state_9x1200_checkpoint, fanout_to_subgraph_10x_sync, react_agent_10x, fanout_to_subgraph_10x_checkpoint_sync, react_agent_10x_sync, wide_state_25x300_checkpoint_sync, react_agent_10x_checkpoint, wide_state_25x300_sync, wide_state_9x1200_sync, wide_state_9x1200
import random
from collections import defaultdict
from contextlib import AbstractAsyncContextManager, AbstractContextManager
Expand Down Expand Up @@ -364,8 +364,12 @@
checkpoint_ns = config["configurable"]["checkpoint_ns"]
checkpoint_id = config["configurable"]["checkpoint_id"]
outer_key = (thread_id, checkpoint_ns, checkpoint_id)
outer_writes_ = self.writes.get(outer_key)
for idx, (c, v) in enumerate(writes):
inner_key = (task_id, WRITES_IDX_MAP.get(c, idx))
if inner_key[1] >= 0 and outer_writes_ and inner_key in outer_writes_:
continue

self.writes[outer_key][inner_key] = (task_id, c, self.serde.dumps_typed(v))

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint"
version = "2.0.2"
version = "2.0.3"
description = "Library with base interfaces for LangGraph checkpoint savers."
authors = []
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph"
version = "0.2.45"
version = "0.2.46"
description = "Building stateful, multi-actor applications with LLMs"
authors = []
license = "MIT"
Expand Down
80 changes: 80 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,86 @@ def test_fork_always_re_runs_nodes(
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_run_from_checkpoint_id_retains_previous_writes(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class MyState(TypedDict):
myval: Annotated[int, operator.add]
otherval: bool

class Anode:
def __init__(self):
self.switch = False

def __call__(self, state: MyState):
self.switch = not self.switch
return {"myval": 2 if self.switch else 1, "otherval": self.switch}

builder = StateGraph(MyState)
thenode = Anode() # Fun.
builder.add_node("node_one", thenode)
builder.add_node("node_two", thenode)
builder.add_edge(START, "node_one")

def _getedge(src: str):
swap = "node_one" if src == "node_two" else "node_two"

def _edge(st: MyState) -> Literal["__end__", "node_one", "node_two"]:
if st["myval"] > 3:
return END
if st["otherval"]:
return swap
return src

return _edge

builder.add_conditional_edges("node_one", _getedge("node_one"))
builder.add_conditional_edges("node_two", _getedge("node_two"))
graph = builder.compile(checkpointer=checkpointer)

thread_id = uuid.uuid4()
thread1 = {"configurable": {"thread_id": str(thread_id)}}

result = graph.invoke({"myval": 1}, thread1)
assert result["myval"] == 4
history = [c for c in graph.get_state_history(thread1)]

assert len(history) == 4
assert history[-1].values == {"myval": 0}
assert history[0].values == {"myval": 4, "otherval": False}

second_run_config = {
**thread1,
"configurable": {
**thread1["configurable"],
"checkpoint_id": history[1].config["configurable"]["checkpoint_id"],
},
}
second_result = graph.invoke(None, second_run_config)
assert second_result == {"myval": 5, "otherval": True}

new_history = [
c
for c in graph.get_state_history(
{"configurable": {"thread_id": str(thread_id), "checkpoint_ns": ""}}
)
]

assert len(new_history) == len(history) + 1
for original, new in zip(history, new_history[1:]):
assert original.values == new.values
assert original.next == new.next
assert original.metadata["step"] == new.metadata["step"]

def _get_tasks(hist: list, start: int):
return [h.tasks for h in hist[start:]]

assert _get_tasks(new_history, 1) == _get_tasks(history, 0)


def test_invoke_two_processes_in_dict_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
Expand Down
79 changes: 79 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,85 @@ def reset(self):
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_run_from_checkpoint_id_retains_previous_writes(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
class MyState(TypedDict):
myval: Annotated[int, operator.add]
otherval: bool

class Anode:
def __init__(self):
self.switch = False

async def __call__(self, state: MyState):
self.switch = not self.switch
return {"myval": 2 if self.switch else 1, "otherval": self.switch}

builder = StateGraph(MyState)
thenode = Anode() # Fun.
builder.add_node("node_one", thenode)
builder.add_node("node_two", thenode)
builder.add_edge(START, "node_one")

def _getedge(src: str):
swap = "node_one" if src == "node_two" else "node_two"

def _edge(st: MyState) -> Literal["__end__", "node_one", "node_two"]:
if st["myval"] > 3:
return END
if st["otherval"]:
return swap
return src

return _edge

builder.add_conditional_edges("node_one", _getedge("node_one"))
builder.add_conditional_edges("node_two", _getedge("node_two"))
async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer)

thread_id = uuid.uuid4()
thread1 = {"configurable": {"thread_id": str(thread_id)}}

result = await graph.ainvoke({"myval": 1}, thread1)
assert result["myval"] == 4
history = [c async for c in graph.aget_state_history(thread1)]

assert len(history) == 4
assert history[-1].values == {"myval": 0}
assert history[0].values == {"myval": 4, "otherval": False}

second_run_config = {
**thread1,
"configurable": {
**thread1["configurable"],
"checkpoint_id": history[1].config["configurable"]["checkpoint_id"],
},
}
second_result = await graph.ainvoke(None, second_run_config)
assert second_result == {"myval": 5, "otherval": True}

new_history = [
c
async for c in graph.aget_state_history(
{"configurable": {"thread_id": str(thread_id), "checkpoint_ns": ""}}
)
]

assert len(new_history) == len(history) + 1
for original, new in zip(history, new_history[1:]):
assert original.values == new.values
assert original.next == new.next
assert original.metadata["step"] == new.metadata["step"]

def _get_tasks(hist: list, start: int):
return [h.tasks for h in hist[start:]]

assert _get_tasks(new_history, 1) == _get_tasks(history, 0)


async def test_cond_edge_after_send() -> None:
class Node:
def __init__(self, name: str):
Expand Down
Loading