Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Checkpointers] MemorySaver: refrain from overwriting writes #2399

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import asyncio

Check notice on line 1 in libs/checkpoint/langgraph/checkpoint/memory/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 57.3 ms +- 1.6 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 48.9 ms +- 0.5 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 82.8 ms +- 1.6 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 90.0 ms +- 1.0 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 525 ms +- 10 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 476 ms +- 6 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 855 ms +- 53 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 886 ms +- 16 ms ......................................... react_agent_10x: Mean +- std dev: 30.1 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.1 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 46.0 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.0 ms +- 0.5 ms ......................................... react_agent_100x: Mean +- std dev: 336 ms +- 7 ms ......................................... react_agent_100x_sync: Mean +- std dev: 268 ms +- 2 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 933 ms +- 8 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 841 ms +- 10 ms ......................................... wide_state_25x300: Mean +- std dev: 23.6 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.1 ms +- 0.3 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 277 ms +- 3 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 265 ms +- 3 ms ......................................... wide_state_15x600: Mean +- std dev: 27.5 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.3 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 478 ms +- 4 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 465 ms +- 8 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.5 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.3 ms +- 0.2 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 311 ms +- 3 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 298 ms +- 4 ms

Check notice on line 1 in libs/checkpoint/langgraph/checkpoint/memory/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x_checkpoint | 894 ms | 855 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 543 ms | 525 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 905 ms | 886 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 91.6 ms | 90.0 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 947 ms | 933 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 853 ms | 841 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 49.5 ms | 48.9 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 83.8 ms | 82.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 482 ms | 476 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 46.5 ms | 46.0 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 57.7 ms | 57.3 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x | 30.3 ms | 30.1 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 270 ms | 268 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.4 ms | 17.3 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 480 ms | 478 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 299 ms | 298 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 23.7 ms | 23.6 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 467 ms | 465 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 279 ms | 277 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.2 ms | 22.1 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.4 ms | 17.3 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 312 ms | 311 ms: 1.00x faster | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.01x faster | +-----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (6): wide_state_9x1200, wide_state_15x600, react_agent_100x, wide_state_25x300_sync, react_agent_10x_checkpoint_sync, wide_state_25x300_checkpoint_sync
import random
from collections import defaultdict
from contextlib import AbstractAsyncContextManager, AbstractContextManager
Expand All @@ -19,6 +19,9 @@
get_checkpoint_id,
)
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol
import logging

_LOGGER = logging.getLogger(__name__)

Check failure on line 24 in libs/checkpoint/langgraph/checkpoint/memory/__init__.py

View workflow job for this annotation

GitHub Actions / cd libs/checkpoint / lint #3.12

Ruff (I001)

langgraph/checkpoint/memory/__init__.py:1:1: I001 Import block is un-sorted or un-formatted


class MemorySaver(
Expand Down Expand Up @@ -366,6 +369,16 @@
outer_key = (thread_id, checkpoint_ns, checkpoint_id)
for idx, (c, v) in enumerate(writes):
inner_key = (task_id, WRITES_IDX_MAP.get(c, idx))
if (
inner_key[1] >= 0
and (outer_writes := self.writes.get(outer_key))
and inner_key in outer_writes
):
_LOGGER.debug(
f"Skipping duplicate writes for [{outer_key}][{inner_key}]"
)
continue

self.writes[outer_key][inner_key] = (task_id, c, self.serde.dumps_typed(v))

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint"
version = "2.0.2"
version = "2.0.3"
description = "Library with base interfaces for LangGraph checkpoint savers."
authors = []
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph"
version = "0.2.45"
version = "0.2.46"
description = "Building stateful, multi-actor applications with LLMs"
authors = []
license = "MIT"
Expand Down
80 changes: 80 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,86 @@ def test_fork_always_re_runs_nodes(
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_run_from_checkpoint_id_retains_previous_writes(
request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class MyState(TypedDict):
myval: Annotated[int, operator.add]
otherval: bool

class Anode:
def __init__(self):
self.switch = False

def __call__(self, state: MyState):
self.switch = not self.switch
return {"myval": 2 if self.switch else 1, "otherval": self.switch}

builder = StateGraph(MyState)
thenode = Anode() # Fun.
builder.add_node("node_one", thenode)
builder.add_node("node_two", thenode)
builder.add_edge(START, "node_one")

def _getedge(src: str):
swap = "node_one" if src == "node_two" else "node_two"

def _edge(st: MyState) -> Literal["__end__", "node_one", "node_two"]:
if st["myval"] > 3:
return END
if st["otherval"]:
return swap
return src

return _edge

builder.add_conditional_edges("node_one", _getedge("node_one"))
builder.add_conditional_edges("node_two", _getedge("node_two"))
graph = builder.compile(checkpointer=checkpointer)

thread_id = uuid.uuid4()
thread1 = {"configurable": {"thread_id": str(thread_id)}}

result = graph.invoke({"myval": 1}, thread1)
assert result["myval"] == 4
history = [c for c in graph.get_state_history(thread1)]

assert len(history) == 4
assert history[-1].values == {"myval": 0}
assert history[0].values == {"myval": 4, "otherval": False}

second_run_config = {
**thread1,
"configurable": {
**thread1["configurable"],
"checkpoint_id": history[1].config["configurable"]["checkpoint_id"],
},
}
second_result = graph.invoke(None, second_run_config)
assert second_result == {"myval": 5, "otherval": True}

new_history = [
c
for c in graph.get_state_history(
{"configurable": {"thread_id": str(thread_id), "checkpoint_ns": ""}}
)
]

assert len(new_history) == len(history) + 1
for original, new in zip(history, new_history[1:]):
assert original.values == new.values
assert original.next == new.next
assert original.metadata["step"] == new.metadata["step"]

def _get_tasks(hist: list, start: int):
return [h.tasks for h in hist[start:]]

assert _get_tasks(new_history, 1) == _get_tasks(history, 0)


def test_invoke_two_processes_in_dict_out(mocker: MockerFixture) -> None:
add_one = mocker.Mock(side_effect=lambda x: x + 1)
one = Channel.subscribe_to("input") | add_one | Channel.write_to("inbox")
Expand Down
Loading