Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 17, 2024
1 parent 66465ef commit 66f97a9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 43 deletions.
24 changes: 12 additions & 12 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
from functools import partial
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type

from langchain_core.runnables import RunnableConfig

Expand Down Expand Up @@ -71,16 +71,16 @@ def __init__(
self,
*,
serde: Optional[SerializerProtocol] = None,
factory=defaultdict,
factory: Type[defaultdict] = defaultdict,
) -> None:
super().__init__(serde=serde)
self.storage = factory(lambda: defaultdict(dict))
self.writes = factory(dict)
self.stack = ExitStack()
try:
self.stack.enter_context(self.storage)
self.stack.enter_context(self.writes)
except TypeError:
self.stack.enter_context(self.storage) # type: ignore[arg-type]
self.stack.enter_context(self.writes) # type: ignore[arg-type]
except (TypeError, AttributeError):
pass

def __enter__(self) -> "MemorySaver":
Expand Down Expand Up @@ -506,14 +506,14 @@ class PersistentDict(defaultdict):
"""

def __init__(self, *args, filename: str, **kwds):
def __init__(self, *args: Any, filename: str, **kwds: Any) -> None:
self.flag = "c" # r=readonly, c=create, or n=new
self.mode = None # None or an octal triple like 0644
self.format = "pickle" # 'csv', 'json', or 'pickle'
self.filename = filename
super().__init__(*args, **kwds)

def sync(self):
def sync(self) -> None:
"Write dict to disk"
if self.flag == "r":
return
Expand All @@ -531,23 +531,23 @@ def sync(self):
if self.mode is not None:
os.chmod(self.filename, self.mode)

def close(self):
def close(self) -> None:
self.sync()
self.clear()

def __enter__(self):
def __enter__(self) -> "PersistentDict":
return self

def __exit__(self, *exc_info):
def __exit__(self, *exc_info: Any) -> None:
self.close()

def dump(self, fileobj):
def dump(self, fileobj: Any) -> None:
if self.format == "pickle":
pickle.dump(dict(self), fileobj, 2)
else:
raise NotImplementedError("Unknown format: " + repr(self.format))

def load(self):
def load(self) -> None:
# try formats from most restrictive to least restrictive
if self.flag == "n":
return
Expand Down
12 changes: 4 additions & 8 deletions libs/checkpoint/langgraph/checkpoint/serde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,13 @@ class SerializerProtocol(Protocol):
Valid implementations include the `pickle`, `json` and `orjson` modules.
"""

def dumps(self, obj: Any) -> bytes:
...
def dumps(self, obj: Any) -> bytes: ...

def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
...
def dumps_typed(self, obj: Any) -> tuple[str, bytes]: ...

def loads(self, data: bytes) -> Any:
...
def loads(self, data: bytes) -> Any: ...

def loads_typed(self, data: tuple[str, bytes]) -> Any:
...
def loads_typed(self, data: tuple[str, bytes]) -> Any: ...


class SerializerCompat(SerializerProtocol):
Expand Down
30 changes: 10 additions & 20 deletions libs/checkpoint/langgraph/checkpoint/serde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,20 @@
class ChannelProtocol(Protocol[Value, Update, C]):
# Mirrors langgraph.channels.base.BaseChannel
@property
def ValueType(self) -> Any:
...
def ValueType(self) -> Any: ...

@property
def UpdateType(self) -> Any:
...
def UpdateType(self) -> Any: ...

def checkpoint(self) -> Optional[C]:
...
def checkpoint(self) -> Optional[C]: ...

def from_checkpoint(self, checkpoint: Optional[C]) -> Self:
...
def from_checkpoint(self, checkpoint: Optional[C]) -> Self: ...

def update(self, values: Sequence[Update]) -> bool:
...
def update(self, values: Sequence[Update]) -> bool: ...

def get(self) -> Value:
...
def get(self) -> Value: ...

def consume(self) -> bool:
...
def consume(self) -> bool: ...


@runtime_checkable
Expand All @@ -52,11 +45,8 @@ class SendProtocol(Protocol):
node: str
arg: Any

def __hash__(self) -> int:
...
def __hash__(self) -> int: ...

def __repr__(self) -> str:
...
def __repr__(self) -> str: ...

def __eq__(self, value: object) -> bool:
...
def __eq__(self, value: object) -> bool: ...
10 changes: 7 additions & 3 deletions libs/checkpoint/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async def test_cannot_put_empty_namespace() -> None:
assert store.get(("langgraph", "foo"), "bar") is None

class MockAsyncBatchedStore(AsyncBatchedBaseStore):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._store = InMemoryStore()

Expand All @@ -340,13 +340,17 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]:
await async_store.aput(("langgraph", "foo"), "bar", doc)

await async_store.aput(("foo", "langgraph", "foo"), "bar", doc)
assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")).value == doc
val = await async_store.aget(("foo", "langgraph", "foo"), "bar")
assert val is not None
assert val.value == doc
assert (await async_store.asearch(("foo", "langgraph", "foo")))[0].value == doc
await async_store.adelete(("foo", "langgraph", "foo"), "bar")
assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")) is None

await async_store.abatch([PutOp(("valid", "namespace"), "key", doc)])
assert (await async_store.aget(("valid", "namespace"), "key")).value == doc
val = await async_store.aget(("valid", "namespace"), "key")
assert val is not None
assert val.value == doc
assert (await async_store.asearch(("valid", "namespace")))[0].value == doc
await async_store.adelete(("valid", "namespace"), "key")
assert (await async_store.aget(("valid", "namespace"), "key")) is None

0 comments on commit 66f97a9

Please sign in to comment.