Skip to content

Commit

Permalink
Merge pull request #515 from zAlweNy26/user_id_header
Browse files Browse the repository at this point in the history
Added `user_id` in header
  • Loading branch information
pieroit authored Oct 23, 2023
2 parents 1ca78eb + 0236330 commit e197ea8
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 29 deletions.
8 changes: 8 additions & 0 deletions core/cat/api_auth.py → core/cat/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,11 @@ def check_api_key(request: Request, api_key: str = Security(api_key_header)) ->
status_code=403,
detail={"error": "Invalid API Key"}
)


def check_user_id(request: Request) -> str:
user_id = request.headers.get("user_id")
if user_id:
return user_id
else:
return "user"
4 changes: 2 additions & 2 deletions core/cat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cat.log import log
from cat.routes import base, settings, llm, embedder, memory, plugins, upload, websocket
from cat.routes.static import public, admin, static
from cat.api_auth import check_api_key
from cat.headers import check_api_key
from cat.routes.openapi import get_openapi_configuration_function
from cat.looking_glass.cheshire_cat import CheshireCat

Expand Down Expand Up @@ -64,7 +64,7 @@ def custom_generate_unique_id(route: APIRoute):
cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(memory.router, tags=["Memory"], prefix="/memory", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(upload.router, tags=["Rabbit Hole"], prefix="/rabbithole", dependencies=[Depends(check_api_key)])
cheshire_cat_api.include_router(websocket.router, tags=["Websocket"])
cheshire_cat_api.include_router(websocket.router, tags=["WebSocket"])

# mount static files
# this cannot be done via fastapi.APIrouter:
Expand Down
9 changes: 5 additions & 4 deletions core/cat/routes/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict
from fastapi import Query, Request, APIRouter, HTTPException
from cat.headers import check_user_id
from fastapi import Query, Request, APIRouter, HTTPException, Depends

router = APIRouter()

Expand All @@ -10,7 +11,7 @@ async def recall_memories_from_text(
request: Request,
text: str = Query(description="Find memories similar to this text."),
k: int = Query(default=100, description="How many memories to return."),
user_id: str = Query(default="user", description="User id."),
user_id = Depends(check_user_id)
) -> Dict:
"""Search k memories similar to given text."""

Expand Down Expand Up @@ -201,7 +202,7 @@ async def wipe_memory_points_by_metadata(
@router.delete("/conversation_history/")
async def wipe_conversation_history(
request: Request,
user_id: str = Query(default="user", description="User id."),
user_id = Depends(check_user_id),
) -> Dict:
"""Delete the specified user's conversation history from working memory"""

Expand All @@ -219,7 +220,7 @@ async def wipe_conversation_history(
@router.get("/conversation_history/")
async def get_conversation_history(
request: Request,
user_id: str = Query(default="user", description="User id."),
user_id = Depends(check_user_id),
) -> Dict:
"""Get the specified user's conversation history from working memory"""

Expand Down
22 changes: 11 additions & 11 deletions core/cat/routes/static/auth_static.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from fastapi.staticfiles import StaticFiles
from fastapi import Request
from cat.api_auth import check_api_key

class AuthStatic(StaticFiles):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

async def __call__(self, scope, receive, send) -> None:
reqeust = Request(scope, receive=receive)
check_api_key(reqeust.headers.get("access_token"))
from fastapi.staticfiles import StaticFiles
from fastapi import Request
from cat.headers import check_api_key

class AuthStatic(StaticFiles):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

async def __call__(self, scope, receive, send) -> None:
request = Request(scope, receive=receive)
check_api_key(request.headers.get("access_token"))
await super().__call__(scope, receive, send)
21 changes: 9 additions & 12 deletions core/tests/routes/memory/test_memory_by_user.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@

from tests.utils import send_websocket_message

# episodic memories are saved having the correct user
def test_episodic_memory_by_user(client):

# send websocket message from user A
# send websocket message from user C
send_websocket_message({
"text": "I am user A"
}, client, user_id="A")
"text": "I am user C",
}, client, user_id="C")

# episodic recall (no user)
params = {
Expand All @@ -21,24 +20,22 @@ def test_episodic_memory_by_user(client):

# episodic recall (memories from non existing user)
params = {
"text": "I am user",
"user_id": "H"
"text": "I am user not existing"
}
response = client.get(f"/memory/recall/", params=params)
response = client.get(f"/memory/recall/", params=params, headers={"user_id": "not_existing"})
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 0

# episodic recall (memories from user A)
# episodic recall (memories from user C)
params = {
"text": "I am user",
"user_id": "A"
"text": "I am user C"
}
response = client.get(f"/memory/recall/", params=params)
response = client.get(f"/memory/recall/", params=params, headers={"user_id": "C"})
json = response.json()
assert response.status_code == 200
episodic_memories = json["vectors"]["collections"]["episodic"]
assert len(episodic_memories) == 1
assert episodic_memories[0]["metadata"]["source"] == "A"
assert episodic_memories[0]["metadata"]["source"] == "C"

0 comments on commit e197ea8

Please sign in to comment.