Skip to content

Commit

Permalink
Merge pull request #1735 from langchain-ai/nc/16sep/stream-subgraph-i…
Browse files Browse the repository at this point in the history
…n-progress

Stream subgraph output while it executes
  • Loading branch information
nfcampos authored Sep 17, 2024
2 parents bfd18f7 + f59435a commit c4d4d61
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 25 deletions.
14 changes: 10 additions & 4 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from langgraph.pregel.validate import validate_graph, validate_keys
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.utils.aio import Queue
from langgraph.utils.config import (
ensure_config,
merge_configs,
Expand Down Expand Up @@ -1323,11 +1324,14 @@ async def astream(
```
"""

stream = deque()
stream = Queue()

def output() -> Iterator:
while stream:
ns, mode, payload = stream.popleft()
while True:
try:
ns, mode, payload = stream.get_nowait()
except asyncio.QueueEmpty:
break
if subgraphs and isinstance(stream_mode, list):
yield (ns, mode, payload)
elif isinstance(stream_mode, list):
Expand All @@ -1337,6 +1341,7 @@ def output() -> Iterator:
else:
yield payload

aioloop = asyncio.get_event_loop()
config = ensure_config(self.config, config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
Expand Down Expand Up @@ -1379,7 +1384,7 @@ def output() -> Iterator:
)
async with AsyncPregelLoop(
input,
stream=StreamProtocol(stream.append, stream_modes),
stream=StreamProtocol(stream.put_nowait, stream_modes),
config=config,
store=self.store,
checkpointer=checkpointer,
Expand Down Expand Up @@ -1412,6 +1417,7 @@ def output() -> Iterator:
loop.tasks.values(),
timeout=self.step_timeout,
retry_policy=self.retry_policy,
get_waiter=lambda: aioloop.create_task(stream.wait()),
):
# emit output
for o in output():
Expand Down
54 changes: 37 additions & 17 deletions libs/langgraph/langgraph/pregel/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,37 @@ async def atick(
reraise: bool = True,
timeout: Optional[float] = None,
retry_policy: Optional[RetryPolicy] = None,
get_waiter: Optional[Callable[[], asyncio.Future[None]]] = None,
) -> AsyncIterator[None]:
loop = asyncio.get_event_loop()
# give control back to the caller
yield
# add waiter task if requested
if get_waiter is not None:
futures: dict[asyncio.Future, Optional[PregelExecutableTask]] = {
get_waiter(): None
}
else:
futures = {}
# execute tasks, and wait for one to fail or all to finish.
# each task is independent from all other concurrent tasks
# yield updates/debug output as each task finishes
futures = {
self.submit(
arun_with_retry,
task,
retry_policy,
stream=self.use_astream,
__name__=task.name,
__cancel_on_exit__=True,
__reraise_on_exit__=reraise,
): task
for task in tasks
if not task.writes
}
for task in tasks:
if not task.writes:
futures[
self.submit(
arun_with_retry,
task,
retry_policy,
stream=self.use_astream,
__name__=task.name,
__cancel_on_exit__=True,
__reraise_on_exit__=reraise,
)
] = task
all_futures = futures.copy()
end_time = timeout + loop.time() if timeout else None
while futures:
while len(futures) > (1 if get_waiter is not None else 0):
done, _ = await asyncio.wait(
futures,
return_when=asyncio.FIRST_COMPLETED,
Expand All @@ -132,6 +140,10 @@ async def atick(
break # timed out
for fut in done:
task = futures.pop(fut)
if task is None:
# waiter task finished, schedule another
futures[get_waiter()] = None
continue
if exc := _exception(fut):
if isinstance(exc, GraphInterrupt):
# save interrupt to checkpointer
Expand All @@ -156,6 +168,9 @@ async def atick(
break
# give control back to the caller
yield
# cancel waiter task
for fut in futures:
fut.cancel()
# panic on failure or timeout
_panic_or_proceed(
all_futures, timeout_exc_cls=asyncio.TimeoutError, panic=reraise
Expand Down Expand Up @@ -187,15 +202,20 @@ def _exception(


def _panic_or_proceed(
futs: Union[set[concurrent.futures.Future[Any]], set[asyncio.Task[Any]]],
futs: Union[
dict[concurrent.futures.Future, Optional[PregelExecutableTask]],
dict[asyncio.Future, Optional[PregelExecutableTask]],
],
*,
timeout_exc_cls: Type[Exception] = TimeoutError,
panic: bool = True,
) -> None:
done: set[Union[concurrent.futures.Future[Any], asyncio.Task[Any]]] = set()
inflight: set[Union[concurrent.futures.Future[Any], asyncio.Task[Any]]] = set()
for fut in futs:
if fut.done():
for fut, val in futs.items():
if val is None:
continue
elif fut.done():
done.add(fut)
else:
inflight.add(fut)
Expand Down
35 changes: 35 additions & 0 deletions libs/langgraph/langgraph/utils/aio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import sys

PY_310 = sys.version_info >= (3, 10)


class Queue(asyncio.Queue):
async def wait(self):
"""If queue is empty, wait until an item is available.
Copied from Queue.get(), removing the call to .get_nowait(),
ie. this doesn't consume the item, just waits for it.
"""
while self.empty():
if PY_310:
getter = self._get_loop().create_future()
else:
getter = self._loop.create_future()
self._getters.append(getter)
try:
await getter
except:
getter.cancel() # Just in case getter is not done yet.
try:
# Clean self._getters from canceled getters.
self._getters.remove(getter)
except ValueError:
# The getter could be removed from self._getters by a
# previous put_nowait call.
pass
if not self.empty() and not getter.cancelled():
# We were woken up by put_nowait(), but can't take
# the call. Wake up the next in line.
self._wakeup_next(self._getters)
raise
6 changes: 3 additions & 3 deletions libs/langgraph/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/langgraph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pytest-repeat = "^0.9.3"
langgraph-checkpoint = {path = "../checkpoint", develop = true}
langgraph-checkpoint-sqlite = {path = "../checkpoint-sqlite", develop = true}
langgraph-checkpoint-postgres = {path = "../checkpoint-postgres", develop = true}
psycopg = {extras = ["binary"], version = ">=3.0.0"}
psycopg = {extras = ["binary"], version = ">=3.0.0", python = ">=3.10"}
uvloop = "^0.20.0"
pyperf = "^2.7.0"
py-spy = "^0.3.14"
Expand Down
76 changes: 76 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from collections import Counter
from contextlib import asynccontextmanager, contextmanager
from time import perf_counter
from typing import (
Annotated,
Any,
Expand Down Expand Up @@ -363,6 +364,7 @@ async def iambad(input: State) -> None:
assert awhiles == 2


@pytest.mark.repeat(10)
async def test_step_timeout_on_stream_hang() -> None:
inner_task_cancelled = False

Expand Down Expand Up @@ -6958,6 +6960,80 @@ async def side(state: State):
assert times_called == 1


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_stream_subgraphs_during_execution(checkpointer_name: str) -> None:
class InnerState(TypedDict):
my_key: Annotated[str, operator.add]
my_other_key: str

async def inner_1(state: InnerState):
return {"my_key": "got here", "my_other_key": state["my_key"]}

async def inner_2(state: InnerState):
await asyncio.sleep(0.5)
return {
"my_key": " and there",
"my_other_key": state["my_key"],
}

inner = StateGraph(InnerState)
inner.add_node("inner_1", inner_1)
inner.add_node("inner_2", inner_2)
inner.add_edge("inner_1", "inner_2")
inner.set_entry_point("inner_1")
inner.set_finish_point("inner_2")

class State(TypedDict):
my_key: Annotated[str, operator.add]

async def outer_1(state: State):
await asyncio.sleep(0.2)
return {"my_key": " and parallel"}

async def outer_2(state: State):
return {"my_key": " and back again"}

graph = StateGraph(State)
graph.add_node("inner", inner.compile())
graph.add_node("outer_1", outer_1)
graph.add_node("outer_2", outer_2)

graph.add_edge(START, "inner")
graph.add_edge(START, "outer_1")
graph.add_edge(["inner", "outer_1"], "outer_2")
graph.add_edge("outer_2", END)

async with awith_checkpointer(checkpointer_name) as checkpointer:
app = graph.compile(checkpointer=checkpointer)

start = perf_counter()
chunks: list[tuple[float, Any]] = []
config = {"configurable": {"thread_id": "2"}}
async for c in app.astream({"my_key": ""}, config, subgraphs=True):
chunks.append((round(perf_counter() - start, 1), c))

assert chunks == [
# arrives before "inner" finishes
(
0.0,
(
(AnyStr("inner:"),),
{"inner_1": {"my_key": "got here", "my_other_key": ""}},
),
),
(0.2, ((), {"outer_1": {"my_key": " and parallel"}})),
(
0.5,
(
(AnyStr("inner:"),),
{"inner_2": {"my_key": " and there", "my_other_key": "got here"}},
),
),
(0.5, ((), {"inner": {"my_key": "got here and there"}})),
(0.5, ((), {"outer_2": {"my_key": " and back again"}})),
]


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_nested_graph_interrupts_parallel(checkpointer_name: str) -> None:
class InnerState(TypedDict):
Expand Down

0 comments on commit c4d4d61

Please sign in to comment.