Skip to content

Commit

Permalink
Merge pull request #2346 from langchain-ai/nc/4nov/send-eager
Browse files Browse the repository at this point in the history
lib: Execute Sends in the superstep that originated them (feature-flagged)
  • Loading branch information
nfcampos authored Nov 13, 2024
2 parents 6906e12 + 29f833b commit f111276
Show file tree
Hide file tree
Showing 32 changed files with 2,891 additions and 663 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/_test_langgraph.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@ jobs:
- "3.13"
core-version:
- "latest"
ff-send-v2:
- "false"
include:
- python-version: "3.11"
core-version: ">=0.2.42,<0.3.0"
- python-version: "3.11"
core-version: "latest"
ff-send-v2: "true"

defaults:
run:
working-directory: libs/langgraph
name: "test #${{ matrix.python-version }} (langchain-core: ${{ matrix.core-version }})"
name: "test #${{ matrix.python-version }} (langchain-core: ${{ matrix.core-version }}, ff-send-v2: ${{ matrix.ff-send-v2 }})"
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
Expand All @@ -52,6 +57,8 @@ jobs:
- name: Run tests
shell: bash
env:
LANGGRAPH_FF_SEND_V2: ${{ matrix.ff-send-v2 }}
run: |
make test
Expand Down
4 changes: 3 additions & 1 deletion libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import (
ERROR,
INTERRUPT,
RESUME,
SCHEDULED,
ChannelProtocol,
SendProtocol,
Expand Down Expand Up @@ -449,4 +451,4 @@ def get_checkpoint_id(config: RunnableConfig) -> Optional[str]:
conflicting with regular writes.
Each Checkpointer implementation should use this mapping in put_writes.
"""
WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2}
WRITES_IDX_MAP = {ERROR: -1, SCHEDULED: -2, INTERRUPT: -3, RESUME: -4}
18 changes: 1 addition & 17 deletions libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from zoneinfo import ZoneInfo

from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.types import CommandProtocol, SendProtocol
from langgraph.checkpoint.serde.types import SendProtocol
from langgraph.store.base import Item

LC_REVIVER = Reviver()
Expand Down Expand Up @@ -122,11 +122,6 @@ def _default(self, obj: Any) -> Union[str, dict[str, Any]]:
return self._encode_constructor_args(
obj.__class__, kwargs={"node": obj.node, "arg": obj.arg}
)
elif isinstance(obj, CommandProtocol):
return self._encode_constructor_args(
obj.__class__,
kwargs={k: getattr(obj, k) for k in obj.__all_slots__},
)
elif isinstance(obj, (bytes, bytearray)):
return self._encode_constructor_args(
obj.__class__, method="fromhex", args=(obj.hex(),)
Expand Down Expand Up @@ -407,17 +402,6 @@ def _msgpack_default(obj: Any) -> Union[str, msgpack.ExtType]:
(obj.__class__.__module__, obj.__class__.__name__, (obj.node, obj.arg)),
),
)
elif isinstance(obj, CommandProtocol):
return msgpack.ExtType(
EXT_CONSTRUCTOR_KW_ARGS,
_msgpack_enc(
(
obj.__class__.__module__,
obj.__class__.__name__,
{k: getattr(obj, k) for k in obj.__all_slots__},
),
),
)
elif dataclasses.is_dataclass(obj):
# doesn't use dataclasses.asdict to avoid deepcopy and recursion
return msgpack.ExtType(
Expand Down
11 changes: 2 additions & 9 deletions libs/checkpoint/langgraph/checkpoint/serde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
Protocol,
Sequence,
TypeVar,
Union,
runtime_checkable,
)

from typing_extensions import Self

ERROR = "__error__"
SCHEDULED = "__scheduled__"
INTERRUPT = "__interrupt__"
RESUME = "__resume__"
TASKS = "__pregel_tasks"

Value = TypeVar("Value", covariant=True)
Expand Down Expand Up @@ -49,11 +50,3 @@ def __hash__(self) -> int: ...
def __repr__(self) -> str: ...

def __eq__(self, value: object) -> bool: ...


@runtime_checkable
class CommandProtocol(Protocol):
# Mirrors langgraph.types.Command
update: Optional[dict[str, Any]]
send: Union[Any, Sequence[Any]]
__all_slots__: set[str]
2 changes: 1 addition & 1 deletion libs/langgraph/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ test:
exit $$EXIT_CODE

test_watch:
make start-postgres && poetry run ptw . -- --ff -v -x -n auto --dist worksteal --snapshot-update --tb short $(TEST); \
make start-postgres && poetry run ptw . -- --ff -vv -x -n auto --dist worksteal --snapshot-update --tb short $(TEST); \
EXIT_CODE=$$?; \
make stop-postgres; \
exit $$EXIT_CODE
Expand Down
11 changes: 11 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from os import getenv
from types import MappingProxyType
from typing import Any, Literal, Mapping, cast

Expand All @@ -10,6 +11,7 @@
# --- Empty read-only containers ---
EMPTY_MAP: Mapping[str, Any] = MappingProxyType({})
EMPTY_SEQ: tuple[str, ...] = tuple()
MISSING = object()

# --- Public constants ---
TAG_NOSTREAM = sys.intern("langsmith:nostream")
Expand All @@ -28,6 +30,8 @@
# for values passed as input to the graph
INTERRUPT = sys.intern("__interrupt__")
# for dynamic interrupts raised by nodes
RESUME = sys.intern("__resume__")
# for values passed to resume a node after an interrupt
ERROR = sys.intern("__error__")
# for errors raised by nodes
NO_WRITES = sys.intern("__no_writes__")
Expand Down Expand Up @@ -69,6 +73,8 @@
# holds the current checkpoint_ns, "" for root graph
CONFIG_KEY_NODE_FINISHED = sys.intern("__pregel_node_finished")
# callback to be called when a node is finished
CONFIG_KEY_RESUME_VALUE = sys.intern("__pregel_resume_value")
# holds the value that "answers" an interrupt() call

# --- Other constants ---
PUSH = sys.intern("__pregel_push")
Expand All @@ -81,12 +87,17 @@
# for checkpoint_ns, for each level, separates the namespace from the task_id
CONF = cast(Literal["configurable"], sys.intern("configurable"))
# key for the configurable dict in RunnableConfig
FF_SEND_V2 = getenv("LANGGRAPH_FF_SEND_V2", "false").lower() == "true"
# temporary flag to enable new Send semantics
NULL_TASK_ID = sys.intern("00000000-0000-0000-0000-000000000000")
# the task_id to use for writes that are not associated with a task

RESERVED = {
TAG_HIDDEN,
# reserved write keys
INPUT,
INTERRUPT,
RESUME,
ERROR,
NO_WRITES,
SCHEDULED,
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class NodeInterrupt(GraphInterrupt):
"""Raised by a node to interrupt execution."""

def __init__(self, value: Any) -> None:
super().__init__([Interrupt(value)])
super().__init__([Interrupt(value=value)])


class GraphDelegate(Exception):
Expand Down
30 changes: 14 additions & 16 deletions libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import inspect
import logging
import typing
Expand All @@ -14,7 +15,6 @@
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
get_args,
Expand Down Expand Up @@ -50,15 +50,13 @@
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import All, Checkpointer, Command, RetryPolicy
from langgraph.types import _DC_KWARGS, All, Checkpointer, Command, N, RetryPolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable

logger = logging.getLogger(__name__)

N = TypeVar("N")


def _warn_invalid_state_schema(schema: Union[Type[Any], Any]) -> None:
if isinstance(schema, type):
Expand All @@ -81,20 +79,20 @@ def _get_node_name(node: RunnableLike) -> str:
raise TypeError(f"Unsupported node type: {type(node)}")


class GraphCommand(Command, Generic[N]):
@dataclasses.dataclass(**_DC_KWARGS)
class GraphCommand(Generic[N], Command[N]):
"""One or more commands to update a StateGraph's state and go to, or send messages to nodes."""

__slots__ = ("goto",)
goto: Union[str, Sequence[str]] = ()

def __init__(
self,
*,
update: Optional[dict[str, Any]] = None,
goto: Union[str, Sequence[str]] = (),
send: Union[Send, Sequence[Send]] = (),
) -> None:
super().__init__(update=update, send=send)
self.goto = goto
def __repr__(self) -> str:
# get all non-None values
contents = ", ".join(
f"{key}={value!r}"
for key, value in dataclasses.asdict(self).items()
if value
)
return f"Command({contents})"


class StateNodeSpec(NamedTuple):
Expand Down Expand Up @@ -389,7 +387,7 @@ def add_node(
input = input_hint
if (
(rtn := hints.get("return"))
and get_origin(rtn) is GraphCommand
and get_origin(rtn) in (Command, GraphCommand)
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
Expand Down
Loading

0 comments on commit f111276

Please sign in to comment.