From 021b4f3d67761db3117916c9b6365cb3956caa9d Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Mon, 23 Oct 2023 19:31:52 +0200 Subject: [PATCH] cat as singleton ready to go --- core/cat/looking_glass/cheshire_cat.py | 5 ++--- core/cat/main.py | 4 ++-- core/tests/routes/memory/test_memory_recall.py | 1 - 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index bd644a10..b3f301af 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -40,12 +40,13 @@ class CheshireCat(): """ + # CheshireCat is a singleton, this is the instance _instance = None + # get instance or create as the constructor is called def __new__(cls): if not cls._instance: cls._instance = super().__new__(cls) - cls._instance = CheshireCat() return cls._instance def __init__(self): @@ -544,5 +545,3 @@ def get_static_path(): """Allows the Cat expose the static files path.""" log.warning("This method will be removed, import cat.utils tu usit instead.") return utils.get_static_path() - -cat = CheshireCat() \ No newline at end of file diff --git a/core/cat/main.py b/core/cat/main.py index e075c4df..c26ef731 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -14,7 +14,7 @@ from cat.routes.static import public, admin, static from cat.api_auth import check_api_key from cat.routes.openapi import get_openapi_configuration_function -from cat.looking_glass.cheshire_cat import cat +from cat.looking_glass.cheshire_cat import CheshireCat @asynccontextmanager @@ -26,7 +26,7 @@ async def lifespan(app: FastAPI): # - Not using midlleware because I can't make it work with both http and websocket; # - Not using Depends because it only supports callables (not instances) # - Starlette allows this: https://www.starlette.io/applications/#storing-state-on-the-app-instance - app.state.ccat = cat + app.state.ccat = CheshireCat() # startup message with admin, public and swagger addresses log.welcome() diff --git a/core/tests/routes/memory/test_memory_recall.py b/core/tests/routes/memory/test_memory_recall.py index 35026034..14040355 100644 --- a/core/tests/routes/memory/test_memory_recall.py +++ b/core/tests/routes/memory/test_memory_recall.py @@ -54,7 +54,6 @@ def test_memory_recall_success(client): episodic_memories = json["vectors"]["collections"]["episodic"] assert len(episodic_memories) == num_messages # all 3 retrieved - # search with query and k def test_memory_recall_with_k_success(client):