Skip to content

Commit

Permalink
Merge branch 'dave90-issue_889' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Sep 25, 2024
2 parents e6e9cfb + 7d68e1b commit 17a424f
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 7 deletions.
2 changes: 1 addition & 1 deletion core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def build_active_procedures_hashes(self, active_procedures):

def embed_procedures(self):
# Retrieve from vectorDB all procedural embeddings
embedded_procedures = self.memory.vectors.procedural.get_all_points()
embedded_procedures, _ = self.memory.vectors.procedural.get_all_points()
embedded_procedures_hashes = self.build_embedded_procedures_hashes(
embedded_procedures
)
Expand Down
16 changes: 11 additions & 5 deletions core/cat/memory/vector_memory_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,22 @@ def recall_memories_from_embedding(

return langchain_documents_from_points

# retrieve all the points in the collection
def get_all_points(self):
# retrieve all the points in the collection with an optional offset and limit.
def get_all_points(
self,
limit: int = 10000,
offset: str | None = None
):

# retrieving the points
all_points, _ = self.client.scroll(
all_points, next_page_offset = self.client.scroll(
collection_name=self.collection_name,
with_vectors=True,
limit=10000, # yeah, good for now dear :*
offset=offset, # Start from the given offset, or the beginning if None.
limit=limit # Limit the number of points retrieved to the specified limit.
)

return all_points
return all_points, next_page_offset

def db_is_remote(self):
return isinstance(self.client._client, QdrantRemote)
Expand Down
94 changes: 94 additions & 0 deletions core/cat/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,97 @@ async def get_conversation_history(
"""Get the specified user's conversation history from working memory"""

return {"history": stray.working_memory.history}

# GET all the points from a single collection
@router.get("/collections/{collection_id}/points")
async def get_collections_points(
request: Request,
collection_id: str,
limit:int=Query(
default=100,
description="How many points to return"
),
offset:str = Query(
default=None,
description="If provided (or not empty string) - skip points with ids less than given `offset`"
),
stray=Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Retrieve all the points from a single collection
Example
----------
```
collection = "declarative"
res = requests.get(
f"http://localhost:1865/memory/collections/{collection}/points",
)
json = res.json()
points = json["points"]
for point in points:
payload = point["payload"]
vector = point["vector"]
print(payload)
print(vector)
```
Example using offset
----------
```
# get all the points with limit 10
limit = 10
next_offset = ""
collection = "declarative"
while True:
res = requests.get(
f"http://localhost:1865/memory/collections/{collection}/points?limit={limit}&offset={next_offset}",
)
json = res.json()
points = json["points"]
next_offset = json["next_offset"]
for point in points:
payload = point["payload"]
vector = point["vector"]
print(payload)
print(vector)
if next_offset is None:
break
```
"""

# do not allow procedural memory reads via network
if collection_id == "procedural":
raise HTTPException(
status_code=400,
detail={
"error": "Procedural memory is not readable via API"
}
)

# check if collection exists
collections = list(stray.memory.vectors.collections.keys())
if collection_id not in collections:
raise HTTPException(
status_code=400,
detail={
"error": "Collection does not exist."
}
)

# if offset is empty string set to null
if offset == "":
offset = None

memory_collection = stray.memory.vectors.collections[collection_id]
points, next_offset = memory_collection.get_all_points(limit=limit, offset=offset)

return {
"points": points,
"next_offset": next_offset
}

2 changes: 1 addition & 1 deletion core/tests/looking_glass/test_cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_default_embedder_loaded(cheshire_cat):

def test_procedures_embedded(cheshire_cat):
# get embedded tools
procedures = cheshire_cat.memory.vectors.procedural.get_all_points()
procedures, _ = cheshire_cat.memory.vectors.procedural.get_all_points()
assert len(procedures) == 3

for p in procedures:
Expand Down
126 changes: 126 additions & 0 deletions core/tests/routes/memory/test_memory_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,130 @@ def test_create_memory_point(client, patch_time_now, collection):
assert memory["page_content"] == content
assert memory["metadata"] == expected_metadata

def test_get_collection_points_wrong_collection(client):

# unexisting collection
res = client.get(
f"/memory/collections/unexistent/points",
)
assert res.status_code == 400
assert "Collection does not exist" in res.json()["detail"]["error"]

# reserved procedural collection
res = client.get(
"/memory/collections/procedural/points",
)
assert res.status_code == 400
assert "Procedural memory is not readable via API" in res.json()["detail"]["error"]

@pytest.mark.parametrize("collection", ["episodic", "declarative"])
def test_get_collection_points(client, patch_time_now, collection):
# create 100 points
n_points = 100
new_points = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}"}} for i in range(n_points) ]

# Add points
for req_json in new_points:
res = client.post(
f"/memory/collections/{collection}/points", json=req_json
)
assert res.status_code == 200

# get all the points no limit, by default is 100
res = client.get(
f"/memory/collections/{collection}/points",
)
assert res.status_code == 200
json = res.json()

points = json["points"]
offset = json["next_offset"]

assert offset is None # the result should contains all the points so no offset

expected_payloads = [
{
"page_content": p["content"],
"metadata": {
"when":FAKE_TIMESTAMP,
"source": "user",
**p["metadata"]
}
} for p in new_points
]

assert len(points) == len(new_points)
# check all the points contains id and vector
for point in points:
assert "id" in point
assert "vector" in point

# check points payload
points_payloads = [p["payload"] for p in points]
# sort the list and compare payload
points_payloads.sort(key=lambda p: p["page_content"])
expected_payloads.sort(key=lambda p: p["page_content"])
assert points_payloads == expected_payloads



@pytest.mark.parametrize("collection", ["episodic", "declarative"])
def test_get_collection_points_offset(client, patch_time_now, collection):
# create 200 points
n_points = 200
new_points = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}"}} for i in range(n_points) ]

# Add points
for req_json in new_points:
res = client.post(
f"/memory/collections/{collection}/points", json=req_json
)
assert res.status_code == 200

# get all the points with limit 10
limit = 10
next_offset = ""
all_points = []

while True:
res = client.get(
f"/memory/collections/{collection}/points?limit={limit}&offset={next_offset}",
)
assert res.status_code == 200
json = res.json()
points = json["points"]
next_offset = json["next_offset"]
assert len(points) == limit

for point in points:
all_points.append(point)

if next_offset is None: # break if no new data
break

# create the expected payloads for all the points
expected_payloads = [
{
"page_content": p["content"],
"metadata": {
"when":FAKE_TIMESTAMP,
"source": "user",
**p["metadata"]
}
} for p in new_points
]

assert len(all_points) == len(new_points)
# check all the points contains id and vector
for point in all_points:
assert "id" in point
assert "vector" in point

# check points payload
points_payloads = [p["payload"] for p in all_points]
# sort the list and compare payload
points_payloads.sort(key=lambda p: p["page_content"])
expected_payloads.sort(key=lambda p: p["page_content"])
assert points_payloads == expected_payloads


0 comments on commit 17a424f

Please sign in to comment.