Skip to content

Commit

Permalink
ci: Enable mypy checks for checkpoint-sqlite lib
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 19, 2024
1 parent b8a8651 commit 6c0339c
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 53 deletions.
3 changes: 2 additions & 1 deletion libs/checkpoint-sqlite/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ lint lint_diff lint_package lint_tests:
poetry run ruff check .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
poetry run ruff format $(PYTHON_FILES)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)


class SqliteSaver(BaseCheckpointSaver):
class SqliteSaver(BaseCheckpointSaver[str]):
"""A checkpoint saver that stores checkpoints in a SQLite database.
Note:
Expand Down Expand Up @@ -487,6 +487,7 @@ async def aput(
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database asynchronously.
Expand Down
5 changes: 3 additions & 2 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Expand All @@ -30,10 +31,10 @@
from langgraph.checkpoint.serde.types import ChannelProtocol
from langgraph.checkpoint.sqlite.utils import search_where

T = TypeVar("T", bound=callable)
T = TypeVar("T", bound=Callable)


class AsyncSqliteSaver(BaseCheckpointSaver):
class AsyncSqliteSaver(BaseCheckpointSaver[str]):
"""An asynchronous checkpoint saver that stores checkpoints in a SQLite database.
This class provides an asynchronous interface for saving and retrieving checkpoints
Expand Down
56 changes: 28 additions & 28 deletions libs/checkpoint-sqlite/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions libs/checkpoint-sqlite/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,13 @@ now = true
delay = 0.1
runner_args = ["--ff", "-v", "--tb", "short"]
patterns = ["*.py"]

[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"
14 changes: 8 additions & 6 deletions libs/checkpoint-sqlite/tests/test_aiosqlite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import pytest
from langchain_core.runnables import RunnableConfig

Expand All @@ -12,7 +14,7 @@

class TestAsyncSqliteSaver:
@pytest.fixture(autouse=True)
def setup(self):
def setup(self) -> None:
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
Expand Down Expand Up @@ -55,20 +57,20 @@ def setup(self):
}
self.metadata_3: CheckpointMetadata = {}

async def test_asearch(self):
async def test_asearch(self) -> None:
async with AsyncSqliteSaver.from_conn_string(":memory:") as saver:
await saver.aput(self.config_1, self.chkpnt_1, self.metadata_1, {})
await saver.aput(self.config_2, self.chkpnt_2, self.metadata_2, {})
await saver.aput(self.config_3, self.chkpnt_3, self.metadata_3, {})

# call method / assertions
query_1: CheckpointMetadata = {"source": "input"} # search by 1 key
query_2: CheckpointMetadata = {
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: CheckpointMetadata = {} # search by no keys, return all checkpoints
query_4: CheckpointMetadata = {"source": "update", "step": 1} # no match
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
assert len(search_results_1) == 1
Expand Down
34 changes: 19 additions & 15 deletions libs/checkpoint-sqlite/tests/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, cast

import pytest
from langchain_core.runnables import RunnableConfig

Expand All @@ -13,7 +15,7 @@

class TestSqliteSaver:
@pytest.fixture(autouse=True)
def setup(self):
def setup(self) -> None:
# objects for test setup
self.config_1: RunnableConfig = {
"configurable": {
Expand Down Expand Up @@ -56,7 +58,7 @@ def setup(self):
}
self.metadata_3: CheckpointMetadata = {}

def test_search(self):
def test_search(self) -> None:
with SqliteSaver.from_conn_string(":memory:") as saver:
# set up test
# save checkpoints
Expand All @@ -65,13 +67,13 @@ def test_search(self):
saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})

# call method / assertions
query_1: CheckpointMetadata = {"source": "input"} # search by 1 key
query_2: CheckpointMetadata = {
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: CheckpointMetadata = {} # search by no keys, return all checkpoints
query_4: CheckpointMetadata = {"source": "update", "step": 1} # no match
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = list(saver.list(None, filter=query_1))
assert len(search_results_1) == 1
Expand Down Expand Up @@ -99,16 +101,18 @@ def test_search(self):

# TODO: test before and limit params

def test_search_where(self):
def test_search_where(self) -> None:
# call method / assertions
expected_predicate_1 = "WHERE json_extract(CAST(metadata AS TEXT), '$.source') = ? AND json_extract(CAST(metadata AS TEXT), '$.step') = ? AND json_extract(CAST(metadata AS TEXT), '$.writes') = ? AND json_extract(CAST(metadata AS TEXT), '$.score') = ? AND checkpoint_id < ?"
expected_param_values_1 = ["input", 2, "{}", 1, "1"]
assert search_where(None, self.metadata_1, self.config_1) == (
assert search_where(
None, cast(dict[str, Any], self.metadata_1), self.config_1
) == (
expected_predicate_1,
expected_param_values_1,
)

def test_metadata_predicate(self):
def test_metadata_predicate(self) -> None:
# call method / assertions
expected_predicate_1 = [
"json_extract(CAST(metadata AS TEXT), '$.source') = ?",
Expand All @@ -122,26 +126,26 @@ def test_metadata_predicate(self):
"json_extract(CAST(metadata AS TEXT), '$.writes') = ?",
"json_extract(CAST(metadata AS TEXT), '$.score') IS ?",
]
expected_predicate_3 = []
expected_predicate_3: list[str] = []

expected_param_values_1 = ["input", 2, "{}", 1]
expected_param_values_2 = ["loop", 1, '{"foo":"bar"}', None]
expected_param_values_3 = []
expected_param_values_3: list[Any] = []

assert _metadata_predicate(self.metadata_1) == (
assert _metadata_predicate(cast(dict[str, Any], self.metadata_1)) == (
expected_predicate_1,
expected_param_values_1,
)
assert _metadata_predicate(self.metadata_2) == (
assert _metadata_predicate(cast(dict[str, Any], self.metadata_2)) == (
expected_predicate_2,
expected_param_values_2,
)
assert _metadata_predicate(self.metadata_3) == (
assert _metadata_predicate(cast(dict[str, Any], self.metadata_3)) == (
expected_predicate_3,
expected_param_values_3,
)

async def test_informative_async_errors(self):
async def test_informative_async_errors(self) -> None:
with SqliteSaver.from_conn_string(":memory:") as saver:
# call method / assertions
with pytest.raises(NotImplementedError, match="AsyncSqliteSaver"):
Expand Down

0 comments on commit 6c0339c

Please sign in to comment.