Skip to content

Commit

Permalink
0.2.50
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 15, 2024
1 parent 38d93a3 commit 5494855
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 3 deletions.
21 changes: 20 additions & 1 deletion libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ async def aupdate_state(
managed,
):
# no values, just clear all tasks
if values is None and as_node is None:
if values is None and as_node == END:
if saved is not None:
# tasks for this checkpoint
next_tasks = prepare_next_tasks(
Expand Down Expand Up @@ -1217,6 +1217,25 @@ async def aupdate_state(
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# no values, copy checkpoint
if values is None and as_node is None:
next_checkpoint = create_checkpoint(checkpoint, None, step)
# copy checkpoint
next_config = await checkpointer.aput(
checkpoint_config,
next_checkpoint,
{
**checkpoint_metadata,
"source": "update",
"step": step + 1,
"writes": {},
"parents": saved.metadata.get("parents", {}) if saved else {},
},
{},
)
return patch_checkpoint_map(
next_config, saved.metadata if saved else None
)
# apply pending writes, if not on specific checkpoint
if (
CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph"
version = "0.2.49"
version = "0.2.50"
description = "Building stateful, multi-actor applications with LLMs"
authors = []
license = "MIT"
Expand Down
177 changes: 176 additions & 1 deletion libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def tool_two_node(s: State) -> State:
)

# clear the interrupt and next tasks
await tool_two.aupdate_state(thread1, None)
await tool_two.aupdate_state(thread1, None, as_node=END)
# interrupt is cleared, as well as the next tasks
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
Expand All @@ -426,6 +426,181 @@ async def tool_two_node(s: State) -> State:
)


@pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled")
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_copy_checkpoint(checkpointer_name: str) -> None:
class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

def tool_one(s: State) -> State:
return {"my_key": " one"}

tool_two_node_count = 0

def tool_two_node(s: State) -> State:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}

def start(state: State) -> list[Union[Send, str]]:
return ["tool_two", Send("tool_one", state)]

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", tool_two_node, retry=RetryPolicy())
tool_two_graph.add_node("tool_one", tool_one)
tool_two_graph.set_conditional_entry_point(start)
tool_two = tool_two_graph.compile()

tracer = FakeTracer()
assert await tool_two.ainvoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value one",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value one"}

assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == {
"my_key": "value one all good",
"market": "US",
}

async with awith_checkpointer(checkpointer_name) as checkpointer:
tool_two = tool_two_graph.compile(checkpointer=checkpointer)

# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
await tool_two.ainvoke({"my_key": "value", "market": "DE"})

# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c
async for c in tool_two.astream(
{"my_key": "value ⛰️", "market": "DE"}, thread2
)
] == [
{
"tool_one": {"my_key": " one"},
},
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
)
},
]
# resume with answer
assert [
c async for c in tool_two.astream(Command(resume=" my answer"), thread2)
] == [
{"tool_two": {"my_key": " my answer"}},
]

# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert await tool_two.ainvoke(
{"my_key": "value ⛰️", "market": "DE"}, thread1
) == {
"my_key": "value ⛰️ one",
"market": "DE",
}
assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:")],
),
),
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": {"tool_one": {"my_key": " one"}},
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)
# clear the interrupt and next tasks
await tool_two.aupdate_state(thread1, None)
# interrupt is cleared, next task is kept
tup = await tool_two.checkpointer.aget_tuple(thread1)
assert await tool_two.aget_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️ one", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(),
),
),
config=tup.config,
created_at=tup.checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[
c async for c in tool_two.checkpointer.alist(thread1, limit=2)
][-1].config,
)


@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Python 3.11+ is required for async contextvars support",
Expand Down

0 comments on commit 5494855

Please sign in to comment.