Skip to content

Commit

Permalink
Merge pull request #1732 from langchain-ai/nc/16sep/stream-selected
Browse files Browse the repository at this point in the history
perf: In PregelLoop, only emit stream values requested by caller
  • Loading branch information
nfcampos authored Sep 17, 2024
2 parents f8ba4c3 + f505afe commit 0bf4bf8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 60 deletions.
40 changes: 19 additions & 21 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
)
from langgraph.pregel.debug import tasks_w_writes
from langgraph.pregel.io import read_channels
from langgraph.pregel.loop import AsyncPregelLoop, SyncPregelLoop
from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop
from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager
from langgraph.pregel.read import PregelNode
from langgraph.pregel.retry import RetryPolicy
Expand Down Expand Up @@ -1148,15 +1148,14 @@ def stream(
def output() -> Iterator:
while stream:
ns, mode, payload = stream.popleft()
if mode in stream_modes:
if subgraphs and isinstance(stream_mode, list):
yield (tuple(ns.split(NS_SEP)) if ns else (), mode, payload)
elif isinstance(stream_mode, list):
yield (mode, payload)
elif subgraphs:
yield (tuple(ns.split(NS_SEP)) if ns else (), payload)
else:
yield payload
if subgraphs and isinstance(stream_mode, list):
yield (ns, mode, payload)
elif isinstance(stream_mode, list):
yield (mode, payload)
elif subgraphs:
yield (ns, payload)
else:
yield payload

config = ensure_config(self.config, config)
callback_manager = get_callback_manager_for_config(config)
Expand Down Expand Up @@ -1192,7 +1191,7 @@ def output() -> Iterator:

with SyncPregelLoop(
input,
stream=stream.append,
stream=StreamProtocol(stream.append, stream_modes),
config=config,
store=self.store,
checkpointer=checkpointer,
Expand Down Expand Up @@ -1329,15 +1328,14 @@ async def astream(
def output() -> Iterator:
while stream:
ns, mode, payload = stream.popleft()
if mode in stream_modes:
if subgraphs and isinstance(stream_mode, list):
yield (tuple(ns.split(NS_SEP)) if ns else (), mode, payload)
elif isinstance(stream_mode, list):
yield (mode, payload)
elif subgraphs:
yield (tuple(ns.split(NS_SEP)) if ns else (), payload)
else:
yield payload
if subgraphs and isinstance(stream_mode, list):
yield (ns, mode, payload)
elif isinstance(stream_mode, list):
yield (mode, payload)
elif subgraphs:
yield (ns, payload)
else:
yield payload

config = ensure_config(self.config, config)
callback_manager = get_async_callback_manager_for_config(config)
Expand Down Expand Up @@ -1381,7 +1379,7 @@ def output() -> Iterator:
)
async with AsyncPregelLoop(
input,
stream=stream.append,
stream=StreamProtocol(stream.append, stream_modes),
config=config,
store=self.store,
checkpointer=checkpointer,
Expand Down
103 changes: 64 additions & 39 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
Callable,
ContextManager,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
Expand All @@ -24,7 +24,7 @@

from langchain_core.callbacks import AsyncParentRunManager, ParentRunManager
from langchain_core.runnables import RunnableConfig
from typing_extensions import Self
from typing_extensions import ParamSpec, Self

from langgraph.channels.base import BaseChannel
from langgraph.checkpoint.base import (
Expand All @@ -48,6 +48,7 @@
ERROR,
INPUT,
INTERRUPT,
NS_SEP,
SCHEDULED,
TAG_HIDDEN,
TASKS,
Expand Down Expand Up @@ -99,23 +100,37 @@
from langgraph.utils.config import patch_configurable

V = TypeVar("V")
P = ParamSpec("P")
INPUT_DONE = object()
INPUT_RESUMING = object()
EMPTY_SEQ = ()
SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED)


class StreamProtocol(Protocol):
def __call__(self, values: Iterable[Tuple[str, str, Any]]) -> None: ...
class StreamProtocol:
__slots__ = ("modes", "__call__")

modes: Sequence[Literal["values", "updates", "debug"]]

__call__: Callable[[Iterable[Tuple[str, str, Any]]], None]

def __init__(
self,
__call__: Callable[[Iterable[Tuple[str, str, Any]]], None],
modes: Sequence[Literal["values", "updates", "debug"]],
) -> None:
self.__call__ = __call__
self.modes = modes


class DuplexStream(StreamProtocol):
def __init__(self, *queues: StreamProtocol) -> None:
self.queues = queues
def __init__(self, *streams: StreamProtocol) -> None:
def __call__(value: Tuple[str, str, Any]) -> None:
for stream in streams:
if value[1] in stream.modes:
stream(value)

def __call__(self, value: Tuple[str, str, Any]) -> None:
for queue in self.queues:
queue(value)
super().__init__(__call__, {mode for s in streams for mode in s.modes})


class PregelLoop:
Expand Down Expand Up @@ -150,6 +165,7 @@ class PregelLoop:
channels: Mapping[str, BaseChannel]
managed: ManagedValueMapping
checkpoint: Checkpoint
checkpoint_ns: tuple[str, ...]
checkpoint_config: RunnableConfig
checkpoint_metadata: CheckpointMetadata
checkpoint_pending_writes: List[PendingWrite]
Expand Down Expand Up @@ -217,6 +233,11 @@ def __init__(
)
else:
self.checkpoint_config = config
self.checkpoint_ns = (
tuple(self.config["configurable"].get("checkpoint_ns").split(NS_SEP))
if self.config["configurable"].get("checkpoint_ns")
else ()
)

def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None:
"""Put writes for a task, to be read by the next tick."""
Expand Down Expand Up @@ -293,8 +314,7 @@ def tick(
self._update_mv(key, values)
# produce values output
self._emit(
(self.config["configurable"].get("checkpoint_ns", ""), "values", v)
for v in map_output_values(self.output_keys, writes, self.channels)
"values", map_output_values, self.output_keys, writes, self.channels
)
# clear pending writes
self.checkpoint_pending_writes.clear()
Expand Down Expand Up @@ -344,17 +364,16 @@ def tick(
# produce debug output
if self._checkpointer_put_after_previous is not None:
self._emit(
(self.config["configurable"].get("checkpoint_ns", ""), "debug", v)
for v in map_debug_checkpoint(
self.step - 1, # printing checkpoint for previous step
self.checkpoint_config,
self.channels,
self.stream_keys,
self.checkpoint_metadata,
self.checkpoint,
self.tasks.values(),
self.checkpoint_pending_writes,
)
"debug",
map_debug_checkpoint,
self.step - 1, # printing checkpoint for previous step
self.checkpoint_config,
self.channels,
self.stream_keys,
self.checkpoint_metadata,
self.checkpoint,
self.tasks.values(),
self.checkpoint_pending_writes,
)

# if no more tasks, we're done
Expand Down Expand Up @@ -413,10 +432,7 @@ def tick(
return False

# produce debug output
self._emit(
(self.config["configurable"].get("checkpoint_ns", ""), "debug", v)
for v in map_debug_tasks(self.step, self.tasks.values())
)
self._emit("debug", map_debug_tasks, self.step, self.tasks.values())

# debug flag
if self.debug:
Expand Down Expand Up @@ -444,8 +460,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None:
self.checkpoint["versions_seen"][INTERRUPT][k] = version
# produce values output
self._emit(
(self.config["configurable"].get("checkpoint_ns", ""), "values", v)
for v in map_output_values(self.output_keys, True, self.channels)
"values", map_output_values, self.output_keys, True, self.channels
)
# map inputs to channel updates
elif input_writes := deque(map_input(input_keys, self.input)):
Expand Down Expand Up @@ -563,11 +578,19 @@ def _suppress_interrupt(
# suppress interrupt
return True

def _emit(self, values: Sequence[tuple[str, str, Any]]) -> None:
def _emit(
self,
mode: str,
values: Callable[P, Iterator[Any]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
if self.stream is None:
return
for v in values:
self.stream(v)
if mode not in self.stream.modes:
return
for v in values(*args, **kwargs):
self.stream((self.checkpoint_ns, mode, v))

def _output_writes(
self, task_id: str, writes: Sequence[tuple[str, Any]], *, cached: bool = False
Expand All @@ -579,17 +602,19 @@ def _output_writes(
return
if writes[0][0] != ERROR and writes[0][0] != INTERRUPT:
self._emit(
(self.config["configurable"].get("checkpoint_ns", ""), "updates", v)
for v in map_output_updates(
self.output_keys, [(task, writes)], cached
)
"updates",
map_output_updates,
self.output_keys,
[(task, writes)],
cached,
)
if not cached:
self._emit(
(self.config["configurable"].get("checkpoint_ns", ""), "debug", v)
for v in map_debug_task_results(
self.step, (task, writes), self.stream_keys
)
"debug",
map_debug_task_results,
self.step,
(task, writes),
self.stream_keys,
)


Expand Down

0 comments on commit 0bf4bf8

Please sign in to comment.