Skip to content

Commit

Permalink
lib: Add interrupts to stream_mode=updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Oct 11, 2024
1 parent 0557fb0 commit c6a450b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 18 deletions.
11 changes: 3 additions & 8 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,7 @@ def tick(
# after execution, check if we should interrupt
if should_interrupt(self.checkpoint, interrupt_after, self.tasks.values()):
self.status = "interrupt_after"
if self.is_nested:
raise GraphInterrupt()
else:
return False
raise GraphInterrupt()
else:
return False

Expand Down Expand Up @@ -441,10 +438,7 @@ def tick(
# before execution, check if we should interrupt
if should_interrupt(self.checkpoint, interrupt_before, self.tasks.values()):
self.status = "interrupt_before"
if self.is_nested:
raise GraphInterrupt()
else:
return False
raise GraphInterrupt()

# produce debug output
self._emit("debug", map_debug_tasks, self.step, self.tasks.values())
Expand Down Expand Up @@ -598,6 +592,7 @@ def _suppress_interrupt(
self.output = read_channels(self.channels, self.output_keys)
if suppress:
# suppress interrupt
self._emit("updates", lambda: iter([{INTERRUPT: exc_value.args[0]}]))
return True

def _emit(
Expand Down
39 changes: 37 additions & 2 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ def test_invoke_two_processes_in_out_interrupt(
]
assert [c for c in app.stream(None, history[2].config, stream_mode="updates")] == [
{"one": {"inbox": 4}},
{"__interrupt__": ()},
]


Expand Down Expand Up @@ -3198,6 +3199,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -3295,6 +3297,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]

with assert_ctx_once():
Expand Down Expand Up @@ -3365,6 +3368,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -3460,6 +3464,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]

app_w_interrupt.update_state(
Expand Down Expand Up @@ -3520,7 +3525,9 @@ def should_continue(data: AgentState) -> str:

assert [
c for c in app_w_interrupt.stream({"input": "what is weather in sf"}, config)
] == []
] == [
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
values={
Expand All @@ -3542,6 +3549,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -3587,6 +3595,7 @@ def should_continue(data: AgentState) -> str:
],
}
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -3641,6 +3650,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]

# test w interrupt after all
Expand All @@ -3661,6 +3671,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -3706,6 +3717,7 @@ def should_continue(data: AgentState) -> str:
],
}
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -3760,6 +3772,7 @@ def should_continue(data: AgentState) -> str:
),
}
},
{"__interrupt__": ()},
]


Expand Down Expand Up @@ -4630,6 +4643,7 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState:
)
}
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -4759,6 +4773,7 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState:
)
},
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -5130,6 +5145,7 @@ def should_continue(messages):
id="ai1",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -5241,6 +5257,7 @@ def should_continue(messages):
id="ai2",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -5360,6 +5377,7 @@ def should_continue(messages):
id="ai1",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -5471,6 +5489,7 @@ def should_continue(messages):
id="ai2",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -5856,6 +5875,7 @@ class State(TypedDict):
id="ai1",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -5967,6 +5987,7 @@ class State(TypedDict):
id="ai2",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -6088,6 +6109,7 @@ class State(TypedDict):
id="ai1",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -6199,6 +6221,7 @@ class State(TypedDict):
id="ai2",
)
},
{"__interrupt__": ()},
]

assert app_w_interrupt.get_state(config) == StateSnapshot(
Expand Down Expand Up @@ -7653,6 +7676,7 @@ def qa(data: State) -> State:
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]

assert [c for c in app_w_interrupt.stream(None, config)] == [
Expand All @@ -7672,6 +7696,7 @@ def qa(data: State) -> State:
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]

app_w_interrupt.update_state(config, {"docs": ["doc5"]})
Expand Down Expand Up @@ -7785,6 +7810,7 @@ def rewrite_query_then(data: State) -> Literal["retriever_two"]:
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]

assert [c for c in app_w_interrupt.stream(None, config)] == [
Expand Down Expand Up @@ -7941,6 +7967,7 @@ def decider(data: State) -> str:
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]

with assert_ctx_once():
Expand Down Expand Up @@ -8109,6 +8136,7 @@ def decider(data: State) -> str:
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]

with assert_ctx_once():
Expand Down Expand Up @@ -8217,6 +8245,7 @@ def qa(data: State) -> State:
{"analyzer_one": {"query": "analyzed: query: what is weather in sf"}},
{"retriever_two": {"docs": ["doc3", "doc4"]}},
{"retriever_one": {"docs": ["doc1", "doc2"]}},
{"__interrupt__": ()},
]

assert [c for c in app_w_interrupt.stream(None, config)] == [
Expand Down Expand Up @@ -8779,6 +8808,7 @@ def outer_2(state: State):
# we got to parallel node first
((), {"outer_1": {"my_key": " and parallel"}}),
((AnyStr("inner:"),), {"inner_1": {"my_key": "got here", "my_other_key": ""}}),
((), {"__interrupt__": ()}),
]
assert [*app.stream(None, config)] == [
{"outer_1": {"my_key": " and parallel"}, "__metadata__": {"cached": True}},
Expand Down Expand Up @@ -8898,6 +8928,7 @@ def parent_2(state: State):
config = {"configurable": {"thread_id": "2"}}
assert [*app.stream({"my_key": "my value"}, config)] == [
{"parent_1": {"my_key": "hi my value"}},
{"__interrupt__": ()},
]
assert [*app.stream(None, config)] == [
{"child": {"my_key": "hi my value here and there"}},
Expand Down Expand Up @@ -9493,6 +9524,7 @@ def parent_2(state: State):
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_1": {"my_key": "hi my value here"}},
),
((), {"__interrupt__": ()}),
]
# get state without subgraphs
outer_state = app.get_state(config)
Expand Down Expand Up @@ -10192,7 +10224,8 @@ def parent_2(state: State):
(
(AnyStr("child:"), AnyStr("child_1:")),
{"grandchild_1": {"my_key": "hi my value here"}},
)
),
((), {"__interrupt__": ()}),
]


Expand Down Expand Up @@ -10644,6 +10677,7 @@ def weather_graph(state: RouterState):
] == [
((), {"router_node": {"route": "weather"}}),
((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}),
((), {"__interrupt__": ()}),
]

# check current state
Expand Down Expand Up @@ -10732,6 +10766,7 @@ def weather_graph(state: RouterState):
] == [
((), {"router_node": {"route": "weather"}}),
((AnyStr("weather_graph:"),), {"model_node": {"city": "San Francisco"}}),
((), {"__interrupt__": ()}),
]
state = graph.get_state(config, subgraphs=True)
assert state == StateSnapshot(
Expand Down
Loading

0 comments on commit c6a450b

Please sign in to comment.