Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstring to RemoteGraph #2217

Merged
merged 7 commits into from
Oct 29, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 193 additions & 1 deletion libs/langgraph/langgraph/pregel/remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (

Check notice on line 1 in libs/langgraph/langgraph/pregel/remote.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 48.1 ms +- 0.8 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 45.8 ms +- 3.0 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 75.9 ms +- 1.5 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 83.7 ms +- 0.9 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 470 ms +- 11 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 423 ms +- 5 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 788 ms +- 36 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 823 ms +- 16 ms ......................................... react_agent_10x: Mean +- std dev: 28.9 ms +- 0.7 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.2 ms +- 1.6 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.1 ms +- 3.1 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 37.1 ms +- 3.2 ms ......................................... react_agent_100x: Mean +- std dev: 327 ms +- 13 ms ......................................... react_agent_100x_sync: Mean +- std dev: 260 ms +- 12 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 905 ms +- 9 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 811 ms +- 7 ms ......................................... wide_state_25x300: Mean +- std dev: 18.1 ms +- 0.3 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 10.8 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 272 ms +- 5 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 259 ms +- 2 ms ......................................... wide_state_15x600: Mean +- std dev: 20.9 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 12.4 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 470 ms +- 9 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 465 ms +- 13 ms ......................................... wide_state_9x1200: Mean +- std dev: 20.9 ms +- 0.4 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 12.5 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 303 ms +- 3 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 300 ms +- 13 ms

Check notice on line 1 in libs/langgraph/langgraph/pregel/remote.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | fanout_to_subgraph_100x_checkpoint | 887 ms | 788 ms: 1.13x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 48.8 ms | 47.1 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 854 ms | 823 ms: 1.04x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 485 ms | 470 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.9 ms | 22.2 ms: 1.03x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 21.4 ms | 20.9 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint_sync | 265 ms | 259 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 21.3 ms | 20.9 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 306 ms | 300 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 474 ms | 465 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 85.3 ms | 83.7 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 12.7 ms | 12.5 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 308 ms | 303 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 478 ms | 470 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 824 ms | 811 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 18.4 ms | 18.1 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x | 331 ms | 327 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 12.6 ms | 12.4 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 275 ms | 272 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x | 29.3 ms | 28.9 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 10.9 ms | 10.8 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 428 ms | 423 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 915 ms | 905 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint | 76.7 ms | 75.9 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 48.4 ms | 48.1 ms: 1.01x faster | +---------------------------------------
Any,
AsyncIterator,
Iterator,
Expand Down Expand Up @@ -46,6 +46,16 @@


class RemoteGraph(PregelProtocol):
"""The RemoteGraph class is a client implementation for calling remote
APIs that implement the Open Agent API Specification.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Open Agent API Specification"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed to LangGraph Server API specification.


For example, the RemoteGraph class can be used to call LangGraph Cloud
APIs.

RemoteGraph behaves the same way as a Graph and can be used directly as
a node in another Graph.
"""

name: str

def __init__(
Expand All @@ -63,7 +73,17 @@
"""Specify `url`, `api_key`, and/or `headers` to create default sync and async clients.

If `client` or `sync_client` are provided, they will be used instead of the default clients.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would also add that if at least one of the 3 (url/client/sync_client) needs to be specified. something similar to how i describe in this how-to guide maybe https://github.com/langchain-ai/langgraph/pull/2218/files

See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients.
See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients. At least
one of `url`, `client`, or `sync_client` must be provided.

Args:
name: The name of the graph.
url: The URL of the remote API.
api_key: The API key to use for authentication. If not provided, it will be read from the environment (LANGGRAPH_API_KEY, LANGSMITH_API_KEY, or LANGCHAIN_API_KEY).
headers: Additional headers to include in the requests.
client: A LangGraphClient instance to use instead of creating a default client.
sync_client: A SyncLangGraphClient instance to use instead of creating a default client.
config: An optional RunnableConfig instance with additional configuration.
"""
self.name = name
self.config = config
Expand Down Expand Up @@ -121,6 +141,19 @@
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
"""Get graph by graph name.

This method calls GET /assistants/{assistant_id}/graph.

Args:
config: This parameter is not used.
xray: Include graph representation of subgraphs. If an integer
value is provided, only subgraphs with a depth less than or
equal to the value will be included.

Returns:
DrawableGraph: The graph information for the assistant in JSON format.
"""
sync_client = self._validate_sync_client()
graph = sync_client.assistants.get_graph(
assistant_id=self.name,
Expand All @@ -137,6 +170,19 @@
*,
xray: Union[int, bool] = False,
) -> DrawableGraph:
"""Get graph by graph name.

This method calls GET /assistants/{assistant_id}/graph.

Args:
config: This parameter is not used.
xray: Include graph representation of subgraphs. If an integer
value is provided, only subgraphs with a depth less than or
equal to the value will be included.

Returns:
DrawableGraph: The graph information for the assistant in JSON format.
"""
client = self._validate_client()
graph = await client.assistants.get_graph(
assistant_id=self.name,
Expand Down Expand Up @@ -268,6 +314,20 @@
def get_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
"""Get the state of a thread.

This method calls POST /threads/{thread_id}/state/checkpoint if a
checkpoint is specified in the config or GET /threads/{thread_id}/state
if no checkpoint is specified.

Args:
config: A RunnableConfig that includes `thread_id` in the
`configurable` field.
subgraphs: Include subgraphs in the state.

Returns:
StateSnapshot: The latest state of the thread.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)

Expand All @@ -281,6 +341,20 @@
async def aget_state(
self, config: RunnableConfig, *, subgraphs: bool = False
) -> StateSnapshot:
"""Get the state of a thread.

This method calls POST /threads/{thread_id}/state/checkpoint if a
checkpoint is specified in the config or GET /threads/{thread_id}/state
if no checkpoint is specified.

Args:
config: A RunnableConfig that includes `thread_id` in the
`configurable` field.
subgraphs: Include subgraphs in the state.

Returns:
StateSnapshot: The latest state of the thread.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)

Expand All @@ -299,6 +373,20 @@
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[StateSnapshot]:
"""Get the state history of a thread.

This method calls POST /threads/{thread_id}/history.

Args:
config: A RunnableConfig that includes `thread_id` in the
`configurable` field.
filter: Metadata to filter on.
before: A RunnableConfig that includes checkpoint metadata.
limit: Max number of states to return.

Returns:
Iterator[StateSnapshot]: States of the thread.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)

Expand All @@ -320,6 +408,20 @@
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[StateSnapshot]:
"""Get the state history of a thread.

This method calls POST /threads/{thread_id}/history.

Args:
config: A RunnableConfig that includes `thread_id` in the
`configurable` field.
filter: Metadata to filter on.
before: A RunnableConfig that includes checkpoint metadata.
limit: Max number of states to return.

Returns:
Iterator[StateSnapshot]: States of the thread.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)

Expand All @@ -339,6 +441,19 @@
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig:
"""Update the state of a thread.

This method calls POST /threads/{thread_id}/state.

Args:
config: A RunnableConfig that includes `thread_id` in the
`configurable` field.
values: Values to update to the state.
as_node: Update the state as if this node had just executed.

Returns:
RunnableConfig: RunnableConfig for the updated thread.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)

Expand All @@ -356,6 +471,19 @@
values: Optional[Union[dict[str, Any], Any]],
as_node: Optional[str] = None,
) -> RunnableConfig:
"""Update the state of a thread.

This method calls POST /threads/{thread_id}/state.

Args:
config: A RunnableConfig that includes `thread_id` in the
`configurable` field.
values: Values to update to the state.
as_node: Update the state as if this node had just executed.

Returns:
RunnableConfig: RunnableConfig for the updated thread.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)

