From 9ef98e169f34986f05309f8b402b340f38493255 Mon Sep 17 00:00:00 2001 From: Davide Fusca Date: Sun, 25 Aug 2024 15:16:19 +0200 Subject: [PATCH 1/5] Add 2 memory endpoints --- core/cat/memory/vector_memory_collection.py | 14 ++ core/cat/routes/memory.py | 133 ++++++++++++++++ .../tests/routes/memory/test_memory_points.py | 146 ++++++++++++++++++ 3 files changed, 293 insertions(+) diff --git a/core/cat/memory/vector_memory_collection.py b/core/cat/memory/vector_memory_collection.py index 259e7c54..de5136b1 100644 --- a/core/cat/memory/vector_memory_collection.py +++ b/core/cat/memory/vector_memory_collection.py @@ -268,6 +268,20 @@ def get_all_points(self): ) return all_points + + # Retrieve a set of points with an optional offset and limit. + def get_all_points_with_offset(self, limit:int=10000, offset:str=None): + # Retrieve the points and the next offset. + # To retrieve the first page set offset equal to None + + all_points, next_page_offset = self.client.scroll( + collection_name=self.collection_name, + with_vectors=True, + offset=offset, # Start from the given offset, or the beginning if None. + limit=limit # Limit the number of points retrieved to the specified limit. + ) + + return (all_points, next_page_offset) def db_is_remote(self): return isinstance(self.client._client, QdrantRemote) diff --git a/core/cat/routes/memory.py b/core/cat/routes/memory.py index f6d1c3b0..6f44affd 100644 --- a/core/cat/routes/memory.py +++ b/core/cat/routes/memory.py @@ -277,3 +277,136 @@ async def get_conversation_history( """Get the specified user's conversation history from working memory""" return {"history": stray.working_memory.history} + +# GET all the points from a single collection +@router.get("/collections/{collection_id}/points") +async def get_collections_points( + request: Request, + collection_id: str, + limit:int=Query( + default=100, + description="How many points to return" + ), + offset:str = Query( + default=None, + description="If provided (or not empty string) - skip points with ids less than given `offset`" + ), + stray=Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)), +) -> Dict: + """Retrieve all the points from a single collection + + + Example + ---------- + ``` + collection = "declarative" + res = requests.get( + f"http://localhost:1865/memory/collections/{collection}/points", + ) + json = res.json() + points = json["points"] + + for point in points: + payload = point["payload"] + vector = point["vector"] + print(payload) + print(vector) + ``` + + Example using offset + ---------- + ``` + # get all the points with limit 10 + limit = 10 + next_offset = "" + collection = "declarative" + + while True: + res = requests.get( + f"http://localhost:1865/memory/collections/{collection}/points?limit={limit}&offset={next_offset}", + ) + json = res.json() + points = json["points"] + next_offset = json["next_offset"] + + for point in points: + payload = point["payload"] + vector = point["vector"] + print(payload) + print(vector) + + if next_offset is None: + break + ``` + """ + + # check if collection exists + collections = list(stray.memory.vectors.collections.keys()) + if collection_id not in collections: + raise HTTPException( + status_code=400, detail={"error": f"Collection does not exist. Avaliable collections: {collections}"} + ) + + # if offset is empty string set to null + if offset == "": + offset = None + + memory_collection = stray.memory.vectors.collections[collection_id] + points, next_offset = memory_collection.get_all_points_with_offset(limit=limit,offset=offset) + + return { + "points":points, + "next_offset":next_offset + } + + +# GET all the points from all the collections +@router.get("/collections/points") +async def get_all_points( + request: Request, + limit:int=Query( + default=100, + description="How many points to return" + ), + stray=Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)) +) -> Dict: + """Retrieve all the points from all the collections + + + Example + ---------- + ``` + # get all the points no limit, by default is 100 + res = requests.get( + f"http://localhost:1865/memory/collections/points", + ) + json = res.json() + + for collection in json: + points = json[collection]["points"] + print(f"Collection {collection}") + + for point in points: + payload = point["payload"] + vector = point["vector"] + print(payload) + print(vector) + ``` + + """ + + # check if collection exists + result = {} + + collections = list(stray.memory.vectors.collections.keys()) + for collection in collections: + #for each collection fetch all the points and next offset + + memory_collection = stray.memory.vectors.collections[collection] + + points, _ = memory_collection.get_all_points_with_offset(limit=limit) + result[collection] = { + "points":points + } + + return result \ No newline at end of file diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index e083ed8f..a6076755 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -1,6 +1,7 @@ import pytest from tests.utils import send_websocket_message, get_declarative_memory_contents from tests.conftest import FAKE_TIMESTAMP +import time def test_point_deleted(client): # send websocket message @@ -162,3 +163,148 @@ def test_create_memory_point(client, patch_time_now, collection): assert memory["metadata"] == expected_metadata +# utility function that validates a list of points against an expected points payload +def _check_points(points, expected_points_payload): + # check length + assert len(points) == len(expected_points_payload) + # check all the points contains id and vector + for point in points: + assert "id" in point + assert "vector" in point + + # check points payload + points_payloads = [p["payload"] for p in points] + # sort the list and compare payload + points_payloads.sort(key=lambda p:p["page_content"]) + expected_points_payload.sort(key=lambda p:p["page_content"]) + assert points_payloads == expected_points_payload + + +@pytest.mark.parametrize("collection", ["episodic", "declarative"]) +def test_get_collection_points(client, patch_time_now, collection): + # create 100 points + n_points = 100 + new_points = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}"}} for i in range(n_points) ] + + # Add points + for req_json in new_points: + res = client.post( + f"/memory/collections/{collection}/points", json=req_json + ) + assert res.status_code == 200 + + # get all the points no limit, by default is 100 + res = client.get( + f"/memory/collections/{collection}/points", + ) + assert res.status_code == 200 + json = res.json() + + points = json["points"] + offset = json["next_offset"] + + assert offset is None # the result should contains all the points so no offset + + expected_payloads = [ + {"page_content":p["content"], + "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} + } for p in new_points + ] + _check_points(points, expected_payloads) + + + +@pytest.mark.parametrize("collection", ["episodic", "declarative"]) +def test_get_collection_points_offset(client, patch_time_now, collection): + # create 200 points + n_points = 200 + new_points = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}"}} for i in range(n_points) ] + + # Add points + for req_json in new_points: + res = client.post( + f"/memory/collections/{collection}/points", json=req_json + ) + assert res.status_code == 200 + + # get all the points with limit 10 + limit = 10 + next_offset = "" + all_points = [] + + while True: + res = client.get( + f"/memory/collections/{collection}/points?limit={limit}&offset={next_offset}", + ) + assert res.status_code == 200 + json = res.json() + points = json["points"] + next_offset = json["next_offset"] + assert len(points) == limit + + for point in points: + all_points.append(point) + + if next_offset is None: # break if no new data + break + + # create the expected payloads for all the points + expected_payloads = [ + {"page_content":p["content"], + "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} + } for p in new_points + ] + _check_points(all_points, expected_payloads) + + +def test_get_all_points(client,patch_time_now): + + # create 50 points for episodic + new_points_episodic = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}_episodic"}} for i in range(50) ] + + # create 100 points for declarative + new_points_declarative = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}_declarative"}} for i in range(100) ] + + for point in new_points_episodic: + res = client.post( + f"/memory/collections/episodic/points", json=point + ) + assert res.status_code == 200 + + for point in new_points_declarative: + res = client.post( + f"/memory/collections/declarative/points", json=point + ) + assert res.status_code == 200 + + # get the points from all the collection with default limit (100 points) + res = client.get( + f"/memory/collections/points", + ) + assert res.status_code == 200 + json = res.json() + + assert "episodic" in json + assert "declarative" in json + + #check episodic points + episodic_points = json["episodic"]["points"] + # create the expected payloads for all the points + expected_episodic_payloads = [ + {"page_content":p["content"], + "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} + } for p in new_points_episodic + ] + _check_points(episodic_points, expected_episodic_payloads) + + # check declarative points + declarative_points = json["declarative"]["points"] + # create the expected payloads for all the points + expected_declarative_payloads = [ + {"page_content":p["content"], + "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} + } for p in new_points_declarative + ] + _check_points(declarative_points, expected_declarative_payloads) + + From 5e8fce96639c78e20021dec3b91e66918ebbb183 Mon Sep 17 00:00:00 2001 From: Davide Fusca Date: Sun, 25 Aug 2024 15:32:44 +0200 Subject: [PATCH 2/5] fix f-string without any placeholders --- core/tests/routes/memory/test_memory_points.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index a6076755..9777a1dc 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -267,19 +267,19 @@ def test_get_all_points(client,patch_time_now): for point in new_points_episodic: res = client.post( - f"/memory/collections/episodic/points", json=point + "/memory/collections/episodic/points", json=point ) assert res.status_code == 200 for point in new_points_declarative: res = client.post( - f"/memory/collections/declarative/points", json=point + "/memory/collections/declarative/points", json=point ) assert res.status_code == 200 # get the points from all the collection with default limit (100 points) res = client.get( - f"/memory/collections/points", + "/memory/collections/points", ) assert res.status_code == 200 json = res.json() From 6307cd935e14fea57318c7598051502cf261a481 Mon Sep 17 00:00:00 2001 From: Davide Fusca Date: Sun, 25 Aug 2024 15:35:16 +0200 Subject: [PATCH 3/5] Remove unused import --- core/tests/routes/memory/test_memory_points.py | 1 - 1 file changed, 1 deletion(-) diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index 9777a1dc..28887237 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -1,7 +1,6 @@ import pytest from tests.utils import send_websocket_message, get_declarative_memory_contents from tests.conftest import FAKE_TIMESTAMP -import time def test_point_deleted(client): # send websocket message From 781dfddb7712b6e17d1986fc0ed93843ac3dace3 Mon Sep 17 00:00:00 2001 From: Davide Fusca Date: Thu, 5 Sep 2024 18:00:28 +0200 Subject: [PATCH 4/5] Remove get all points endpoint --- core/cat/routes/memory.py | 51 ---------- .../tests/routes/memory/test_memory_points.py | 93 +++++-------------- 2 files changed, 25 insertions(+), 119 deletions(-) diff --git a/core/cat/routes/memory.py b/core/cat/routes/memory.py index 6f44affd..68af8e19 100644 --- a/core/cat/routes/memory.py +++ b/core/cat/routes/memory.py @@ -359,54 +359,3 @@ async def get_collections_points( "next_offset":next_offset } - -# GET all the points from all the collections -@router.get("/collections/points") -async def get_all_points( - request: Request, - limit:int=Query( - default=100, - description="How many points to return" - ), - stray=Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)) -) -> Dict: - """Retrieve all the points from all the collections - - - Example - ---------- - ``` - # get all the points no limit, by default is 100 - res = requests.get( - f"http://localhost:1865/memory/collections/points", - ) - json = res.json() - - for collection in json: - points = json[collection]["points"] - print(f"Collection {collection}") - - for point in points: - payload = point["payload"] - vector = point["vector"] - print(payload) - print(vector) - ``` - - """ - - # check if collection exists - result = {} - - collections = list(stray.memory.vectors.collections.keys()) - for collection in collections: - #for each collection fetch all the points and next offset - - memory_collection = stray.memory.vectors.collections[collection] - - points, _ = memory_collection.get_all_points_with_offset(limit=limit) - result[collection] = { - "points":points - } - - return result \ No newline at end of file diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index 28887237..de49fe49 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -162,22 +162,6 @@ def test_create_memory_point(client, patch_time_now, collection): assert memory["metadata"] == expected_metadata -# utility function that validates a list of points against an expected points payload -def _check_points(points, expected_points_payload): - # check length - assert len(points) == len(expected_points_payload) - # check all the points contains id and vector - for point in points: - assert "id" in point - assert "vector" in point - - # check points payload - points_payloads = [p["payload"] for p in points] - # sort the list and compare payload - points_payloads.sort(key=lambda p:p["page_content"]) - expected_points_payload.sort(key=lambda p:p["page_content"]) - assert points_payloads == expected_points_payload - @pytest.mark.parametrize("collection", ["episodic", "declarative"]) def test_get_collection_points(client, patch_time_now, collection): @@ -209,7 +193,19 @@ def test_get_collection_points(client, patch_time_now, collection): "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} } for p in new_points ] - _check_points(points, expected_payloads) + + assert len(points) == len(new_points) + # check all the points contains id and vector + for point in points: + assert "id" in point + assert "vector" in point + + # check points payload + points_payloads = [p["payload"] for p in points] + # sort the list and compare payload + points_payloads.sort(key=lambda p:p["page_content"]) + expected_payloads.sort(key=lambda p:p["page_content"]) + assert points_payloads == expected_payloads @@ -253,57 +249,18 @@ def test_get_collection_points_offset(client, patch_time_now, collection): "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} } for p in new_points ] - _check_points(all_points, expected_payloads) - - -def test_get_all_points(client,patch_time_now): - - # create 50 points for episodic - new_points_episodic = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}_episodic"}} for i in range(50) ] - - # create 100 points for declarative - new_points_declarative = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}_declarative"}} for i in range(100) ] - - for point in new_points_episodic: - res = client.post( - "/memory/collections/episodic/points", json=point - ) - assert res.status_code == 200 - - for point in new_points_declarative: - res = client.post( - "/memory/collections/declarative/points", json=point - ) - assert res.status_code == 200 - - # get the points from all the collection with default limit (100 points) - res = client.get( - "/memory/collections/points", - ) - assert res.status_code == 200 - json = res.json() - - assert "episodic" in json - assert "declarative" in json - - #check episodic points - episodic_points = json["episodic"]["points"] - # create the expected payloads for all the points - expected_episodic_payloads = [ - {"page_content":p["content"], - "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} - } for p in new_points_episodic - ] - _check_points(episodic_points, expected_episodic_payloads) - # check declarative points - declarative_points = json["declarative"]["points"] - # create the expected payloads for all the points - expected_declarative_payloads = [ - {"page_content":p["content"], - "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} - } for p in new_points_declarative - ] - _check_points(declarative_points, expected_declarative_payloads) + assert len(all_points) == len(new_points) + # check all the points contains id and vector + for point in all_points: + assert "id" in point + assert "vector" in point + + # check points payload + points_payloads = [p["payload"] for p in all_points] + # sort the list and compare payload + points_payloads.sort(key=lambda p:p["page_content"]) + expected_payloads.sort(key=lambda p:p["page_content"]) + assert points_payloads == expected_payloads From 7d68e1b9b5a2ad30959dd1bee864d61e02967c22 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 25 Sep 2024 23:30:48 +0200 Subject: [PATCH 5/5] review PR --- core/cat/looking_glass/cheshire_cat.py | 2 +- core/cat/memory/vector_memory_collection.py | 24 ++++------- core/cat/routes/memory.py | 20 ++++++++-- core/tests/looking_glass/test_cheshire_cat.py | 2 +- .../tests/routes/memory/test_memory_points.py | 40 +++++++++++++++---- 5 files changed, 58 insertions(+), 30 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index b34e47aa..06b336bb 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -334,7 +334,7 @@ def build_active_procedures_hashes(self, active_procedures): def embed_procedures(self): # Retrieve from vectorDB all procedural embeddings - embedded_procedures = self.memory.vectors.procedural.get_all_points() + embedded_procedures, _ = self.memory.vectors.procedural.get_all_points() embedded_procedures_hashes = self.build_embedded_procedures_hashes( embedded_procedures ) diff --git a/core/cat/memory/vector_memory_collection.py b/core/cat/memory/vector_memory_collection.py index de5136b1..0f2a651c 100644 --- a/core/cat/memory/vector_memory_collection.py +++ b/core/cat/memory/vector_memory_collection.py @@ -258,22 +258,14 @@ def recall_memories_from_embedding( return langchain_documents_from_points - # retrieve all the points in the collection - def get_all_points(self): + # retrieve all the points in the collection with an optional offset and limit. + def get_all_points( + self, + limit: int = 10000, + offset: str | None = None + ): + # retrieving the points - all_points, _ = self.client.scroll( - collection_name=self.collection_name, - with_vectors=True, - limit=10000, # yeah, good for now dear :* - ) - - return all_points - - # Retrieve a set of points with an optional offset and limit. - def get_all_points_with_offset(self, limit:int=10000, offset:str=None): - # Retrieve the points and the next offset. - # To retrieve the first page set offset equal to None - all_points, next_page_offset = self.client.scroll( collection_name=self.collection_name, with_vectors=True, @@ -281,7 +273,7 @@ def get_all_points_with_offset(self, limit:int=10000, offset:str=None): limit=limit # Limit the number of points retrieved to the specified limit. ) - return (all_points, next_page_offset) + return all_points, next_page_offset def db_is_remote(self): return isinstance(self.client._client, QdrantRemote) diff --git a/core/cat/routes/memory.py b/core/cat/routes/memory.py index 68af8e19..a9367387 100644 --- a/core/cat/routes/memory.py +++ b/core/cat/routes/memory.py @@ -340,11 +340,23 @@ async def get_collections_points( ``` """ + # do not allow procedural memory reads via network + if collection_id == "procedural": + raise HTTPException( + status_code=400, + detail={ + "error": "Procedural memory is not readable via API" + } + ) + # check if collection exists collections = list(stray.memory.vectors.collections.keys()) if collection_id not in collections: raise HTTPException( - status_code=400, detail={"error": f"Collection does not exist. Avaliable collections: {collections}"} + status_code=400, + detail={ + "error": "Collection does not exist." + } ) # if offset is empty string set to null @@ -352,10 +364,10 @@ async def get_collections_points( offset = None memory_collection = stray.memory.vectors.collections[collection_id] - points, next_offset = memory_collection.get_all_points_with_offset(limit=limit,offset=offset) + points, next_offset = memory_collection.get_all_points(limit=limit, offset=offset) return { - "points":points, - "next_offset":next_offset + "points": points, + "next_offset": next_offset } diff --git a/core/tests/looking_glass/test_cheshire_cat.py b/core/tests/looking_glass/test_cheshire_cat.py index d3a245c3..00a9e30c 100644 --- a/core/tests/looking_glass/test_cheshire_cat.py +++ b/core/tests/looking_glass/test_cheshire_cat.py @@ -59,7 +59,7 @@ def test_default_embedder_loaded(cheshire_cat): def test_procedures_embedded(cheshire_cat): # get embedded tools - procedures = cheshire_cat.memory.vectors.procedural.get_all_points() + procedures, _ = cheshire_cat.memory.vectors.procedural.get_all_points() assert len(procedures) == 3 for p in procedures: diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index de49fe49..3b0ea71a 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -161,7 +161,21 @@ def test_create_memory_point(client, patch_time_now, collection): assert memory["page_content"] == content assert memory["metadata"] == expected_metadata +def test_get_collection_points_wrong_collection(client): + + # unexisting collection + res = client.get( + f"/memory/collections/unexistent/points", + ) + assert res.status_code == 400 + assert "Collection does not exist" in res.json()["detail"]["error"] + # reserved procedural collection + res = client.get( + "/memory/collections/procedural/points", + ) + assert res.status_code == 400 + assert "Procedural memory is not readable via API" in res.json()["detail"]["error"] @pytest.mark.parametrize("collection", ["episodic", "declarative"]) def test_get_collection_points(client, patch_time_now, collection): @@ -189,8 +203,13 @@ def test_get_collection_points(client, patch_time_now, collection): assert offset is None # the result should contains all the points so no offset expected_payloads = [ - {"page_content":p["content"], - "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} + { + "page_content": p["content"], + "metadata": { + "when":FAKE_TIMESTAMP, + "source": "user", + **p["metadata"] + } } for p in new_points ] @@ -203,8 +222,8 @@ def test_get_collection_points(client, patch_time_now, collection): # check points payload points_payloads = [p["payload"] for p in points] # sort the list and compare payload - points_payloads.sort(key=lambda p:p["page_content"]) - expected_payloads.sort(key=lambda p:p["page_content"]) + points_payloads.sort(key=lambda p: p["page_content"]) + expected_payloads.sort(key=lambda p: p["page_content"]) assert points_payloads == expected_payloads @@ -245,8 +264,13 @@ def test_get_collection_points_offset(client, patch_time_now, collection): # create the expected payloads for all the points expected_payloads = [ - {"page_content":p["content"], - "metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]} + { + "page_content": p["content"], + "metadata": { + "when":FAKE_TIMESTAMP, + "source": "user", + **p["metadata"] + } } for p in new_points ] @@ -259,8 +283,8 @@ def test_get_collection_points_offset(client, patch_time_now, collection): # check points payload points_payloads = [p["payload"] for p in all_points] # sort the list and compare payload - points_payloads.sort(key=lambda p:p["page_content"]) - expected_payloads.sort(key=lambda p:p["page_content"]) + points_payloads.sort(key=lambda p: p["page_content"]) + expected_payloads.sort(key=lambda p: p["page_content"]) assert points_payloads == expected_payloads