diff --git a/libs/checkpoint-duckdb/langgraph/store/duckdb/__init__.py b/libs/checkpoint-duckdb/langgraph/store/duckdb/__init__.py new file mode 100644 index 000000000..058c64bf7 --- /dev/null +++ b/libs/checkpoint-duckdb/langgraph/store/duckdb/__init__.py @@ -0,0 +1,4 @@ +from langgraph.store.duckdb.aio import AsyncDuckDBStore +from langgraph.store.duckdb.base import DuckDBStore + +__all__ = ["AsyncDuckDBStore", "DuckDBStore"] diff --git a/libs/checkpoint-duckdb/langgraph/store/duckdb/aio.py b/libs/checkpoint-duckdb/langgraph/store/duckdb/aio.py new file mode 100644 index 000000000..d6fd7dd89 --- /dev/null +++ b/libs/checkpoint-duckdb/langgraph/store/duckdb/aio.py @@ -0,0 +1,195 @@ +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import ( + AsyncIterator, + Iterable, + Sequence, + cast, +) + +import duckdb +from langgraph.store.base import GetOp, ListNamespacesOp, Op, PutOp, Result, SearchOp +from langgraph.store.base.batch import AsyncBatchedBaseStore +from langgraph.store.duckdb.base import ( + BaseDuckDBStore, + _convert_ns, + _group_ops, + _row_to_item, +) + +logger = logging.getLogger(__name__) + + +class AsyncDuckDBStore(AsyncBatchedBaseStore, BaseDuckDBStore): + def __init__( + self, + conn: duckdb.DuckDBPyConnection, + ) -> None: + super().__init__() + self.conn = conn + self.loop = asyncio.get_running_loop() + + async def abatch(self, ops: Iterable[Op]) -> list[Result]: + grouped_ops, num_ops = _group_ops(ops) + results: list[Result] = [None] * num_ops + + tasks = [] + + if GetOp in grouped_ops: + tasks.append( + self._batch_get_ops( + cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results + ) + ) + + if PutOp in grouped_ops: + tasks.append( + self._batch_put_ops( + cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]) + ) + ) + + if SearchOp in grouped_ops: + tasks.append( + self._batch_search_ops( + cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), + results, + ) + ) + + if ListNamespacesOp in grouped_ops: + tasks.append( + self._batch_list_namespaces_ops( + cast( + Sequence[tuple[int, ListNamespacesOp]], + grouped_ops[ListNamespacesOp], + ), + results, + ) + ) + + await asyncio.gather(*tasks) + + return results + + def batch(self, ops: Iterable[Op]) -> list[Result]: + return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() + + async def _batch_get_ops( + self, + get_ops: Sequence[tuple[int, GetOp]], + results: list[Result], + ) -> None: + cursors = [] + for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): + cur = self.conn.cursor() + await asyncio.to_thread(cur.execute, query, params) + cursors.append((cur, namespace, items)) + + for cur, namespace, items in cursors: + rows = await asyncio.to_thread(cur.fetchall) + key_to_row = {row[1]: row for row in rows} + for idx, key in items: + row = key_to_row.get(key) + if row: + results[idx] = _row_to_item(namespace, row) + else: + results[idx] = None + + async def _batch_put_ops( + self, + put_ops: Sequence[tuple[int, PutOp]], + ) -> None: + queries = self._get_batch_PUT_queries(put_ops) + for query, params in queries: + cur = self.conn.cursor() + await asyncio.to_thread(cur.execute, query, params) + + async def _batch_search_ops( + self, + search_ops: Sequence[tuple[int, SearchOp]], + results: list[Result], + ) -> None: + queries = self._get_batch_search_queries(search_ops) + cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] + + for (query, params), (idx, _) in zip(queries, search_ops): + cur = self.conn.cursor() + await asyncio.to_thread(cur.execute, query, params) + cursors.append((cur, idx)) + + for cur, idx in cursors: + rows = await asyncio.to_thread(cur.fetchall) + items = [_row_to_item(_convert_ns(row[0]), row) for row in rows] + results[idx] = items + + async def _batch_list_namespaces_ops( + self, + list_ops: Sequence[tuple[int, ListNamespacesOp]], + results: list[Result], + ) -> None: + queries = self._get_batch_list_namespaces_queries(list_ops) + cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] + for (query, params), (idx, _) in zip(queries, list_ops): + cur = self.conn.cursor() + await asyncio.to_thread(cur.execute, query, params) + cursors.append((cur, idx)) + + for cur, idx in cursors: + rows = cast(list[tuple], await asyncio.to_thread(cur.fetchall)) + namespaces = [_convert_ns(row[0]) for row in rows] + results[idx] = namespaces + + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, + conn_string: str, + ) -> AsyncIterator["AsyncDuckDBStore"]: + """Create a new AsyncDuckDBStore instance from a connection string. + + Args: + conn_string (str): The DuckDB connection info string. + + Returns: + AsyncDuckDBStore: A new AsyncDuckDBStore instance. + """ + with duckdb.connect(conn_string) as conn: + yield AsyncDuckDBStore(conn) + + async def setup(self) -> None: + """Set up the store database asynchronously. + + This method creates the necessary tables in the DuckDB database if they don't + already exist and runs database migrations. It is called automatically when needed and should not be called + directly by the user. + """ + cur = self.conn.cursor() + try: + await asyncio.to_thread( + cur.execute, "SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1" + ) + row = await asyncio.to_thread(cur.fetchone) + if row is None: + version = -1 + else: + version = row[0] + except duckdb.CatalogException: + version = -1 + # Create store_migrations table if it doesn't exist + await asyncio.to_thread( + cur.execute, + """ + CREATE TABLE IF NOT EXISTS store_migrations ( + v INTEGER PRIMARY KEY + ) + """, + ) + for v, migration in enumerate( + self.MIGRATIONS[version + 1 :], start=version + 1 + ): + await asyncio.to_thread(cur.execute, migration) + await asyncio.to_thread( + cur.execute, "INSERT INTO store_migrations (v) VALUES (?)", (v,) + ) diff --git a/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py b/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py new file mode 100644 index 000000000..e0fb57067 --- /dev/null +++ b/libs/checkpoint-duckdb/langgraph/store/duckdb/base.py @@ -0,0 +1,391 @@ +import asyncio +import json +import logging +from collections import defaultdict +from contextlib import contextmanager +from typing import ( + Any, + Generic, + Iterable, + Iterator, + Sequence, + TypeVar, + Union, + cast, +) + +import duckdb +from langgraph.store.base import ( + BaseStore, + GetOp, + Item, + ListNamespacesOp, + Op, + PutOp, + Result, + SearchOp, +) + +logger = logging.getLogger(__name__) + + +MIGRATIONS = [ + """ +CREATE TABLE IF NOT EXISTS store ( + prefix TEXT NOT NULL, + key TEXT NOT NULL, + value JSON NOT NULL, + created_at TIMESTAMP DEFAULT now(), + updated_at TIMESTAMP DEFAULT now(), + PRIMARY KEY (prefix, key) +); +""", + """ +CREATE INDEX IF NOT EXISTS store_prefix_idx ON store (prefix); +""", +] + +C = TypeVar("C", bound=duckdb.DuckDBPyConnection) + + +class BaseDuckDBStore(Generic[C]): + MIGRATIONS = MIGRATIONS + conn: C + + def _get_batch_GET_ops_queries( + self, + get_ops: Sequence[tuple[int, GetOp]], + ) -> list[tuple[str, tuple, tuple[str, ...], list]]: + namespace_groups = defaultdict(list) + for idx, op in get_ops: + namespace_groups[op.namespace].append((idx, op.key)) + results = [] + for namespace, items in namespace_groups.items(): + _, keys = zip(*items) + keys_to_query = ",".join(["?"] * len(keys)) + query = f""" + SELECT prefix, key, value, created_at, updated_at + FROM store + WHERE prefix = ? AND key IN ({keys_to_query}) + """ + params = (_namespace_to_text(namespace), *keys) + results.append((query, params, namespace, items)) + return results + + def _get_batch_PUT_queries( + self, + put_ops: Sequence[tuple[int, PutOp]], + ) -> list[tuple[str, Sequence]]: + inserts: list[PutOp] = [] + deletes: list[PutOp] = [] + for _, op in put_ops: + if op.value is None: + deletes.append(op) + else: + inserts.append(op) + + queries: list[tuple[str, Sequence]] = [] + + if deletes: + namespace_groups: dict[tuple[str, ...], list[str]] = defaultdict(list) + for op in deletes: + namespace_groups[op.namespace].append(op.key) + for namespace, keys in namespace_groups.items(): + placeholders = ",".join(["?"] * len(keys)) + query = ( + f"DELETE FROM store WHERE prefix = ? AND key IN ({placeholders})" + ) + params = (_namespace_to_text(namespace), *keys) + queries.append((query, params)) + if inserts: + values = [] + insertion_params = [] + for op in inserts: + values.append("(?, ?, ?, now(), now())") + insertion_params.extend( + [ + _namespace_to_text(op.namespace), + op.key, + json.dumps(op.value), + ] + ) + values_str = ",".join(values) + query = f""" + INSERT INTO store (prefix, key, value, created_at, updated_at) + VALUES {values_str} + ON CONFLICT (prefix, key) DO UPDATE + SET value = EXCLUDED.value, updated_at = now() + """ + queries.append((query, insertion_params)) + + return queries + + def _get_batch_search_queries( + self, + search_ops: Sequence[tuple[int, SearchOp]], + ) -> list[tuple[str, Sequence]]: + queries: list[tuple[str, Sequence]] = [] + for _, op in search_ops: + query = """ + SELECT prefix, key, value, created_at, updated_at + FROM store + WHERE prefix LIKE ? + """ + params: list = [f"{_namespace_to_text(op.namespace_prefix)}%"] + + if op.filter: + filter_conditions = [] + for key, value in op.filter.items(): + filter_conditions.append(f"json_extract(value, '$.{key}') = ?") + params.append(json.dumps(value)) + query += " AND " + " AND ".join(filter_conditions) + + query += " ORDER BY updated_at DESC LIMIT ? OFFSET ?" + params.extend([op.limit, op.offset]) + + queries.append((query, params)) + return queries + + def _get_batch_list_namespaces_queries( + self, + list_ops: Sequence[tuple[int, ListNamespacesOp]], + ) -> list[tuple[str, Sequence]]: + queries: list[tuple[str, Sequence]] = [] + for _, op in list_ops: + query = """ + WITH split_prefix AS ( + SELECT + prefix, + string_split(prefix, '.') AS parts + FROM store + ) + SELECT DISTINCT ON (truncated_prefix) + CASE + WHEN ? IS NOT NULL THEN + array_to_string(array_slice(parts, 1, ?), '.') + ELSE prefix + END AS truncated_prefix, + prefix + FROM split_prefix + """ + params: list[Any] = [op.max_depth, op.max_depth] + + conditions = [] + if op.match_conditions: + for condition in op.match_conditions: + if condition.match_type == "prefix": + conditions.append("prefix LIKE ?") + params.append( + f"{_namespace_to_text(condition.path, handle_wildcards=True)}%" + ) + elif condition.match_type == "suffix": + conditions.append("prefix LIKE ?") + params.append( + f"%{_namespace_to_text(condition.path, handle_wildcards=True)}" + ) + else: + logger.warning( + f"Unknown match_type in list_namespaces: {condition.match_type}" + ) + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += " ORDER BY prefix LIMIT ? OFFSET ?" + params.extend([op.limit, op.offset]) + queries.append((query, params)) + + return queries + + +class DuckDBStore(BaseStore, BaseDuckDBStore[duckdb.DuckDBPyConnection]): + def __init__( + self, + conn: duckdb.DuckDBPyConnection, + ) -> None: + super().__init__() + self.conn = conn + + def batch(self, ops: Iterable[Op]) -> list[Result]: + grouped_ops, num_ops = _group_ops(ops) + results: list[Result] = [None] * num_ops + + if GetOp in grouped_ops: + self._batch_get_ops( + cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results + ) + + if PutOp in grouped_ops: + self._batch_put_ops(cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp])) + + if SearchOp in grouped_ops: + self._batch_search_ops( + cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), + results, + ) + + if ListNamespacesOp in grouped_ops: + self._batch_list_namespaces_ops( + cast( + Sequence[tuple[int, ListNamespacesOp]], + grouped_ops[ListNamespacesOp], + ), + results, + ) + + return results + + async def abatch(self, ops: Iterable[Op]) -> list[Result]: + return await asyncio.get_running_loop().run_in_executor(None, self.batch, ops) + + def _batch_get_ops( + self, + get_ops: Sequence[tuple[int, GetOp]], + results: list[Result], + ) -> None: + cursors = [] + for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): + cur = self.conn.cursor() + cur.execute(query, params) + cursors.append((cur, namespace, items)) + + for cur, namespace, items in cursors: + rows = cur.fetchall() + key_to_row = {row[1]: row for row in rows} + for idx, key in items: + row = key_to_row.get(key) + if row: + results[idx] = _row_to_item(namespace, row) + else: + results[idx] = None + + def _batch_put_ops( + self, + put_ops: Sequence[tuple[int, PutOp]], + ) -> None: + queries = self._get_batch_PUT_queries(put_ops) + for query, params in queries: + cur = self.conn.cursor() + cur.execute(query, params) + + def _batch_search_ops( + self, + search_ops: Sequence[tuple[int, SearchOp]], + results: list[Result], + ) -> None: + queries = self._get_batch_search_queries(search_ops) + cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] + + for (query, params), (idx, _) in zip(queries, search_ops): + cur = self.conn.cursor() + cur.execute(query, params) + cursors.append((cur, idx)) + + for cur, idx in cursors: + rows = cur.fetchall() + items = [_row_to_item(_convert_ns(row[0]), row) for row in rows] + results[idx] = items + + def _batch_list_namespaces_ops( + self, + list_ops: Sequence[tuple[int, ListNamespacesOp]], + results: list[Result], + ) -> None: + queries = self._get_batch_list_namespaces_queries(list_ops) + cursors: list[tuple[duckdb.DuckDBPyConnection, int]] = [] + for (query, params), (idx, _) in zip(queries, list_ops): + cur = self.conn.cursor() + cur.execute(query, params) + cursors.append((cur, idx)) + + for cur, idx in cursors: + rows = cast(list[dict], cur.fetchall()) + namespaces = [_convert_ns(row[0]) for row in rows] + results[idx] = namespaces + + @classmethod + @contextmanager + def from_conn_string( + cls, + conn_string: str, + ) -> Iterator["DuckDBStore"]: + """Create a new BaseDuckDBStore instance from a connection string. + + Args: + conn_string (str): The DuckDB connection info string. + + Returns: + DuckDBStore: A new DuckDBStore instance. + """ + with duckdb.connect(conn_string) as conn: + yield cls(conn=conn) + + def setup(self) -> None: + """Set up the store database. + + This method creates the necessary tables in the DuckDB database if they don't + already exist and runs database migrations. It is called automatically when needed and should not be called + directly by the user. + """ + with self.conn.cursor() as cur: + try: + cur.execute("SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1") + row = cast(dict, cur.fetchone()) + if row is None: + version = -1 + else: + version = row["v"] + except duckdb.CatalogException: + version = -1 + # Create store_migrations table if it doesn't exist + cur.execute( + """ + CREATE TABLE IF NOT EXISTS store_migrations ( + v INTEGER PRIMARY KEY + ) + """ + ) + for v, migration in enumerate( + self.MIGRATIONS[version + 1 :], start=version + 1 + ): + cur.execute(migration) + cur.execute("INSERT INTO store_migrations (v) VALUES (?)", (v,)) + + +def _namespace_to_text( + namespace: tuple[str, ...], handle_wildcards: bool = False +) -> str: + """Convert namespace tuple to text string.""" + if handle_wildcards: + namespace = tuple("%" if val == "*" else val for val in namespace) + return ".".join(namespace) + + +def _row_to_item( + namespace: tuple[str, ...], + row: tuple, +) -> Item: + """Convert a row from the database into an Item.""" + _, key, val, created_at, updated_at = row + return Item( + value=val if isinstance(val, dict) else json.loads(val), + key=key, + namespace=namespace, + created_at=created_at, + updated_at=updated_at, + ) + + +def _group_ops(ops: Iterable[Op]) -> tuple[dict[type, list[tuple[int, Op]]], int]: + grouped_ops: dict[type, list[tuple[int, Op]]] = defaultdict(list) + tot = 0 + for idx, op in enumerate(ops): + grouped_ops[type(op)].append((idx, op)) + tot += 1 + return grouped_ops, tot + + +def _convert_ns(namespace: Union[str, list]) -> tuple[str, ...]: + if isinstance(namespace, list): + return tuple(namespace) + return tuple(namespace.split(".")) diff --git a/libs/checkpoint-duckdb/langgraph/store/py.typed b/libs/checkpoint-duckdb/langgraph/store/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/libs/checkpoint-duckdb/pyproject.toml b/libs/checkpoint-duckdb/pyproject.toml index 89453d914..5dde4a33b 100644 --- a/libs/checkpoint-duckdb/pyproject.toml +++ b/libs/checkpoint-duckdb/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langgraph-checkpoint-duckdb" -version = "1.0.0" +version = "2.0.0" description = "Library with a DuckDB implementation of LangGraph checkpoint saver." authors = [] license = "MIT" diff --git a/libs/checkpoint-duckdb/tests/test_async_store.py b/libs/checkpoint-duckdb/tests/test_async_store.py new file mode 100644 index 000000000..140807c08 --- /dev/null +++ b/libs/checkpoint-duckdb/tests/test_async_store.py @@ -0,0 +1,517 @@ +# type: ignore +import uuid +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp +from langgraph.store.duckdb import AsyncDuckDBStore + + +class MockCursor: + def __init__(self, fetch_result: Any) -> None: + self.fetch_result = fetch_result + self.execute = MagicMock() + self.fetchall = MagicMock(return_value=self.fetch_result) + + +class MockConnection: + def __init__(self) -> None: + self.cursor = MagicMock() + + +@pytest.fixture +def mock_connection() -> MockConnection: + return MockConnection() + + +@pytest.fixture +async def store(mock_connection: MockConnection) -> AsyncDuckDBStore: + duck_db_store = AsyncDuckDBStore(mock_connection) + await duck_db_store.setup() + return duck_db_store + + +async def test_abatch_order(store: AsyncDuckDBStore) -> None: + mock_connection = store.conn + mock_get_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ( + "test.bar", + "key2", + '{"data": "value2"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_search_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_list_namespaces_cursor = MockCursor( + [ + ("test",), + ] + ) + + failures = [] + + def cursor_side_effect() -> Any: + cursor = MagicMock() + + def execute_side_effect(query: str, *params: Any) -> None: + # My super sophisticated database. + if "WHERE prefix = ? AND key" in query: + cursor.fetchall = mock_get_cursor.fetchall + elif "SELECT prefix, key, value" in query: + cursor.fetchall = mock_search_cursor.fetchall + elif "SELECT DISTINCT ON (truncated_prefix)" in query: + cursor.fetchall = mock_list_namespaces_cursor.fetchall + elif "INSERT INTO " in query: + pass + else: + e = ValueError(f"Unmatched query: {query}") + failures.append(e) + raise e + + cursor.execute = MagicMock(side_effect=execute_side_effect) + return cursor + + mock_connection.cursor.side_effect = cursor_side_effect # type: ignore + + ops = [ + GetOp(namespace=("test",), key="key1"), + PutOp(namespace=("test",), key="key2", value={"data": "value2"}), + SearchOp( + namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 + ), + ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0), + GetOp(namespace=("test",), key="key3"), + ] + results = await store.abatch(ops) + assert not failures + assert len(results) == 5 + assert isinstance(results[0], Item) + assert isinstance(results[0].value, dict) + assert results[0].value == {"data": "value1"} + assert results[0].key == "key1" + assert results[1] is None + assert isinstance(results[2], list) + assert len(results[2]) == 1 + assert isinstance(results[3], list) + assert results[3] == [("test",)] + assert results[4] is None + + ops_reordered = [ + SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), + GetOp(namespace=("test",), key="key2"), + ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0), + PutOp(namespace=("test",), key="key3", value={"data": "value3"}), + GetOp(namespace=("test",), key="key1"), + ] + + results_reordered = await store.abatch(ops_reordered) + assert not failures + assert len(results_reordered) == 5 + assert isinstance(results_reordered[0], list) + assert len(results_reordered[0]) == 1 + assert isinstance(results_reordered[1], Item) + assert results_reordered[1].value == {"data": "value2"} + assert results_reordered[1].key == "key2" + assert isinstance(results_reordered[2], list) + assert results_reordered[2] == [("test",)] + assert results_reordered[3] is None + assert isinstance(results_reordered[4], Item) + assert results_reordered[4].value == {"data": "value1"} + assert results_reordered[4].key == "key1" + + +async def test_batch_get_ops(store: AsyncDuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ( + "test.bar", + "key2", + '{"data": "value2"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_connection.cursor.return_value = mock_cursor + + ops = [ + GetOp(namespace=("test",), key="key1"), + GetOp(namespace=("test",), key="key2"), + GetOp(namespace=("test",), key="key3"), + ] + + results = await store.abatch(ops) + + assert len(results) == 3 + assert results[0] is not None + assert results[1] is not None + assert results[2] is None + assert results[0].key == "key1" + assert results[1].key == "key2" + + +async def test_batch_put_ops(store: AsyncDuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor([]) + mock_connection.cursor.return_value = mock_cursor + + ops = [ + PutOp(namespace=("test",), key="key1", value={"data": "value1"}), + PutOp(namespace=("test",), key="key2", value={"data": "value2"}), + PutOp(namespace=("test",), key="key3", value=None), + ] + + results = await store.abatch(ops) + + assert len(results) == 3 + assert all(result is None for result in results) + assert mock_cursor.execute.call_count == 2 + + +async def test_batch_search_ops(store: AsyncDuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ( + "test.bar", + "key2", + '{"data": "value2"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_connection.cursor.return_value = mock_cursor + + ops = [ + SearchOp( + namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 + ), + SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), + ] + + results = await store.abatch(ops) + + assert len(results) == 2 + assert len(results[0]) == 2 + assert len(results[1]) == 2 + + +async def test_batch_list_namespaces_ops(store: AsyncDuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor([("test.namespace1",), ("test.namespace2",)]) + mock_connection.cursor.return_value = mock_cursor + + ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)] + + results = await store.abatch(ops) + + assert len(results) == 1 + assert results[0] == [("test", "namespace1"), ("test", "namespace2")] + + +# The following use the actual DB connection + + +async def test_basic_store_ops() -> None: + async with AsyncDuckDBStore.from_conn_string(":memory:") as store: + await store.setup() + namespace = ("test", "documents") + item_id = "doc1" + item_value = {"title": "Test Document", "content": "Hello, World!"} + + await store.aput(namespace, item_id, item_value) + item = await store.aget(namespace, item_id) + + assert item + assert item.namespace == namespace + assert item.key == item_id + assert item.value == item_value + + updated_value = { + "title": "Updated Test Document", + "content": "Hello, LangGraph!", + } + await store.aput(namespace, item_id, updated_value) + updated_item = await store.aget(namespace, item_id) + + assert updated_item.value == updated_value + assert updated_item.updated_at > item.updated_at + different_namespace = ("test", "other_documents") + item_in_different_namespace = await store.aget(different_namespace, item_id) + assert item_in_different_namespace is None + + new_item_id = "doc2" + new_item_value = {"title": "Another Document", "content": "Greetings!"} + await store.aput(namespace, new_item_id, new_item_value) + + search_results = await store.asearch(["test"], limit=10) + items = search_results + assert len(items) == 2 + assert any(item.key == item_id for item in items) + assert any(item.key == new_item_id for item in items) + + namespaces = await store.alist_namespaces(prefix=["test"]) + assert ("test", "documents") in namespaces + + await store.adelete(namespace, item_id) + await store.adelete(namespace, new_item_id) + deleted_item = await store.aget(namespace, item_id) + assert deleted_item is None + + deleted_item = await store.aget(namespace, new_item_id) + assert deleted_item is None + + empty_search_results = await store.asearch(["test"], limit=10) + assert len(empty_search_results) == 0 + + +async def test_list_namespaces() -> None: + async with AsyncDuckDBStore.from_conn_string(":memory:") as store: + await store.setup() + test_pref = str(uuid.uuid4()) + test_namespaces = [ + (test_pref, "test", "documents", "public", test_pref), + (test_pref, "test", "documents", "private", test_pref), + (test_pref, "test", "images", "public", test_pref), + (test_pref, "test", "images", "private", test_pref), + (test_pref, "prod", "documents", "public", test_pref), + ( + test_pref, + "prod", + "documents", + "some", + "nesting", + "public", + test_pref, + ), + (test_pref, "prod", "documents", "private", test_pref), + ] + + for namespace in test_namespaces: + await store.aput(namespace, "dummy", {"content": "dummy"}) + + prefix_result = await store.alist_namespaces(prefix=[test_pref, "test"]) + assert len(prefix_result) == 4 + assert all([ns[1] == "test" for ns in prefix_result]) + + specific_prefix_result = await store.alist_namespaces( + prefix=[test_pref, "test", "documents"] + ) + assert len(specific_prefix_result) == 2 + assert all([ns[1:3] == ("test", "documents") for ns in specific_prefix_result]) + + suffix_result = await store.alist_namespaces(suffix=["public", test_pref]) + assert len(suffix_result) == 4 + assert all(ns[-2] == "public" for ns in suffix_result) + + prefix_suffix_result = await store.alist_namespaces( + prefix=[test_pref, "test"], suffix=["public", test_pref] + ) + assert len(prefix_suffix_result) == 2 + assert all( + ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result + ) + + wildcard_prefix_result = await store.alist_namespaces( + prefix=[test_pref, "*", "documents"] + ) + assert len(wildcard_prefix_result) == 5 + assert all(ns[2] == "documents" for ns in wildcard_prefix_result) + + wildcard_suffix_result = await store.alist_namespaces( + suffix=["*", "public", test_pref] + ) + assert len(wildcard_suffix_result) == 4 + assert all(ns[-2] == "public" for ns in wildcard_suffix_result) + wildcard_single = await store.alist_namespaces( + suffix=["some", "*", "public", test_pref] + ) + assert len(wildcard_single) == 1 + assert wildcard_single[0] == ( + test_pref, + "prod", + "documents", + "some", + "nesting", + "public", + test_pref, + ) + + max_depth_result = await store.alist_namespaces(max_depth=3) + assert all([len(ns) <= 3 for ns in max_depth_result]) + max_depth_result = await store.alist_namespaces( + max_depth=4, prefix=[test_pref, "*", "documents"] + ) + assert ( + len(set(tuple(res) for res in max_depth_result)) + == len(max_depth_result) + == 5 + ) + + limit_result = await store.alist_namespaces(prefix=[test_pref], limit=3) + assert len(limit_result) == 3 + + offset_result = await store.alist_namespaces(prefix=[test_pref], offset=3) + assert len(offset_result) == len(test_namespaces) - 3 + + empty_prefix_result = await store.alist_namespaces(prefix=[test_pref]) + assert len(empty_prefix_result) == len(test_namespaces) + assert set(tuple(ns) for ns in empty_prefix_result) == set( + tuple(ns) for ns in test_namespaces + ) + + for namespace in test_namespaces: + await store.adelete(namespace, "dummy") + + +async def test_search(): + async with AsyncDuckDBStore.from_conn_string(":memory:") as store: + await store.setup() + test_namespaces = [ + ("test_search", "documents", "user1"), + ("test_search", "documents", "user2"), + ("test_search", "reports", "department1"), + ("test_search", "reports", "department2"), + ] + test_items = [ + {"title": "Doc 1", "author": "John Doe", "tags": ["important"]}, + {"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]}, + {"title": "Report A", "author": "John Doe", "tags": ["final"]}, + {"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]}, + ] + empty = await store.asearch( + ( + "scoped", + "assistant_id", + "shared", + "6c5356f6-63ab-4158-868d-cd9fd14c736e", + ), + limit=10, + offset=0, + ) + assert len(empty) == 0 + + for namespace, item in zip(test_namespaces, test_items): + await store.aput(namespace, f"item_{namespace[-1]}", item) + + docs_result = await store.asearch(["test_search", "documents"]) + assert len(docs_result) == 2 + assert all([item.namespace[1] == "documents" for item in docs_result]), [ + item.namespace for item in docs_result + ] + + reports_result = await store.asearch(["test_search", "reports"]) + assert len(reports_result) == 2 + assert all(item.namespace[1] == "reports" for item in reports_result) + + limited_result = await store.asearch(["test_search"], limit=2) + assert len(limited_result) == 2 + offset_result = await store.asearch(["test_search"]) + assert len(offset_result) == 4 + + offset_result = await store.asearch(["test_search"], offset=2) + assert len(offset_result) == 2 + assert all(item not in limited_result for item in offset_result) + + john_doe_result = await store.asearch( + ["test_search"], filter={"author": "John Doe"} + ) + assert len(john_doe_result) == 2 + assert all(item.value["author"] == "John Doe" for item in john_doe_result) + + draft_result = await store.asearch(["test_search"], filter={"tags": ["draft"]}) + assert len(draft_result) == 2 + assert all("draft" in item.value["tags"] for item in draft_result) + + page1 = await store.asearch(["test_search"], limit=2, offset=0) + page2 = await store.asearch(["test_search"], limit=2, offset=2) + all_items = page1 + page2 + assert len(all_items) == 4 + assert len(set(item.key for item in all_items)) == 4 + empty = await store.asearch( + ( + "scoped", + "assistant_id", + "shared", + "again", + "maybe", + "some-long", + "6be5cb0e-2eb4-42e6-bb6b-fba3c269db25", + ), + limit=10, + offset=0, + ) + assert len(empty) == 0 + + # Test with a namespace beginning with a number (like a UUID) + uuid_namespace = (str(uuid.uuid4()), "documents") + uuid_item_id = "uuid_doc" + uuid_item_value = { + "title": "UUID Document", + "content": "This document has a UUID namespace.", + } + + # Insert the item with the UUID namespace + await store.aput(uuid_namespace, uuid_item_id, uuid_item_value) + + # Retrieve the item to verify it was stored correctly + retrieved_item = await store.aget(uuid_namespace, uuid_item_id) + assert retrieved_item is not None + assert retrieved_item.namespace == uuid_namespace + assert retrieved_item.key == uuid_item_id + assert retrieved_item.value == uuid_item_value + + # Search for the item using the UUID namespace + search_result = await store.asearch([uuid_namespace[0]]) + assert len(search_result) == 1 + assert search_result[0].key == uuid_item_id + assert search_result[0].value == uuid_item_value + + # Clean up: delete the item with the UUID namespace + await store.adelete(uuid_namespace, uuid_item_id) + + # Verify the item was deleted + deleted_item = await store.aget(uuid_namespace, uuid_item_id) + assert deleted_item is None + + for namespace in test_namespaces: + await store.adelete(namespace, f"item_{namespace[-1]}") diff --git a/libs/checkpoint-duckdb/tests/test_store.py b/libs/checkpoint-duckdb/tests/test_store.py new file mode 100644 index 000000000..47d2d573c --- /dev/null +++ b/libs/checkpoint-duckdb/tests/test_store.py @@ -0,0 +1,457 @@ +# type: ignore +import uuid +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp +from langgraph.store.duckdb import DuckDBStore + + +class MockCursor: + def __init__(self, fetch_result: Any) -> None: + self.fetch_result = fetch_result + self.execute = MagicMock() + self.fetchall = MagicMock(return_value=self.fetch_result) + + +class MockConnection: + def __init__(self) -> None: + self.cursor = MagicMock() + + +@pytest.fixture +def mock_connection() -> MockConnection: + return MockConnection() + + +@pytest.fixture +def store(mock_connection: MockConnection) -> DuckDBStore: + duck_db_store = DuckDBStore(mock_connection) + duck_db_store.setup() + return duck_db_store + + +def test_batch_order(store: DuckDBStore) -> None: + mock_connection = store.conn + mock_get_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ( + "test.bar", + "key2", + '{"data": "value2"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_search_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_list_namespaces_cursor = MockCursor( + [ + ("test",), + ] + ) + + failures = [] + + def cursor_side_effect() -> Any: + cursor = MagicMock() + + def execute_side_effect(query: str, *params: Any) -> None: + # My super sophisticated database. + if "WHERE prefix = ? AND key" in query: + cursor.fetchall = mock_get_cursor.fetchall + elif "SELECT prefix, key, value" in query: + cursor.fetchall = mock_search_cursor.fetchall + elif "SELECT DISTINCT ON (truncated_prefix)" in query: + cursor.fetchall = mock_list_namespaces_cursor.fetchall + elif "INSERT INTO " in query: + pass + else: + e = ValueError(f"Unmatched query: {query}") + failures.append(e) + raise e + + cursor.execute = MagicMock(side_effect=execute_side_effect) + return cursor + + mock_connection.cursor.side_effect = cursor_side_effect + + ops = [ + GetOp(namespace=("test",), key="key1"), + PutOp(namespace=("test",), key="key2", value={"data": "value2"}), + SearchOp( + namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 + ), + ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0), + GetOp(namespace=("test",), key="key3"), + ] + results = store.batch(ops) + assert not failures + assert len(results) == 5 + assert isinstance(results[0], Item) + assert isinstance(results[0].value, dict) + assert results[0].value == {"data": "value1"} + assert results[0].key == "key1" + assert results[1] is None + assert isinstance(results[2], list) + assert len(results[2]) == 1 + assert isinstance(results[3], list) + assert results[3] == [("test",)] + assert results[4] is None + + ops_reordered = [ + SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), + GetOp(namespace=("test",), key="key2"), + ListNamespacesOp(match_conditions=None, max_depth=None, limit=5, offset=0), + PutOp(namespace=("test",), key="key3", value={"data": "value3"}), + GetOp(namespace=("test",), key="key1"), + ] + + results_reordered = store.batch(ops_reordered) + assert not failures + assert len(results_reordered) == 5 + assert isinstance(results_reordered[0], list) + assert len(results_reordered[0]) == 1 + assert isinstance(results_reordered[1], Item) + assert results_reordered[1].value == {"data": "value2"} + assert results_reordered[1].key == "key2" + assert isinstance(results_reordered[2], list) + assert results_reordered[2] == [("test",)] + assert results_reordered[3] is None + assert isinstance(results_reordered[4], Item) + assert results_reordered[4].value == {"data": "value1"} + assert results_reordered[4].key == "key1" + + +def test_batch_get_ops(store: DuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ( + "test.bar", + "key2", + '{"data": "value2"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_connection.cursor.return_value = mock_cursor + + ops = [ + GetOp(namespace=("test",), key="key1"), + GetOp(namespace=("test",), key="key2"), + GetOp(namespace=("test",), key="key3"), + ] + + results = store.batch(ops) + + assert len(results) == 3 + assert results[0] is not None + assert results[1] is not None + assert results[2] is None + assert results[0].key == "key1" + assert results[1].key == "key2" + + +def test_batch_put_ops(store: DuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor([]) + mock_connection.cursor.return_value = mock_cursor + + ops = [ + PutOp(namespace=("test",), key="key1", value={"data": "value1"}), + PutOp(namespace=("test",), key="key2", value={"data": "value2"}), + PutOp(namespace=("test",), key="key3", value=None), + ] + + results = store.batch(ops) + + assert len(results) == 3 + assert all(result is None for result in results) + assert mock_cursor.execute.call_count == 2 + + +def test_batch_search_ops(store: DuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor( + [ + ( + "test.foo", + "key1", + '{"data": "value1"}', + datetime.now(), + datetime.now(), + ), + ( + "test.bar", + "key2", + '{"data": "value2"}', + datetime.now(), + datetime.now(), + ), + ] + ) + mock_connection.cursor.return_value = mock_cursor + + ops = [ + SearchOp( + namespace_prefix=("test",), filter={"data": "value1"}, limit=10, offset=0 + ), + SearchOp(namespace_prefix=("test",), filter=None, limit=5, offset=0), + ] + + results = store.batch(ops) + + assert len(results) == 2 + assert len(results[0]) == 2 + assert len(results[1]) == 2 + + +def test_batch_list_namespaces_ops(store: DuckDBStore) -> None: + mock_connection = store.conn + mock_cursor = MockCursor([("test.namespace1",), ("test.namespace2",)]) + mock_connection.cursor.return_value = mock_cursor + + ops = [ListNamespacesOp(match_conditions=None, max_depth=None, limit=10, offset=0)] + + results = store.batch(ops) + + assert len(results) == 1 + assert results[0] == [("test", "namespace1"), ("test", "namespace2")] + + +def test_basic_store_ops() -> None: + with DuckDBStore.from_conn_string(":memory:") as store: + store.setup() + namespace = ("test", "documents") + item_id = "doc1" + item_value = {"title": "Test Document", "content": "Hello, World!"} + + store.put(namespace, item_id, item_value) + item = store.get(namespace, item_id) + + assert item + assert item.namespace == namespace + assert item.key == item_id + assert item.value == item_value + + updated_value = { + "title": "Updated Test Document", + "content": "Hello, LangGraph!", + } + store.put(namespace, item_id, updated_value) + updated_item = store.get(namespace, item_id) + + assert updated_item.value == updated_value + assert updated_item.updated_at > item.updated_at + different_namespace = ("test", "other_documents") + item_in_different_namespace = store.get(different_namespace, item_id) + assert item_in_different_namespace is None + + new_item_id = "doc2" + new_item_value = {"title": "Another Document", "content": "Greetings!"} + store.put(namespace, new_item_id, new_item_value) + + search_results = store.search(["test"], limit=10) + items = search_results + assert len(items) == 2 + assert any(item.key == item_id for item in items) + assert any(item.key == new_item_id for item in items) + + namespaces = store.list_namespaces(prefix=["test"]) + assert ("test", "documents") in namespaces + + store.delete(namespace, item_id) + store.delete(namespace, new_item_id) + deleted_item = store.get(namespace, item_id) + assert deleted_item is None + + deleted_item = store.get(namespace, new_item_id) + assert deleted_item is None + + empty_search_results = store.search(["test"], limit=10) + assert len(empty_search_results) == 0 + + +def test_list_namespaces() -> None: + with DuckDBStore.from_conn_string(":memory:") as store: + store.setup() + test_pref = str(uuid.uuid4()) + test_namespaces = [ + (test_pref, "test", "documents", "public", test_pref), + (test_pref, "test", "documents", "private", test_pref), + (test_pref, "test", "images", "public", test_pref), + (test_pref, "test", "images", "private", test_pref), + (test_pref, "prod", "documents", "public", test_pref), + ( + test_pref, + "prod", + "documents", + "some", + "nesting", + "public", + test_pref, + ), + (test_pref, "prod", "documents", "private", test_pref), + ] + + for namespace in test_namespaces: + store.put(namespace, "dummy", {"content": "dummy"}) + + prefix_result = store.list_namespaces(prefix=[test_pref, "test"]) + assert len(prefix_result) == 4 + assert all([ns[1] == "test" for ns in prefix_result]) + + specific_prefix_result = store.list_namespaces( + prefix=[test_pref, "test", "documents"] + ) + assert len(specific_prefix_result) == 2 + assert all([ns[1:3] == ("test", "documents") for ns in specific_prefix_result]) + + suffix_result = store.list_namespaces(suffix=["public", test_pref]) + assert len(suffix_result) == 4 + assert all(ns[-2] == "public" for ns in suffix_result) + + prefix_suffix_result = store.list_namespaces( + prefix=[test_pref, "test"], suffix=["public", test_pref] + ) + assert len(prefix_suffix_result) == 2 + assert all( + ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result + ) + + wildcard_prefix_result = store.list_namespaces( + prefix=[test_pref, "*", "documents"] + ) + assert len(wildcard_prefix_result) == 5 + assert all(ns[2] == "documents" for ns in wildcard_prefix_result) + + wildcard_suffix_result = store.list_namespaces( + suffix=["*", "public", test_pref] + ) + assert len(wildcard_suffix_result) == 4 + assert all(ns[-2] == "public" for ns in wildcard_suffix_result) + wildcard_single = store.list_namespaces( + suffix=["some", "*", "public", test_pref] + ) + assert len(wildcard_single) == 1 + assert wildcard_single[0] == ( + test_pref, + "prod", + "documents", + "some", + "nesting", + "public", + test_pref, + ) + + max_depth_result = store.list_namespaces(max_depth=3) + assert all([len(ns) <= 3 for ns in max_depth_result]) + + max_depth_result = store.list_namespaces( + max_depth=4, prefix=[test_pref, "*", "documents"] + ) + assert ( + len(set(tuple(res) for res in max_depth_result)) + == len(max_depth_result) + == 5 + ) + + limit_result = store.list_namespaces(prefix=[test_pref], limit=3) + assert len(limit_result) == 3 + + offset_result = store.list_namespaces(prefix=[test_pref], offset=3) + assert len(offset_result) == len(test_namespaces) - 3 + + empty_prefix_result = store.list_namespaces(prefix=[test_pref]) + assert len(empty_prefix_result) == len(test_namespaces) + assert set(tuple(ns) for ns in empty_prefix_result) == set( + tuple(ns) for ns in test_namespaces + ) + + for namespace in test_namespaces: + store.delete(namespace, "dummy") + + +def test_search(): + with DuckDBStore.from_conn_string(":memory:") as store: + store.setup() + test_namespaces = [ + ("test_search", "documents", "user1"), + ("test_search", "documents", "user2"), + ("test_search", "reports", "department1"), + ("test_search", "reports", "department2"), + ] + test_items = [ + {"title": "Doc 1", "author": "John Doe", "tags": ["important"]}, + {"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]}, + {"title": "Report A", "author": "John Doe", "tags": ["final"]}, + {"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]}, + ] + + for namespace, item in zip(test_namespaces, test_items): + store.put(namespace, f"item_{namespace[-1]}", item) + + docs_result = store.search(["test_search", "documents"]) + assert len(docs_result) == 2 + assert all( + [item.namespace[1] == "documents" for item in docs_result] + ), docs_result + + reports_result = store.search(["test_search", "reports"]) + assert len(reports_result) == 2 + assert all(item.namespace[1] == "reports" for item in reports_result) + + limited_result = store.search(["test_search"], limit=2) + assert len(limited_result) == 2 + offset_result = store.search(["test_search"]) + assert len(offset_result) == 4 + + offset_result = store.search(["test_search"], offset=2) + assert len(offset_result) == 2 + assert all(item not in limited_result for item in offset_result) + + john_doe_result = store.search(["test_search"], filter={"author": "John Doe"}) + assert len(john_doe_result) == 2 + assert all(item.value["author"] == "John Doe" for item in john_doe_result) + + draft_result = store.search(["test_search"], filter={"tags": ["draft"]}) + assert len(draft_result) == 2 + assert all("draft" in item.value["tags"] for item in draft_result) + + page1 = store.search(["test_search"], limit=2, offset=0) + page2 = store.search(["test_search"], limit=2, offset=2) + all_items = page1 + page2 + assert len(all_items) == 4 + assert len(set(item.key for item in all_items)) == 4 + + for namespace in test_namespaces: + store.delete(namespace, f"item_{namespace[-1]}") diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py index 71d497012..dda7321d0 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py @@ -44,7 +44,6 @@ def __init__( super().__init__() self._deserializer = deserializer self.conn = conn - self.conn = conn self.loop = asyncio.get_running_loop() async def abatch(self, ops: Iterable[Op]) -> list[Result]: diff --git a/libs/langgraph/tests/conftest.py b/libs/langgraph/tests/conftest.py index b8a00f1bd..bef083a53 100644 --- a/libs/langgraph/tests/conftest.py +++ b/libs/langgraph/tests/conftest.py @@ -18,6 +18,7 @@ from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver from langgraph.store.base import BaseStore +from langgraph.store.duckdb import AsyncDuckDBStore, DuckDBStore from langgraph.store.memory import InMemoryStore from langgraph.store.postgres import AsyncPostgresStore, PostgresStore from tests.memory_assert import MemorySaverAssertImmutable @@ -266,6 +267,13 @@ async def _store_postgres_aio(): await conn.execute(f"DROP DATABASE {database}") +@asynccontextmanager +async def _store_duckdb_aio(): + async with AsyncDuckDBStore.from_conn_string(":memory:") as store: + await store.setup() + yield store + + @pytest.fixture(scope="function") def store_postgres(): database = f"test_{uuid4().hex[:16]}" @@ -283,6 +291,13 @@ def store_postgres(): conn.execute(f"DROP DATABASE {database}") +@pytest.fixture(scope="function") +def store_duckdb(): + with DuckDBStore.from_conn_string(":memory:") as store: + store.setup() + yield store + + @pytest.fixture(scope="function") def store_in_memory(): yield InMemoryStore() @@ -297,6 +312,9 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: elif store_name == "postgres_aio": async with _store_postgres_aio() as store: yield store + elif store_name == "duckdb_aio": + async with _store_duckdb_aio() as store: + yield store else: raise NotImplementedError(f"Unknown store {store_name}") @@ -321,5 +339,5 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: *ALL_CHECKPOINTERS_ASYNC, None, ] -ALL_STORES_SYNC = ["in_memory", "postgres"] -ALL_STORES_ASYNC = ["in_memory", "postgres_aio"] +ALL_STORES_SYNC = ["in_memory", "postgres", "duckdb"] +ALL_STORES_ASYNC = ["in_memory", "postgres_aio", "duckdb_aio"]