diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 1d9b2c9df..c0c5e0bcb 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -73,7 +73,7 @@ ) from langgraph.pregel.debug import tasks_w_writes from langgraph.pregel.io import read_channels -from langgraph.pregel.loop import AsyncPregelLoop, SyncPregelLoop +from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager from langgraph.pregel.read import PregelNode from langgraph.pregel.retry import RetryPolicy @@ -1148,15 +1148,14 @@ def stream( def output() -> Iterator: while stream: ns, mode, payload = stream.popleft() - if mode in stream_modes: - if subgraphs and isinstance(stream_mode, list): - yield (tuple(ns.split(NS_SEP)) if ns else (), mode, payload) - elif isinstance(stream_mode, list): - yield (mode, payload) - elif subgraphs: - yield (tuple(ns.split(NS_SEP)) if ns else (), payload) - else: - yield payload + if subgraphs and isinstance(stream_mode, list): + yield (ns, mode, payload) + elif isinstance(stream_mode, list): + yield (mode, payload) + elif subgraphs: + yield (ns, payload) + else: + yield payload config = ensure_config(self.config, config) callback_manager = get_callback_manager_for_config(config) @@ -1192,7 +1191,7 @@ def output() -> Iterator: with SyncPregelLoop( input, - stream=stream.append, + stream=StreamProtocol(stream.append, stream_modes), config=config, store=self.store, checkpointer=checkpointer, @@ -1329,15 +1328,14 @@ async def astream( def output() -> Iterator: while stream: ns, mode, payload = stream.popleft() - if mode in stream_modes: - if subgraphs and isinstance(stream_mode, list): - yield (tuple(ns.split(NS_SEP)) if ns else (), mode, payload) - elif isinstance(stream_mode, list): - yield (mode, payload) - elif subgraphs: - yield (tuple(ns.split(NS_SEP)) if ns else (), payload) - else: - yield payload + if subgraphs and isinstance(stream_mode, list): + yield (ns, mode, payload) + elif isinstance(stream_mode, list): + yield (mode, payload) + elif subgraphs: + yield (ns, payload) + else: + yield payload config = ensure_config(self.config, config) callback_manager = get_async_callback_manager_for_config(config) @@ -1381,7 +1379,7 @@ def output() -> Iterator: ) async with AsyncPregelLoop( input, - stream=stream.append, + stream=StreamProtocol(stream.append, stream_modes), config=config, store=self.store, checkpointer=checkpointer, diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 52401283f..f1db60e0e 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -9,11 +9,11 @@ Callable, ContextManager, Iterable, + Iterator, List, Literal, Mapping, Optional, - Protocol, Sequence, Tuple, Type, @@ -24,7 +24,7 @@ from langchain_core.callbacks import AsyncParentRunManager, ParentRunManager from langchain_core.runnables import RunnableConfig -from typing_extensions import Self +from typing_extensions import ParamSpec, Self from langgraph.channels.base import BaseChannel from langgraph.checkpoint.base import ( @@ -48,6 +48,7 @@ ERROR, INPUT, INTERRUPT, + NS_SEP, SCHEDULED, TAG_HIDDEN, TASKS, @@ -99,23 +100,37 @@ from langgraph.utils.config import patch_configurable V = TypeVar("V") +P = ParamSpec("P") INPUT_DONE = object() INPUT_RESUMING = object() EMPTY_SEQ = () SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED) -class StreamProtocol(Protocol): - def __call__(self, values: Iterable[Tuple[str, str, Any]]) -> None: ... +class StreamProtocol: + __slots__ = ("modes", "__call__") + + modes: Sequence[Literal["values", "updates", "debug"]] + + __call__: Callable[[Iterable[Tuple[str, str, Any]]], None] + + def __init__( + self, + __call__: Callable[[Iterable[Tuple[str, str, Any]]], None], + modes: Sequence[Literal["values", "updates", "debug"]], + ) -> None: + self.__call__ = __call__ + self.modes = modes class DuplexStream(StreamProtocol): - def __init__(self, *queues: StreamProtocol) -> None: - self.queues = queues + def __init__(self, *streams: StreamProtocol) -> None: + def __call__(value: Tuple[str, str, Any]) -> None: + for stream in streams: + if value[1] in stream.modes: + stream(value) - def __call__(self, value: Tuple[str, str, Any]) -> None: - for queue in self.queues: - queue(value) + super().__init__(__call__, {mode for s in streams for mode in s.modes}) class PregelLoop: @@ -150,6 +165,7 @@ class PregelLoop: channels: Mapping[str, BaseChannel] managed: ManagedValueMapping checkpoint: Checkpoint + checkpoint_ns: tuple[str, ...] checkpoint_config: RunnableConfig checkpoint_metadata: CheckpointMetadata checkpoint_pending_writes: List[PendingWrite] @@ -217,6 +233,11 @@ def __init__( ) else: self.checkpoint_config = config + self.checkpoint_ns = ( + tuple(self.config["configurable"].get("checkpoint_ns").split(NS_SEP)) + if self.config["configurable"].get("checkpoint_ns") + else () + ) def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None: """Put writes for a task, to be read by the next tick.""" @@ -293,8 +314,7 @@ def tick( self._update_mv(key, values) # produce values output self._emit( - (self.config["configurable"].get("checkpoint_ns", ""), "values", v) - for v in map_output_values(self.output_keys, writes, self.channels) + "values", map_output_values, self.output_keys, writes, self.channels ) # clear pending writes self.checkpoint_pending_writes.clear() @@ -344,17 +364,16 @@ def tick( # produce debug output if self._checkpointer_put_after_previous is not None: self._emit( - (self.config["configurable"].get("checkpoint_ns", ""), "debug", v) - for v in map_debug_checkpoint( - self.step - 1, # printing checkpoint for previous step - self.checkpoint_config, - self.channels, - self.stream_keys, - self.checkpoint_metadata, - self.checkpoint, - self.tasks.values(), - self.checkpoint_pending_writes, - ) + "debug", + map_debug_checkpoint, + self.step - 1, # printing checkpoint for previous step + self.checkpoint_config, + self.channels, + self.stream_keys, + self.checkpoint_metadata, + self.checkpoint, + self.tasks.values(), + self.checkpoint_pending_writes, ) # if no more tasks, we're done @@ -413,10 +432,7 @@ def tick( return False # produce debug output - self._emit( - (self.config["configurable"].get("checkpoint_ns", ""), "debug", v) - for v in map_debug_tasks(self.step, self.tasks.values()) - ) + self._emit("debug", map_debug_tasks, self.step, self.tasks.values()) # debug flag if self.debug: @@ -444,8 +460,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: self.checkpoint["versions_seen"][INTERRUPT][k] = version # produce values output self._emit( - (self.config["configurable"].get("checkpoint_ns", ""), "values", v) - for v in map_output_values(self.output_keys, True, self.channels) + "values", map_output_values, self.output_keys, True, self.channels ) # map inputs to channel updates elif input_writes := deque(map_input(input_keys, self.input)): @@ -563,11 +578,19 @@ def _suppress_interrupt( # suppress interrupt return True - def _emit(self, values: Sequence[tuple[str, str, Any]]) -> None: + def _emit( + self, + mode: str, + values: Callable[P, Iterator[Any]], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: if self.stream is None: return - for v in values: - self.stream(v) + if mode not in self.stream.modes: + return + for v in values(*args, **kwargs): + self.stream((self.checkpoint_ns, mode, v)) def _output_writes( self, task_id: str, writes: Sequence[tuple[str, Any]], *, cached: bool = False @@ -579,17 +602,19 @@ def _output_writes( return if writes[0][0] != ERROR and writes[0][0] != INTERRUPT: self._emit( - (self.config["configurable"].get("checkpoint_ns", ""), "updates", v) - for v in map_output_updates( - self.output_keys, [(task, writes)], cached - ) + "updates", + map_output_updates, + self.output_keys, + [(task, writes)], + cached, ) if not cached: self._emit( - (self.config["configurable"].get("checkpoint_ns", ""), "debug", v) - for v in map_debug_task_results( - self.step, (task, writes), self.stream_keys - ) + "debug", + map_debug_task_results, + self.step, + (task, writes), + self.stream_keys, )