Skip to content

Commit

Permalink
Merge pull request #1776 from langchain-ai/nc/19sep/stream-mode-custom
Browse files Browse the repository at this point in the history
Add stream_mode=custom
  • Loading branch information
nfcampos authored Sep 20, 2024
2 parents b9fe384 + 531890e commit b03d9ae
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 22 deletions.
8 changes: 7 additions & 1 deletion libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
from typing import Any, Literal
from types import MappingProxyType
from typing import Any, Literal, Mapping

INPUT = "__input__"
CONFIG_KEY_SEND = "__pregel_send"
CONFIG_KEY_READ = "__pregel_read"
CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer"
CONFIG_KEY_STREAM = "__pregel_stream"
CONFIG_KEY_STREAM_WRITER = "__pregel_stream_writer"
CONFIG_KEY_STORE = "__pregel_store"
CONFIG_KEY_RESUMING = "__pregel_resuming"
CONFIG_KEY_TASK_ID = "__pregel_task_id"
Expand Down Expand Up @@ -34,6 +36,8 @@
CONFIG_KEY_READ,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_STREAM,
CONFIG_KEY_STREAM_WRITER,
CONFIG_KEY_STORE,
CONFIG_KEY_RESUMING,
CONFIG_KEY_TASK_ID,
Expand All @@ -51,6 +55,8 @@
NS_SEP = "|"
NS_END = ":"

EMPTY_MAP: Mapping[str, Any] = MappingProxyType({})


class Send:
"""A message or packet to send to a specific node in the graph.
Expand Down
15 changes: 13 additions & 2 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
CONFIG_KEY_RESUMING,
CONFIG_KEY_SEND,
CONFIG_KEY_STREAM,
CONFIG_KEY_STREAM_WRITER,
CONFIG_KEY_TASK_ID,
INTERRUPT,
NS_END,
Expand Down Expand Up @@ -1219,6 +1220,11 @@ def output() -> Iterator:
run_manager.inheritable_handlers.append(
StreamMessagesHandler(stream.put)
)
# set up custom stream mode
if "custom" in stream_modes:
config["configurable"][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put(
((), "custom", c)
)
with SyncPregelLoop(
input,
stream=StreamProtocol(stream.put, stream_modes),
Expand All @@ -1240,7 +1246,7 @@ def output() -> Iterator:
if subgraphs:
loop.config["configurable"][CONFIG_KEY_STREAM] = loop.stream
# enable concurrent streaming
if subgraphs or "messages" in stream_modes:
if subgraphs or "messages" in stream_modes or "custom" in stream_modes:
# we are careful to have a single waiter live at any one time
# because on exit we increment semaphore count by exactly 1
waiter: Optional[concurrent.futures.Future] = None
Expand Down Expand Up @@ -1435,6 +1441,11 @@ def output() -> Iterator:
run_manager.inheritable_handlers.append(
StreamMessagesHandler(stream.put_nowait)
)
# set up custom stream mode
if "custom" in stream_modes:
config["configurable"][CONFIG_KEY_STREAM_WRITER] = (
lambda c: stream.put_nowait(((), "custom", c))
)
async with AsyncPregelLoop(
input,
stream=StreamProtocol(stream.put_nowait, stream_modes),
Expand All @@ -1456,7 +1467,7 @@ def output() -> Iterator:
if subgraphs:
loop.config["configurable"][CONFIG_KEY_STREAM] = loop.stream
# enable concurrent streaming
if subgraphs or "messages" in stream_modes:
if subgraphs or "messages" in stream_modes or "custom" in stream_modes:

def get_waiter() -> asyncio.Task[None]:
return aioloop.create_task(stream.wait())
Expand Down
8 changes: 7 additions & 1 deletion libs/langgraph/langgraph/pregel/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,18 @@ class StateSnapshot(NamedTuple):

All = Literal["*"]

StreamMode = Literal["values", "updates", "debug", "messages"]
StreamMode = Literal["values", "updates", "debug", "messages", "custom"]
"""How the stream method should emit outputs.
- 'values': Emit all values of the state for each step.
- 'updates': Emit only the node name(s) and updates
that were returned by the node(s) **after** each step.
- 'debug': Emit debug events for each step.
- 'messages': Emit LLM messages token-by-token.
- 'custom': Emit custom output `write: StreamWriter` kwarg of each node.
"""

StreamWriter = Callable[[Any], None]
"""Callable that accepts a single argument and writes it to the output stream.
Always injected into nodes if requested,
but it's a no-op when not using stream_mode="custom"."""
48 changes: 38 additions & 10 deletions libs/langgraph/langgraph/utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import Input, accepts_config
from langchain_core.runnables.utils import Input
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from typing_extensions import TypeGuard

from langgraph.constants import CONFIG_KEY_STREAM_WRITER
from langgraph.pregel.types import StreamWriter
from langgraph.utils.config import (
ensure_config,
get_async_callback_manager_for_config,
Expand All @@ -57,6 +59,19 @@ class StrEnum(str, enum.Enum):

ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)

KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
(
sys.intern("writer"),
(StreamWriter, inspect.Parameter.empty),
CONFIG_KEY_STREAM_WRITER,
lambda _: None,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations."""

VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)


class RunnableCallable(Runnable):
"""A much simpler version of RunnableLambda that requires sync and async functions."""
Expand Down Expand Up @@ -86,15 +101,22 @@ def __init__(
except AttributeError:
pass
self.func = func
if func is not None:
self.func_accepts_config = accepts_config(func)
self.afunc = afunc
if afunc is not None:
self.afunc_accepts_config = accepts_config(afunc)
self.tags = tags
self.kwargs = kwargs
self.trace = trace
self.recurse = recurse
# check signature
if func is None and afunc is None:
raise ValueError("At least one of func or afunc must be provided.")
params = inspect.signature(cast(Callable, func or afunc)).parameters
self.func_accepts_config = "config" in params
self.func_accepts: dict[str, bool] = {}
for kw, typ, _, _ in KWARGS_CONFIG_KEYS:
p = params.get(kw)
self.func_accepts[kw] = (
p is not None and p.annotation in typ and p.kind in VALID_KINDS
)

def __repr__(self) -> str:
repr_args = {
Expand All @@ -113,11 +135,14 @@ def invoke(
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
if config is None:
config = ensure_config()
kwargs = {**self.kwargs, **kwargs}
if self.func_accepts_config:
kwargs["config"] = config
if config is None:
config = ensure_config()
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
if self.func_accepts[kw]:
kwargs[kw] = config["configurable"].get(ck, defv)
context = copy_context()
if self.trace:
callback_manager = get_callback_manager_for_config(config, self.tags)
Expand Down Expand Up @@ -149,11 +174,14 @@ async def ainvoke(
) -> Any:
if not self.afunc:
return self.invoke(input, config)
kwargs = {**self.kwargs, **kwargs}
if self.afunc_accepts_config:
kwargs["config"] = config
if config is None:
config = ensure_config()
kwargs = {**self.kwargs, **kwargs}
if self.func_accepts_config:
kwargs["config"] = config
for kw, _, ck, defv in KWARGS_CONFIG_KEYS:
if self.func_accepts[kw]:
kwargs[kw] = config["configurable"].get(ck, defv)
context = copy_context()
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
Expand Down
21 changes: 17 additions & 4 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
StateSnapshot,
)
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.types import PregelTask
from langgraph.pregel.types import PregelTask, StreamWriter
from langgraph.store.memory import MemoryStore
from tests.any_str import AnyDict, AnyStr, AnyVersion, UnsortedSequence
from tests.conftest import ALL_CHECKPOINTERS_SYNC, SHOULD_CHECK_SNAPSHOTS
Expand Down Expand Up @@ -10440,11 +10440,13 @@ def get_weather(city: str):
class SubGraphState(MessagesState):
city: str

def model_node(state: SubGraphState):
def model_node(state: SubGraphState, writer: StreamWriter):
writer(" very")
result = weather_model.invoke(state["messages"])
return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]}

def weather_node(state: SubGraphState):
def weather_node(state: SubGraphState, writer: StreamWriter):
writer(" good")
result = get_weather.invoke({"city": state["city"]})
return {"messages": [{"role": "assistant", "content": result}]}

Expand Down Expand Up @@ -10479,7 +10481,8 @@ class Router(TypedDict):
]
)

def router_node(state: RouterState):
def router_node(state: RouterState, writer: StreamWriter):
writer("I'm")
system_message = "Classify the incoming query as either about weather or not."
messages = [{"role": "system", "content": system_message}] + state["messages"]
route = router_model.invoke(messages)
Expand Down Expand Up @@ -10510,8 +10513,18 @@ def weather_graph(state: RouterState):
assert graph.get_graph(xray=1).draw_mermaid() == snapshot

config = {"configurable": {"thread_id": "1"}}
thread2 = {"configurable": {"thread_id": "2"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}

# run with custom output
assert [c for c in graph.stream(inputs, thread2, stream_mode="custom")] == [
"I'm",
" very",
]
assert [c for c in graph.stream(None, thread2, stream_mode="custom")] == [
" good",
]

# run until interrupt
assert [
c
Expand Down
25 changes: 21 additions & 4 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
StateSnapshot,
)
from langgraph.pregel.retry import RetryPolicy
from langgraph.pregel.types import PregelTask
from langgraph.pregel.types import PregelTask, StreamWriter
from langgraph.store.memory import MemoryStore
from tests.any_str import AnyDict, AnyStr, AnyVersion, UnsortedSequence
from tests.conftest import (
Expand Down Expand Up @@ -9051,11 +9051,13 @@ def get_weather(city: str):
class SubGraphState(MessagesState):
city: str

def model_node(state: SubGraphState):
def model_node(state: SubGraphState, writer: StreamWriter):
writer(" very")
result = weather_model.invoke(state["messages"])
return {"city": cast(AIMessage, result).tool_calls[0]["args"]["city"]}

def weather_node(state: SubGraphState):
def weather_node(state: SubGraphState, writer: StreamWriter):
writer(" good")
result = get_weather.invoke({"city": state["city"]})
return {"messages": [{"role": "assistant", "content": result}]}

Expand Down Expand Up @@ -9090,7 +9092,8 @@ class Router(TypedDict):
]
)

def router_node(state: RouterState):
def router_node(state: RouterState, writer: StreamWriter):
writer("I'm")
system_message = "Classify the incoming query as either about weather or not."
messages = [{"role": "system", "content": system_message}] + state["messages"]
route = router_model.invoke(messages)
Expand Down Expand Up @@ -9128,8 +9131,22 @@ def get_first_in_list():
assert graph.get_graph(xray=1).draw_mermaid() == snapshot

config = {"configurable": {"thread_id": "1"}}
thread2 = {"configurable": {"thread_id": "2"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}

# run with custom output
assert [
c async for c in graph.astream(inputs, thread2, stream_mode="custom")
] == [
"I'm",
" very",
]
assert [
c async for c in graph.astream(None, thread2, stream_mode="custom")
] == [
" good",
]

# run until interrupt
assert [
c
Expand Down

0 comments on commit b03d9ae

Please sign in to comment.