diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 01d53feab..5436d426d 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -484,7 +484,7 @@ def _prepare_state_snapshot( patch_checkpoint_map(saved.config, saved.metadata), saved.metadata, saved.checkpoint["ts"], - saved.parent_config, + patch_checkpoint_map(saved.parent_config, saved.metadata), tasks_w_writes( next_tasks.values(), saved.pending_writes, @@ -565,7 +565,7 @@ async def _aprepare_state_snapshot( patch_checkpoint_map(saved.config, saved.metadata), saved.metadata, saved.checkpoint["ts"], - saved.parent_config, + patch_checkpoint_map(saved.parent_config, saved.metadata), tasks_w_writes( next_tasks.values(), saved.pending_writes, diff --git a/libs/langgraph/langgraph/pregel/debug.py b/libs/langgraph/langgraph/pregel/debug.py index f70a6e8c4..d772e7cba 100644 --- a/libs/langgraph/langgraph/pregel/debug.py +++ b/libs/langgraph/langgraph/pregel/debug.py @@ -32,6 +32,7 @@ from langgraph.pregel.io import read_channels from langgraph.pregel.utils import find_subgraph_pregel from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot +from langgraph.utils.config import patch_checkpoint_map class TaskPayload(TypedDict): @@ -177,8 +178,8 @@ def map_debug_checkpoint( "timestamp": checkpoint["ts"], "step": step, "payload": { - "config": config, - "parent_config": parent_config, + "config": patch_checkpoint_map(config, metadata), + "parent_config": patch_checkpoint_map(parent_config, metadata), "values": read_channels(channels, stream_channels), "metadata": metadata, "next": [t.name for t in tasks], diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index dadc54b7a..d4f3a52c3 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -502,7 +502,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: ) def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: - # assign step + # assign step and parents metadata["step"] = self.step metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {}) # debug flag @@ -518,13 +518,14 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: self.checkpoint = create_checkpoint(self.checkpoint, self.channels, self.step) # bail if no checkpointer if self._checkpointer_put_after_previous is not None: + self.checkpoint_metadata = metadata + self.prev_checkpoint_config = ( self.checkpoint_config if CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF] and self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID] else None ) - self.checkpoint_metadata = metadata self.checkpoint_config = { **self.checkpoint_config, CONF: { diff --git a/libs/langgraph/langgraph/utils/config.py b/libs/langgraph/langgraph/utils/config.py index 3993df9b2..fe25b6d9a 100644 --- a/libs/langgraph/langgraph/utils/config.py +++ b/libs/langgraph/langgraph/utils/config.py @@ -36,9 +36,11 @@ def patch_configurable( def patch_checkpoint_map( - config: RunnableConfig, metadata: Optional[CheckpointMetadata] + config: Optional[RunnableConfig], metadata: Optional[CheckpointMetadata] ) -> RunnableConfig: - if parents := (metadata.get("parents") if metadata else None): + if config is None: + return config + elif parents := (metadata.get("parents") if metadata else None): conf = config[CONF] return patch_configurable( config, diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 348922e2b..a73083290 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -9092,6 +9092,9 @@ def outer_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, ), @@ -9250,6 +9253,9 @@ def outer_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), @@ -9279,6 +9285,9 @@ def outer_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, tasks=( @@ -9707,6 +9716,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, ) @@ -9770,6 +9786,15 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr( + re.compile(r"child:.+|child1:") + ): AnyStr(), + } + ), } }, ), @@ -9798,6 +9823,9 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, ), @@ -10056,6 +10084,9 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, tasks=(), @@ -10085,6 +10116,9 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, tasks=( @@ -10170,6 +10204,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, tasks=(), @@ -10208,6 +10249,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, tasks=( @@ -10253,6 +10301,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, tasks=( @@ -10444,6 +10499,12 @@ def edit(state: JokeState): "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("generate_joke:"): AnyStr(), + } + ), } }, tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), @@ -10476,6 +10537,12 @@ def edit(state: JokeState): "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("generate_joke:"): AnyStr(), + } + ), } }, tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), @@ -10931,6 +10998,12 @@ def weather_graph(state: RouterState): "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), } }, tasks=( @@ -11020,6 +11093,12 @@ def weather_graph(state: RouterState): "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), } }, tasks=(), @@ -11917,6 +11996,8 @@ def normalize_config(config: Optional[dict]) -> Optional[dict]: clean_config["thread_id"] = config["configurable"]["thread_id"] clean_config["checkpoint_id"] = config["configurable"]["checkpoint_id"] clean_config["checkpoint_ns"] = config["configurable"]["checkpoint_ns"] + if "checkpoint_map" in config["configurable"]: + clean_config["checkpoint_map"] = config["configurable"]["checkpoint_map"] return clean_config diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 96317d041..1c45d414f 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -7738,6 +7738,9 @@ def outer_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, ), @@ -7903,6 +7906,9 @@ def outer_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("inner:"): AnyStr()} + ), } }, tasks=( @@ -7934,6 +7940,9 @@ def outer_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("inner:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("inner:"): AnyStr()} + ), } }, tasks=( @@ -8328,6 +8337,9 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, ).tasks[0] @@ -8374,6 +8386,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, ) @@ -8439,6 +8458,15 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr( + re.compile(r"child:.+|child1:") + ): AnyStr(), + } + ), } }, ), @@ -8467,6 +8495,9 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, ), @@ -8732,6 +8763,9 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, tasks=(), @@ -8761,6 +8795,9 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr("child:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), } }, tasks=( @@ -8850,6 +8887,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, tasks=(), @@ -8888,6 +8932,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, tasks=( @@ -8933,6 +8984,13 @@ def parent_2(state: State): "thread_id": "1", "checkpoint_ns": AnyStr(), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), } }, tasks=( @@ -9568,6 +9626,12 @@ def get_first_in_list(): "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), } }, tasks=( @@ -9659,6 +9723,12 @@ def get_first_in_list(): "thread_id": "14", "checkpoint_ns": AnyStr("weather_graph:"), "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), } }, tasks=(), @@ -10143,6 +10213,8 @@ def normalize_config(config: Optional[dict]) -> Optional[dict]: clean_config["thread_id"] = config["configurable"]["thread_id"] clean_config["checkpoint_id"] = config["configurable"]["checkpoint_id"] clean_config["checkpoint_ns"] = config["configurable"]["checkpoint_ns"] + if "checkpoint_map" in config["configurable"]: + clean_config["checkpoint_map"] = config["configurable"]["checkpoint_map"] return clean_config