diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 81257c883..d2d693bef 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -1164,7 +1164,7 @@ async def aupdate_state( managed, ): # no values, just clear all tasks - if values is None and as_node is None: + if values is None and as_node == END: if saved is not None: # tasks for this checkpoint next_tasks = prepare_next_tasks( @@ -1217,6 +1217,25 @@ async def aupdate_state( return patch_checkpoint_map( next_config, saved.metadata if saved else None ) + # no values, copy checkpoint + if values is None and as_node is None: + next_checkpoint = create_checkpoint(checkpoint, None, step) + # copy checkpoint + next_config = await checkpointer.aput( + checkpoint_config, + next_checkpoint, + { + **checkpoint_metadata, + "source": "update", + "step": step + 1, + "writes": {}, + "parents": saved.metadata.get("parents", {}) if saved else {}, + }, + {}, + ) + return patch_checkpoint_map( + next_config, saved.metadata if saved else None + ) # apply pending writes, if not on specific checkpoint if ( CONFIG_KEY_CHECKPOINT_ID not in config[CONF] diff --git a/libs/langgraph/pyproject.toml b/libs/langgraph/pyproject.toml index 81af93cf3..8619fde99 100644 --- a/libs/langgraph/pyproject.toml +++ b/libs/langgraph/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langgraph" -version = "0.2.49" +version = "0.2.50" description = "Building stateful, multi-actor applications with LLMs" authors = [] license = "MIT" diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index fa8e1d7f1..a31e444e1 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -404,7 +404,7 @@ async def tool_two_node(s: State) -> State: ) # clear the interrupt and next tasks - await tool_two.aupdate_state(thread1, None) + await tool_two.aupdate_state(thread1, None, as_node=END) # interrupt is cleared, as well as the next tasks tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( @@ -426,6 +426,181 @@ async def tool_two_node(s: State) -> State: ) +@pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled") +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.11+ is required for async contextvars support", +) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_copy_checkpoint(checkpointer_name: str) -> None: + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + def tool_one(s: State) -> State: + return {"my_key": " one"} + + tool_two_node_count = 0 + + def tool_two_node(s: State) -> State: + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + return {"my_key": answer} + + def start(state: State) -> list[Union[Send, str]]: + return ["tool_two", Send("tool_one", state)] + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy()) + tool_two_graph.add_node("tool_one", tool_one) + tool_two_graph.set_conditional_entry_point(start) + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert await tool_two.ainvoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value one", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value one"} + + assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { + "my_key": "value one all good", + "market": "US", + } + + async with awith_checkpointer(checkpointer_name) as checkpointer: + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + await tool_two.ainvoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread2 + ) + ] == [ + { + "tool_one": {"my_key": " one"}, + }, + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:")], + ), + ) + }, + ] + # resume with answer + assert [ + c async for c in tool_two.astream(Command(resume=" my answer"), thread2) + ] == [ + {"tool_two": {"my_key": " my answer"}}, + ] + + # flow: interrupt -> clear tasks + thread1 = {"configurable": {"thread_id": "1"}} + # stop when about to enter node + assert await tool_two.ainvoke( + {"my_key": "value ⛰️", "market": "DE"}, thread1 + ) == { + "my_key": "value ⛰️ one", + "market": "DE", + } + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️ one", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:")], + ), + ), + ), + ), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1, limit=2) + ][-1].config, + ) + # clear the interrupt and next tasks + await tool_two.aupdate_state(thread1, None) + # interrupt is cleared, next task is kept + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️ one", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=(), + ), + ), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1, limit=2) + ][-1].config, + ) + + @pytest.mark.skipif( sys.version_info < (3, 11), reason="Python 3.11+ is required for async contextvars support",