diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index aed19736d..dd8e86834 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -7,7 +7,7 @@ from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack from functools import partial from types import TracebackType -from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple +from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type from langchain_core.runnables import RunnableConfig @@ -71,16 +71,16 @@ def __init__( self, *, serde: Optional[SerializerProtocol] = None, - factory=defaultdict, + factory: Type[defaultdict] = defaultdict, ) -> None: super().__init__(serde=serde) self.storage = factory(lambda: defaultdict(dict)) self.writes = factory(dict) self.stack = ExitStack() try: - self.stack.enter_context(self.storage) - self.stack.enter_context(self.writes) - except TypeError: + self.stack.enter_context(self.storage) # type: ignore[arg-type] + self.stack.enter_context(self.writes) # type: ignore[arg-type] + except (TypeError, AttributeError): pass def __enter__(self) -> "MemorySaver": @@ -506,14 +506,14 @@ class PersistentDict(defaultdict): """ - def __init__(self, *args, filename: str, **kwds): + def __init__(self, *args: Any, filename: str, **kwds: Any) -> None: self.flag = "c" # r=readonly, c=create, or n=new self.mode = None # None or an octal triple like 0644 self.format = "pickle" # 'csv', 'json', or 'pickle' self.filename = filename super().__init__(*args, **kwds) - def sync(self): + def sync(self) -> None: "Write dict to disk" if self.flag == "r": return @@ -531,23 +531,23 @@ def sync(self): if self.mode is not None: os.chmod(self.filename, self.mode) - def close(self): + def close(self) -> None: self.sync() self.clear() - def __enter__(self): + def __enter__(self) -> "PersistentDict": return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any) -> None: self.close() - def dump(self, fileobj): + def dump(self, fileobj: Any) -> None: if self.format == "pickle": pickle.dump(dict(self), fileobj, 2) else: raise NotImplementedError("Unknown format: " + repr(self.format)) - def load(self): + def load(self) -> None: # try formats from most restrictive to least restrictive if self.flag == "n": return diff --git a/libs/checkpoint/langgraph/checkpoint/serde/base.py b/libs/checkpoint/langgraph/checkpoint/serde/base.py index de82b80c0..229837735 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/base.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/base.py @@ -12,17 +12,13 @@ class SerializerProtocol(Protocol): Valid implementations include the `pickle`, `json` and `orjson` modules. """ - def dumps(self, obj: Any) -> bytes: - ... + def dumps(self, obj: Any) -> bytes: ... - def dumps_typed(self, obj: Any) -> tuple[str, bytes]: - ... + def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ... - def loads(self, data: bytes) -> Any: - ... + def loads(self, data: bytes) -> Any: ... - def loads_typed(self, data: tuple[str, bytes]) -> Any: - ... + def loads_typed(self, data: tuple[str, bytes]) -> Any: ... class SerializerCompat(SerializerProtocol): diff --git a/libs/checkpoint/langgraph/checkpoint/serde/types.py b/libs/checkpoint/langgraph/checkpoint/serde/types.py index 039f654af..1df967b5f 100644 --- a/libs/checkpoint/langgraph/checkpoint/serde/types.py +++ b/libs/checkpoint/langgraph/checkpoint/serde/types.py @@ -23,27 +23,20 @@ class ChannelProtocol(Protocol[Value, Update, C]): # Mirrors langgraph.channels.base.BaseChannel @property - def ValueType(self) -> Any: - ... + def ValueType(self) -> Any: ... @property - def UpdateType(self) -> Any: - ... + def UpdateType(self) -> Any: ... - def checkpoint(self) -> Optional[C]: - ... + def checkpoint(self) -> Optional[C]: ... - def from_checkpoint(self, checkpoint: Optional[C]) -> Self: - ... + def from_checkpoint(self, checkpoint: Optional[C]) -> Self: ... - def update(self, values: Sequence[Update]) -> bool: - ... + def update(self, values: Sequence[Update]) -> bool: ... - def get(self) -> Value: - ... + def get(self) -> Value: ... - def consume(self) -> bool: - ... + def consume(self) -> bool: ... @runtime_checkable @@ -52,11 +45,8 @@ class SendProtocol(Protocol): node: str arg: Any - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... - def __eq__(self, value: object) -> bool: - ... + def __eq__(self, value: object) -> bool: ... diff --git a/libs/checkpoint/tests/test_store.py b/libs/checkpoint/tests/test_store.py index d9fbc5084..9d06281d0 100644 --- a/libs/checkpoint/tests/test_store.py +++ b/libs/checkpoint/tests/test_store.py @@ -314,7 +314,7 @@ async def test_cannot_put_empty_namespace() -> None: assert store.get(("langgraph", "foo"), "bar") is None class MockAsyncBatchedStore(AsyncBatchedBaseStore): - def __init__(self): + def __init__(self) -> None: super().__init__() self._store = InMemoryStore() @@ -340,13 +340,17 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: await async_store.aput(("langgraph", "foo"), "bar", doc) await async_store.aput(("foo", "langgraph", "foo"), "bar", doc) - assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")).value == doc + val = await async_store.aget(("foo", "langgraph", "foo"), "bar") + assert val is not None + assert val.value == doc assert (await async_store.asearch(("foo", "langgraph", "foo")))[0].value == doc await async_store.adelete(("foo", "langgraph", "foo"), "bar") assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")) is None await async_store.abatch([PutOp(("valid", "namespace"), "key", doc)]) - assert (await async_store.aget(("valid", "namespace"), "key")).value == doc + val = await async_store.aget(("valid", "namespace"), "key") + assert val is not None + assert val.value == doc assert (await async_store.asearch(("valid", "namespace")))[0].value == doc await async_store.adelete(("valid", "namespace"), "key") assert (await async_store.aget(("valid", "namespace"), "key")) is None