Expand Down Expand Up @@ -408,6 +536,23 @@
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
) -> Iterator[Union[dict[str, Any], Any]]:
"""Create a run and stream the results.

This method calls POST /threads/{thread_id}/runs/stream if a `thread_id`
is speciffed in the `configurable` field of the config or
POST /runs/stream otherwise.

Args:
input: Input to the graph.
config: A RunnableConfig for graph invocation.
stream_mode: Stream mode(s) to use.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.
subgraphs: Stream from subgraphs.

Yields:
Iterator[Union[dict[str, Any], Any]]: The output of the graph.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
Expand Down Expand Up @@ -456,6 +601,23 @@
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
subgraphs: bool = False,
) -> AsyncIterator[Union[dict[str, Any], Any]]:
"""Create a run and stream the results.

This method calls POST /threads/{thread_id}/runs/stream if a `thread_id`
is speciffed in the `configurable` field of the config or
POST /runs/stream otherwise.

Args:
input: Input to the graph.
config: A RunnableConfig for graph invocation.
stream_mode: Stream mode(s) to use.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.
subgraphs: Stream from subgraphs.

Yields:
Iterator[Union[dict[str, Any], Any]]: The output of the graph.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
Expand Down Expand Up @@ -518,6 +680,21 @@
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]:
"""Create a run, wait until it finishes and return the final state.

This method calls POST /threads/{thread_id}/runs/wait if a `thread_id`
is speciffed in the `configurable` field of the config or
POST /runs/wait otherwise.

Args:
input: Input to the graph.
config: A RunnableConfig for graph invocation.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.

Returns:
Union[dict[str, Any], Any]: The output of the graph.
"""
sync_client = self._validate_sync_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
Expand All @@ -540,6 +717,21 @@
interrupt_before: Optional[Union[All, Sequence[str]]] = None,
interrupt_after: Optional[Union[All, Sequence[str]]] = None,
) -> Union[dict[str, Any], Any]:
"""Create a run, wait until it finishes and return the final state.

This method calls POST /threads/{thread_id}/runs/wait if a `thread_id`
is speciffed in the `configurable` field of the config or
POST /runs/wait otherwise.

Args:
input: Input to the graph.
config: A RunnableConfig for graph invocation.
interrupt_before: Interrupt the graph before these nodes.
interrupt_after: Interrupt the graph after these nodes.

Returns:
Union[dict[str, Any], Any]: The output of the graph.
"""
client = self._validate_client()
merged_config = merge_configs(self.config, config)
sanitized_config = self._sanitize_config(merged_config)
Expand Down
Loading