diff --git a/libs/langgraph/langgraph/pregel/remote.py b/libs/langgraph/langgraph/pregel/remote.py index f7ae8c9b6..48a594066 100644 --- a/libs/langgraph/langgraph/pregel/remote.py +++ b/libs/langgraph/langgraph/pregel/remote.py @@ -67,10 +67,28 @@ def __init__( """ self.name = name self.config = config - self.client = client or get_client(url=url, api_key=api_key, headers=headers) - self.sync_client = sync_client or get_sync_client( - url=url, api_key=api_key, headers=headers - ) + + if client is None and url is not None: + client = get_client(url=url, api_key=api_key, headers=headers) + self.client = client + + if sync_client is None and url is not None: + sync_client = get_sync_client(url=url, api_key=api_key, headers=headers) + self.sync_client = sync_client + + def _validate_client(self) -> LangGraphClient: + if self.client is None: + raise ValueError( + "Async client is not initialized: please provide `url` or `client` when initializing `RemoteGraph`." + ) + return self.client + + def _validate_sync_client(self) -> SyncLangGraphClient: + if self.sync_client is None: + raise ValueError( + "Sync client is not initialized: please provide `url` or `sync_client` when initializing `RemoteGraph`." + ) + return self.sync_client def copy(self, update: dict[str, Any]) -> Self: attrs = {**self.__dict__, **update} @@ -103,7 +121,8 @@ def get_graph( *, xray: Union[int, bool] = False, ) -> DrawableGraph: - graph = self.sync_client.assistants.get_graph( + sync_client = self._validate_sync_client() + graph = sync_client.assistants.get_graph( assistant_id=self.name, xray=xray, ) @@ -118,7 +137,8 @@ async def aget_graph( *, xray: Union[int, bool] = False, ) -> DrawableGraph: - graph = await self.client.assistants.get_graph( + client = self._validate_client() + graph = await client.assistants.get_graph( assistant_id=self.name, xray=xray, ) @@ -248,9 +268,10 @@ def _sanitize_obj(obj: Any) -> Any: def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: + sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) - state = self.sync_client.threads.get_state( + state = sync_client.threads.get_state( thread_id=merged_config["configurable"]["thread_id"], checkpoint=self._get_checkpoint(merged_config), subgraphs=subgraphs, @@ -260,9 +281,10 @@ def get_state( async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: + client = self._validate_client() merged_config = merge_configs(self.config, config) - state = await self.client.threads.get_state( + state = await client.threads.get_state( thread_id=merged_config["configurable"]["thread_id"], checkpoint=self._get_checkpoint(merged_config), subgraphs=subgraphs, @@ -277,9 +299,10 @@ def get_state_history( before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> Iterator[StateSnapshot]: + sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) - states = self.sync_client.threads.get_history( + states = sync_client.threads.get_history( thread_id=merged_config["configurable"]["thread_id"], limit=limit if limit else 10, before=self._get_checkpoint(before), @@ -297,9 +320,10 @@ async def aget_state_history( before: Optional[RunnableConfig] = None, limit: Optional[int] = None, ) -> AsyncIterator[StateSnapshot]: + client = self._validate_client() merged_config = merge_configs(self.config, config) - states = await self.client.threads.get_history( + states = await client.threads.get_history( thread_id=merged_config["configurable"]["thread_id"], limit=limit if limit else 10, before=self._get_checkpoint(before), @@ -315,9 +339,10 @@ def update_state( values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None, ) -> RunnableConfig: + sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) - response: dict = self.sync_client.threads.update_state( # type: ignore + response: dict = sync_client.threads.update_state( # type: ignore thread_id=merged_config["configurable"]["thread_id"], values=values, as_node=as_node, @@ -331,9 +356,10 @@ async def aupdate_state( values: Optional[Union[dict[str, Any], Any]], as_node: Optional[str] = None, ) -> RunnableConfig: + client = self._validate_client() merged_config = merge_configs(self.config, config) - response: dict = await self.client.threads.update_state( # type: ignore + response: dict = await client.threads.update_state( # type: ignore thread_id=merged_config["configurable"]["thread_id"], values=values, as_node=as_node, @@ -382,11 +408,12 @@ def stream( interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, ) -> Iterator[Union[dict[str, Any], Any]]: + sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) - for chunk in self.sync_client.runs.stream( + for chunk in sync_client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, @@ -429,11 +456,12 @@ async def astream( interrupt_after: Optional[Union[All, Sequence[str]]] = None, subgraphs: bool = False, ) -> AsyncIterator[Union[dict[str, Any], Any]]: + client = self._validate_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) stream_modes, req_updates, req_single = self._get_stream_modes(stream_mode) - async for chunk in self.client.runs.stream( + async for chunk in client.runs.stream( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, @@ -490,10 +518,11 @@ def invoke( interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, ) -> Union[dict[str, Any], Any]: + sync_client = self._validate_sync_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) - return self.sync_client.runs.wait( + return sync_client.runs.wait( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input, @@ -511,10 +540,11 @@ async def ainvoke( interrupt_before: Optional[Union[All, Sequence[str]]] = None, interrupt_after: Optional[Union[All, Sequence[str]]] = None, ) -> Union[dict[str, Any], Any]: + client = self._validate_client() merged_config = merge_configs(self.config, config) sanitized_config = self._sanitize_config(merged_config) - return await self.client.runs.wait( + return await client.runs.wait( thread_id=sanitized_config["configurable"].get("thread_id"), assistant_id=self.name, input=input,