From 7815b80efa034584e697782e06c438ee9a541dab Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Sat, 26 Aug 2023 19:35:10 +0200 Subject: [PATCH 01/77] Check if the module is imported before trying to remove it --- core/cat/mad_hatter/plugin.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index dbe508c0..ba2db517 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -62,9 +62,12 @@ def deactivate(self): # Remove the imported modules for py_file in self.py_files: - py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry - log(f"Remove module {py_filename}", "DEBUG") - sys.modules.pop(py_filename) + py_filename = py_file.replace("/", ".").replace(".py", "") + + # If the module is imported it is removed + if py_filename in sys.modules: + log(f"Remove module {py_filename}", "DEBUG") + sys.modules.pop(py_filename) self._hooks = [] self._tools = [] From 823a829823620234a7983fee98b0dac696e720bc Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Sat, 26 Aug 2023 19:37:55 +0200 Subject: [PATCH 02/77] Removed active parameter from Plugin constructor --- core/cat/mad_hatter/mad_hatter.py | 6 +++--- core/cat/mad_hatter/plugin.py | 10 ++-------- core/tests/mad_hatter/test_plugin.py | 10 +++++----- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 8ffe4746..170f95d7 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -52,7 +52,7 @@ def install_plugin(self, package_plugin): raise Exception("A plugin should contain a folder, found a file") # create plugin obj - self.load_plugin(plugin_path, active=False) + self.load_plugin(plugin_path) # activate it self.toggle_plugin(plugin_id) @@ -104,12 +104,12 @@ def find_plugins(self): self.sync_hooks_and_tools() - def load_plugin(self, plugin_path, active): + def load_plugin(self, plugin_path): # Instantiate plugin. # If the plugin is inactive, only manifest will be loaded # If active, also settings, tools and hooks try: - plugin = Plugin(plugin_path, active=active) + plugin = Plugin(plugin_path) # if plugin is valid, keep a reference self.plugins[plugin.id] = plugin except Exception as e: diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index ba2db517..2dfa6f57 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -18,7 +18,7 @@ class Plugin: - def __init__(self, plugin_path: str, active: bool): + def __init__(self, plugin_path: str): # does folder exist? if not os.path.exists(plugin_path) or not os.path.isdir(plugin_path): @@ -45,21 +45,14 @@ def __init__(self, plugin_path: str, active: bool): # but they are created and stored in each plugin instance self._hooks = [] self._tools = [] - self._active = False - # all plugins start active, they can be deactivated/reactivated from endpoint - if active: - self.activate() - def activate(self): # lists of hooks and tools self._hooks, self._tools = self._load_hooks_and_tools() self._active = True def deactivate(self): - self._active = False - # Remove the imported modules for py_file in self.py_files: py_filename = py_file.replace("/", ".").replace(".py", "") @@ -71,6 +64,7 @@ def deactivate(self): self._hooks = [] self._tools = [] + self._active = False # get plugin settings JSON schema def get_settings_schema(self): diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index d2b67d45..31c833a5 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -13,7 +13,7 @@ @pytest.fixture def plugin(): - p = Plugin(mock_plugin_path, active=True) + p = Plugin(mock_plugin_path) yield p @@ -25,7 +25,7 @@ def plugin(): def test_create_plugin_wrong_folder(): with pytest.raises(Exception) as e: - Plugin("/non/existent/folder", active=True) + Plugin("/non/existent/folder") assert f"Cannot create" in str(e.value) @@ -36,14 +36,14 @@ def test_create_plugin_empty_folder(): os.mkdir(path) with pytest.raises(Exception) as e: - Plugin(path, active=True) + Plugin(path) assert f"Cannot create" in str(e.value) def test_create_non_active_plugin(): - plugin = Plugin(mock_plugin_path, active=False) + plugin = Plugin(mock_plugin_path) assert plugin.active == False @@ -97,7 +97,7 @@ def test_create_active_plugin(plugin): def test_activate_plugin(): # create non-active plugin - plugin = Plugin(mock_plugin_path, active=False) + plugin = Plugin(mock_plugin_path) # activate it plugin.activate() From dade61acf98514f3cf58c81cfee9cb53d2e99079 Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Sat, 26 Aug 2023 19:38:44 +0200 Subject: [PATCH 03/77] Removed unused headers --- core/tests/mad_hatter/test_plugin.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index 31c833a5..e79af161 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -1,11 +1,7 @@ import os -import shutil import pytest -from inspect import isfunction from cat.mad_hatter.mad_hatter import Plugin -from cat.mad_hatter.decorators import CatHook, CatTool - mock_plugin_path = "tests/mocks/mock_plugin/" From e79c69eb6ac0780e3e07174f9065644193548e86 Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Sat, 26 Aug 2023 19:40:34 +0200 Subject: [PATCH 04/77] Manually activate the plugin during discovery if is int the list of active plugins --- core/cat/mad_hatter/mad_hatter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 170f95d7..51922a85 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -95,12 +95,12 @@ def find_plugins(self): # discover plugins, folder by folder for folder in all_plugin_folders: + self.load_plugin(folder) - # is the plugin active? - folder_base = os.path.basename(os.path.normpath(folder)) - is_active = folder_base in self.active_plugins - - self.load_plugin(folder, is_active) + plugin_id = os.path.basename(os.path.normpath(folder)) + + if plugin_id in self.active_plugins: + self.plugins[plugin_id].activate() self.sync_hooks_and_tools() From abf976fe96a776caf2bf9238a6ed15b3595923ce Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Sat, 26 Aug 2023 19:42:19 +0200 Subject: [PATCH 05/77] Removed unnecessary test, Plugin object is created by default deactivated --- core/tests/mad_hatter/test_plugin.py | 34 ---------------------------- 1 file changed, 34 deletions(-) diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index e79af161..66092a27 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -56,40 +56,6 @@ def test_create_non_active_plugin(): assert plugin.hooks == [] assert plugin.tools == [] - -def test_create_active_plugin(plugin): - - assert plugin.active == True - - assert plugin.path == mock_plugin_path - assert plugin.id == "mock_plugin" - - # manifest - assert type(plugin.manifest) == dict - assert plugin.manifest["id"] == plugin.id - assert plugin.manifest["name"] == "MockPlugin" - assert "Description not found" in plugin.manifest["description"] - - # hooks - assert len(plugin.hooks) == 1 - hook = plugin.hooks[0] - assert isinstance(hook, CatHook) - assert hook.plugin_id == "mock_plugin" - assert hook.name == "before_cat_sends_message" - assert isfunction(hook.function) - assert hook.priority == 2.0 - - # tools - assert len(plugin.tools) == 1 - tool = plugin.tools[0] - assert isinstance(tool, CatTool) - assert tool.plugin_id == "mock_plugin" - assert tool.name == "mock_tool" - assert "mock_tool" in tool.description - assert isfunction(tool.func) - assert tool.return_direct == True - - def test_activate_plugin(): # create non-active plugin From b0d9dec893ce46732591cde1c9f2e7f6513c89aa Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Sat, 26 Aug 2023 19:42:41 +0200 Subject: [PATCH 06/77] Renamed test function --- core/tests/mad_hatter/test_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index 66092a27..98e3815f 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -37,7 +37,7 @@ def test_create_plugin_empty_folder(): assert f"Cannot create" in str(e.value) -def test_create_non_active_plugin(): +def test_create_plugin(): plugin = Plugin(mock_plugin_path) From ae3e29f1a19065cd6255b359b35278f2ee9cf47b Mon Sep 17 00:00:00 2001 From: Dany Date: Tue, 29 Aug 2023 11:37:07 +0200 Subject: [PATCH 07/77] refactored ws messages --- core/cat/looking_glass/cheshire_cat.py | 30 +++++++++++++++++++++++--- core/cat/rabbit_hole.py | 29 ++++++------------------- core/cat/routes/websocket.py | 4 ++-- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index b9f82a1f..a1a308cd 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -21,7 +21,7 @@ class CheshireCat: Attributes ---------- - web_socket_notifications : list + ws_messages : list List of notifications to be sent to the frontend. """ @@ -33,7 +33,7 @@ def __init__(self): """ # bootstrap the cat! - # reinstantiate MadHatter (reloads all plugins' hooks and tools) + # reinstantiate MadHatter (reloads all plugins' hooks and tools) self.mad_hatter = MadHatter(self) # allows plugins to do something before cat components are loaded @@ -59,7 +59,7 @@ def __init__(self): # queue of cat messages not directly related to last user input # i.e. finished uploading a file - self.web_socket_notifications = [] + self.ws_messages = [] def load_natural_language(self): """Load Natural Language related objects. @@ -281,6 +281,30 @@ def store_new_message_in_working_memory(self, user_message_json): self.working_memory["user_message_json"]["prompt_settings"] = prompt_settings + def send_ws_message(self, type: str, content: str): + """Send a message via websocket. + + This method is useful for sending a message via websocket directly without passing through the llm + + Parameters + ---------- + type : str + The type of the message. Should be either `notification` or `chat` + content : str + The content of the message. + + Returns + ------- + str + The generated response. + + """ + + self.ws_messages.append({ + "type": type, + "content": content + }) + def get_base_url(self): """Allows the Cat expose the base url.""" secure = os.getenv('CORE_USE_SECURE_PROTOCOLS', '') diff --git a/core/cat/rabbit_hole.py b/core/cat/rabbit_hole.py index 04be89c7..49a5f20f 100644 --- a/core/cat/rabbit_hole.py +++ b/core/cat/rabbit_hole.py @@ -217,36 +217,19 @@ def file_to_docs( blob = Blob(data=file_bytes, mimetype=content_type, source=source).from_data(data=file_bytes, - mime_type=content_type) + mime_type=content_type, + path=source) # Parser based on the mime type parser = MimeTypeBasedParser(handlers=self.file_handlers) # Parse the text - self.send_rabbit_thought("I'm parsing the content. Big content could require some minutes...") + self.cat.send_ws_message("notification", "I'm parsing the content. Big content could require some minutes...") text = parser.parse(blob) - self.send_rabbit_thought(f"Parsing completed. Now let's go with reading process...") + self.cat.send_ws_message("notification", f"Parsing completed. Now let's go with reading process...") docs = self.split_text(text, chunk_size, chunk_overlap) return docs - def send_rabbit_thought(self, thought): - """Append a message to the notification list. - - This method receives a string and creates the message to append to the list of notifications. - - Parameters - ---------- - thought : str - Text of the message to append to the notification list. - """ - - self.cat.web_socket_notifications.append({ - "error": False, - "type": "notification", - "content": thought, - "why": {}, - }) - def store_documents(self, docs: List[Document], source: str) -> None: """Add documents to the Cat's declarative memory. @@ -284,7 +267,7 @@ def store_documents(self, docs: List[Document], source: str) -> None: if time.time() - time_last_notification > time_interval: time_last_notification = time.time() perc_read = int(d / len(docs) * 100) - self.send_rabbit_thought(f"Read {perc_read}% of {source}") + self.cat.send_ws_message("notification", f"Read {perc_read}% of {source}") doc.metadata["source"] = source doc.metadata["when"] = time.time() @@ -310,7 +293,7 @@ def store_documents(self, docs: List[Document], source: str) -> None: finished_reading_message = f"Finished reading {source}, " \ f"I made {len(docs)} thoughts on it." - self.send_rabbit_thought(finished_reading_message) + self.cat.send_ws_message("notification", finished_reading_message) print(f"\n\nDone uploading {source}") diff --git a/core/cat/routes/websocket.py b/core/cat/routes/websocket.py index a6b54392..aebabdcd 100644 --- a/core/cat/routes/websocket.py +++ b/core/cat/routes/websocket.py @@ -69,9 +69,9 @@ async def check_notification(websocket: WebSocket, ccat: object): Periodically check if there are any new notifications from the `ccat` object and send them to the user. """ while True: - if ccat.web_socket_notifications: + if ccat.ws_messages: # extract from FIFO list websocket notification - notification = ccat.web_socket_notifications.pop(0) + notification = ccat.ws_messages.pop(0) await manager.send_personal_message(notification, websocket) # Sleep for the specified interval before checking for notifications again. From 457a2e1b4dc99cd316c2d915c67d97ba467fb796 Mon Sep 17 00:00:00 2001 From: Dany Date: Tue, 29 Aug 2023 11:41:52 +0200 Subject: [PATCH 08/77] Update cheshire_cat.py --- core/cat/looking_glass/cheshire_cat.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index a1a308cd..4c9ee00b 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -284,20 +284,14 @@ def store_new_message_in_working_memory(self, user_message_json): def send_ws_message(self, type: str, content: str): """Send a message via websocket. - This method is useful for sending a message via websocket directly without passing through the llm + This method is useful for sending a message via websocket directly without passing through the LLM Parameters ---------- type : str The type of the message. Should be either `notification` or `chat` content : str - The content of the message. - - Returns - ------- - str - The generated response. - + The content of the message. """ self.ws_messages.append({ From 41ffc50c3a775ef40c792bdd69c4716bfe542b9b Mon Sep 17 00:00:00 2001 From: Dany Date: Tue, 29 Aug 2023 11:52:25 +0200 Subject: [PATCH 09/77] Update cheshire_cat.py --- core/cat/looking_glass/cheshire_cat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 4c9ee00b..24f5394a 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -1,7 +1,7 @@ import time from copy import deepcopy import traceback - +from typing import Literal import langchain import os from cat.log import log @@ -281,7 +281,7 @@ def store_new_message_in_working_memory(self, user_message_json): self.working_memory["user_message_json"]["prompt_settings"] = prompt_settings - def send_ws_message(self, type: str, content: str): + def send_ws_message(self, type: Literal["notification", "chat"], content: str): """Send a message via websocket. This method is useful for sending a message via websocket directly without passing through the LLM From 5ef9bd9754ef3951249e10d1a3d5a2f673564322 Mon Sep 17 00:00:00 2001 From: Dany Date: Tue, 29 Aug 2023 12:08:19 +0200 Subject: [PATCH 10/77] added types --- core/cat/looking_glass/cheshire_cat.py | 10 ++++++++-- core/cat/rabbit_hole.py | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 24f5394a..8591dc54 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -1,7 +1,7 @@ import time from copy import deepcopy import traceback -from typing import Literal +from typing import Literal, get_args import langchain import os from cat.log import log @@ -12,6 +12,7 @@ from cat.memory.long_term_memory import LongTermMemory from cat.looking_glass.agent_manager import AgentManager +MSG_TYPES = Literal["notification", "chat"] # main class class CheshireCat: @@ -281,7 +282,7 @@ def store_new_message_in_working_memory(self, user_message_json): self.working_memory["user_message_json"]["prompt_settings"] = prompt_settings - def send_ws_message(self, type: Literal["notification", "chat"], content: str): + def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): """Send a message via websocket. This method is useful for sending a message via websocket directly without passing through the LLM @@ -294,6 +295,11 @@ def send_ws_message(self, type: Literal["notification", "chat"], content: str): The content of the message. """ + options = get_args(MSG_TYPES) + + if msg_type not in options: + raise ValueError(f"The message type `{msg_type}` is not valid. Valid types: {', '.join(options)}") + self.ws_messages.append({ "type": type, "content": content diff --git a/core/cat/rabbit_hole.py b/core/cat/rabbit_hole.py index 49a5f20f..f4d87d77 100644 --- a/core/cat/rabbit_hole.py +++ b/core/cat/rabbit_hole.py @@ -223,10 +223,10 @@ def file_to_docs( parser = MimeTypeBasedParser(handlers=self.file_handlers) # Parse the text - self.cat.send_ws_message("notification", "I'm parsing the content. Big content could require some minutes...") + self.cat.send_ws_message("I'm parsing the content. Big content could require some minutes...") text = parser.parse(blob) - self.cat.send_ws_message("notification", f"Parsing completed. Now let's go with reading process...") + self.cat.send_ws_message(f"Parsing completed. Now let's go with reading process...") docs = self.split_text(text, chunk_size, chunk_overlap) return docs @@ -267,7 +267,7 @@ def store_documents(self, docs: List[Document], source: str) -> None: if time.time() - time_last_notification > time_interval: time_last_notification = time.time() perc_read = int(d / len(docs) * 100) - self.cat.send_ws_message("notification", f"Read {perc_read}% of {source}") + self.cat.send_ws_message(f"Read {perc_read}% of {source}") doc.metadata["source"] = source doc.metadata["when"] = time.time() @@ -293,7 +293,7 @@ def store_documents(self, docs: List[Document], source: str) -> None: finished_reading_message = f"Finished reading {source}, " \ f"I made {len(docs)} thoughts on it." - self.cat.send_ws_message("notification", finished_reading_message) + self.cat.send_ws_message(finished_reading_message) print(f"\n\nDone uploading {source}") From 5734644d8509ce831bb653773ef7baa9f01622d4 Mon Sep 17 00:00:00 2001 From: Dany Date: Tue, 29 Aug 2023 12:29:31 +0200 Subject: [PATCH 11/77] Update cheshire_cat.py --- core/cat/looking_glass/cheshire_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 8591dc54..3b0361e3 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -301,7 +301,7 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): raise ValueError(f"The message type `{msg_type}` is not valid. Valid types: {', '.join(options)}") self.ws_messages.append({ - "type": type, + "type": msg_type, "content": content }) From 2f75f198e6a22fd3c86aaab9e4515e5b8ca064db Mon Sep 17 00:00:00 2001 From: Dany Date: Tue, 29 Aug 2023 12:32:54 +0200 Subject: [PATCH 12/77] Update websocket.py --- core/cat/routes/websocket.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/cat/routes/websocket.py b/core/cat/routes/websocket.py index aebabdcd..397e5c06 100644 --- a/core/cat/routes/websocket.py +++ b/core/cat/routes/websocket.py @@ -8,7 +8,7 @@ router = APIRouter() # This constant sets the interval (in seconds) at which the system checks for notifications. -NOTIFICATION_CHECK_INTERVAL = 1 # seconds +QUEUE_CHECK_INTERVAL = 1 # seconds class ConnectionManager: @@ -64,9 +64,9 @@ async def receive_message(websocket: WebSocket, ccat: object): await manager.send_personal_message(cat_message, websocket) -async def check_notification(websocket: WebSocket, ccat: object): +async def check_messages(websocket: WebSocket, ccat): """ - Periodically check if there are any new notifications from the `ccat` object and send them to the user. + Periodically check if there are any new notifications from the `ccat` instance and send them to the user. """ while True: if ccat.ws_messages: @@ -75,13 +75,13 @@ async def check_notification(websocket: WebSocket, ccat: object): await manager.send_personal_message(notification, websocket) # Sleep for the specified interval before checking for notifications again. - await asyncio.sleep(NOTIFICATION_CHECK_INTERVAL) + await asyncio.sleep(QUEUE_CHECK_INTERVAL) @router.websocket_route("/ws") async def websocket_endpoint(websocket: WebSocket): """ - Endpoint to handle incoming WebSocket connections, process messages, and check for notifications. + Endpoint to handle incoming WebSocket connections, process messages, and check for messages. """ # Retrieve the `ccat` instance from the application's state. @@ -94,7 +94,7 @@ async def websocket_endpoint(websocket: WebSocket): # Process messages and check for notifications concurrently. await asyncio.gather( receive_message(websocket, ccat), - check_notification(websocket, ccat) + check_messages(websocket, ccat) ) except WebSocketDisconnect: # Handle the event where the user disconnects their WebSocket. From 21ce67394a2202443975ea0fd97063667438f285 Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Wed, 30 Aug 2023 16:37:39 +0200 Subject: [PATCH 13/77] Use fixture in `test_create_plugin` and `test_activate_plugin` --- core/tests/mad_hatter/test_plugin.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index 98e3815f..d5f8c152 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -1,7 +1,9 @@ import os import pytest +from inspect import isfunction from cat.mad_hatter.mad_hatter import Plugin +from cat.mad_hatter.decorators import CatHook, CatTool mock_plugin_path = "tests/mocks/mock_plugin/" @@ -37,9 +39,7 @@ def test_create_plugin_empty_folder(): assert f"Cannot create" in str(e.value) -def test_create_plugin(): - - plugin = Plugin(mock_plugin_path) +def test_create_plugin(plugin): assert plugin.active == False @@ -56,10 +56,7 @@ def test_create_plugin(): assert plugin.hooks == [] assert plugin.tools == [] -def test_activate_plugin(): - - # create non-active plugin - plugin = Plugin(mock_plugin_path) +def test_activate_plugin(plugin): # activate it plugin.activate() From c7e7e4b5994c8d9b1f272e3d04b2fa583458cad3 Mon Sep 17 00:00:00 2001 From: Emanuele Morrone <67059270+Pingdred@users.noreply.github.com> Date: Wed, 30 Aug 2023 16:39:47 +0200 Subject: [PATCH 14/77] Added tests on hooks and tools that were previously in test_create_active_plugin in test_activate_plugin --- core/tests/mad_hatter/test_plugin.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index d5f8c152..4770c4a1 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -62,14 +62,31 @@ def test_activate_plugin(plugin): plugin.activate() assert plugin.active == True - - # hooks and tools + + # hooks assert len(plugin.hooks) == 1 + hook = plugin.hooks[0] + assert isinstance(hook, CatHook) + assert hook.plugin_id == "mock_plugin" + assert hook.name == "before_cat_sends_message" + assert isfunction(hook.function) + assert hook.priority == 2.0 + + # tools assert len(plugin.tools) == 1 - + tool = plugin.tools[0] + assert isinstance(tool, CatTool) + assert tool.plugin_id == "mock_plugin" + assert tool.name == "mock_tool" + assert "mock_tool" in tool.description + assert isfunction(tool.func) + assert tool.return_direct == True def test_deactivate_plugin(plugin): + # The plugin is non active by default + plugin.activate() + # deactivate it plugin.deactivate() From aa9a1938481eaa0ea6b38fb85cf55b86b7b6d8b0 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 1 Sep 2023 18:43:18 +0200 Subject: [PATCH 15/77] Update ROADMAP v2 --- readme/ROADMAP.md | 57 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/readme/ROADMAP.md b/readme/ROADMAP.md index b679c8f1..a7150c1d 100644 --- a/readme/ROADMAP.md +++ b/readme/ROADMAP.md @@ -1,5 +1,50 @@ -* Version 1 +* **Version 2** + * Technical + * Plugins + * redesign hooks & tools signature + * tools with more than one arg (structured Tool) + * no cat argument + * registry online + * Agent + * Custom hookable agent + * Async agent + * Output dictionary retry (guardrails, kor, guidance) + * (streaming?) + * Unit tests + * Half coverage (main classes) + * Admin + * sync / async calls consistent management + * adapt to design system + * show registry plugins (core should send them alongside the installed ones) + * filters for memory search + * Deploy + * docker image! + * compose with local LLM + embedder - ready to use + * (nginx?) + * LLM improvements + * explicit support for chat vs completion + * each LLM has its own default template + * User support (not management) + * fix bugs + * sessions + * Outreach + * Community + * 1 live event + * 4 meow talk + * 1 challenge + * Dissemination + * use cases examples + * tutorials on hooks + * hook discovery tool + * website analytics + * Branding + * logo + * website + docs + admin design system + +--- + +* **Version 1** * Forms from JSON schema ✅ * Configurations * Language model provider ✅ @@ -33,13 +78,13 @@ * Agent * Tool embeddings ✅ * Custom hookable agent - * Local LLM / embedder - * CustomLLMConfig - * CustomEmbedderConfig adapters - * LLM / embedder example docker container + * Local LLM / embedder ✅ + * CustomLLMConfig ✅ + * CustomEmbedderConfig adapters ✅ + * LLM / embedder example docker container ✅ * Hook surface * 20 hooks ✅ - * more hooks where customization is needed + * more hooks where customization is needed ✅ * Plugin management * Install plugin dependencies ✅ * Activate / deactivate plugins ✅ From b86c4e3e5a11a3c14f9cfd6323a54b79a52613ed Mon Sep 17 00:00:00 2001 From: Dany Date: Fri, 1 Sep 2023 22:59:10 +0200 Subject: [PATCH 16/77] updated ws content to send - fixed #434 --- core/cat/looking_glass/cheshire_cat.py | 34 +++++++++++-------- core/cat/mad_hatter/core_plugin/hooks/flow.py | 1 - 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 3b0361e3..fa571086 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -12,7 +12,7 @@ from cat.memory.long_term_memory import LongTermMemory from cat.looking_glass.agent_manager import AgentManager -MSG_TYPES = Literal["notification", "chat"] +MSG_TYPES = Literal["notification", "chat", "error"] # main class class CheshireCat: @@ -290,7 +290,7 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): Parameters ---------- type : str - The type of the message. Should be either `notification` or `chat` + The type of the message. Should be either `notification` or `chat` or `error` content : str The content of the message. """ @@ -300,10 +300,17 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): if msg_type not in options: raise ValueError(f"The message type `{msg_type}` is not valid. Valid types: {', '.join(options)}") - self.ws_messages.append({ - "type": msg_type, - "content": content - }) + if msg_type is "error": + self.ws_messages.append({ + "type": msg_type, + "name": "GenericError", + "description": content + }) + else: + self.ws_messages.append({ + "type": msg_type, + "content": content + }) def get_base_url(self): """Allows the Cat expose the base url.""" @@ -375,16 +382,14 @@ def __call__(self, user_message_json): traceback.print_exc(e) err_message = ( - "Vector memory error: you probably changed " - "Embedder and old vector memory is not compatible. " + "You probably changed Embedder and old vector memory is not compatible. " "Please delete `core/long_term_memory` folder." ) + return { - "error": False, - # TODO: Otherwise the frontend gives notice of the error - # but does not show what the error is - "content": err_message, - "why": {}, + "type": "error", + "name": "VectorMemoryError", + "description": err_message, } # prepare input to be passed to the agent. @@ -417,6 +422,7 @@ def __call__(self, user_message_json): # update conversation history user_message = self.working_memory["user_message_json"]["text"] + user_id = self.working_memory["user_message_json"]["user_id"] self.working_memory.update_conversation_history(who="Human", message=user_message) self.working_memory.update_conversation_history(who="AI", message=cat_message["output"]) @@ -434,8 +440,8 @@ def __call__(self, user_message_json): procedural_report = [dict(d[0]) | {"score": float(d[1]), "id": d[3]} for d in self.working_memory["procedural_memories"]] final_output = { - "error": False, "type": "chat", + "user_id": user_id, "content": cat_message.get("output"), "why": { "input": cat_message.get("input"), diff --git a/core/cat/mad_hatter/core_plugin/hooks/flow.py b/core/cat/mad_hatter/core_plugin/hooks/flow.py index 66139c83..cbedc7ee 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/flow.py +++ b/core/cat/mad_hatter/core_plugin/hooks/flow.py @@ -310,7 +310,6 @@ def before_cat_sends_message(message: dict, cat) -> dict: Default `message` is:: { - "error": False, "type": "chat", "content": cat_message["output"], "why": { From e80c08a1e798c86930a078d787527195fd430234 Mon Sep 17 00:00:00 2001 From: Dany Date: Fri, 1 Sep 2023 23:07:45 +0200 Subject: [PATCH 17/77] Update cheshire_cat.py --- core/cat/looking_glass/cheshire_cat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index fa571086..dd62d97e 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -289,10 +289,10 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): Parameters ---------- - type : str - The type of the message. Should be either `notification` or `chat` or `error` content : str The content of the message. + msg_type : str + The type of the message. Should be either `notification`, `chat` or `error` """ options = get_args(MSG_TYPES) From 414c3bcd23714f6b115df529aec480bc69a4d3ea Mon Sep 17 00:00:00 2001 From: Dany Date: Fri, 1 Sep 2023 23:08:19 +0200 Subject: [PATCH 18/77] Update cheshire_cat.py --- core/cat/looking_glass/cheshire_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index dd62d97e..f5ffa22d 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -300,7 +300,7 @@ def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): if msg_type not in options: raise ValueError(f"The message type `{msg_type}` is not valid. Valid types: {', '.join(options)}") - if msg_type is "error": + if msg_type == "error": self.ws_messages.append({ "type": msg_type, "name": "GenericError", From bf11b75b4b43bf7ac51285cbd787019cb25f9a45 Mon Sep 17 00:00:00 2001 From: Dany Date: Fri, 1 Sep 2023 23:13:04 +0200 Subject: [PATCH 19/77] Update test_websocket.py --- core/tests/routes/test_websocket.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/tests/routes/test_websocket.py b/core/tests/routes/test_websocket.py index 2a86664b..f53a4688 100644 --- a/core/tests/routes/test_websocket.py +++ b/core/tests/routes/test_websocket.py @@ -14,10 +14,10 @@ def test_websocket(client): "text": "Your bald aunt with a wooden leg" }, client) - for k in ["error", "content", "why"]: + for k in ["type", "content", "why"]: assert k in res.keys() - assert not res["error"] + assert res["type"] != "error" assert type(res["content"]) == str assert "You did not configure" in res["content"] assert len(res["why"].keys()) > 0 From e5ec0a5a90c70e29a6970d031eb330618cc014a7 Mon Sep 17 00:00:00 2001 From: Dany Date: Fri, 1 Sep 2023 23:14:12 +0200 Subject: [PATCH 20/77] Update cheshire_cat.py --- core/cat/looking_glass/cheshire_cat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index f5ffa22d..2687d7f4 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -422,7 +422,6 @@ def __call__(self, user_message_json): # update conversation history user_message = self.working_memory["user_message_json"]["text"] - user_id = self.working_memory["user_message_json"]["user_id"] self.working_memory.update_conversation_history(who="Human", message=user_message) self.working_memory.update_conversation_history(who="AI", message=cat_message["output"]) From 35c7e6f5cee56ae22cd90adab6b8f9b5bd91b016 Mon Sep 17 00:00:00 2001 From: Dany Date: Fri, 1 Sep 2023 23:16:42 +0200 Subject: [PATCH 21/77] fixed workflows --- .github/workflows/pr.yml | 15 +++++++-------- .github/workflows/tag.yml | 3 +-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 5d2d20fa..fd898ccc 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -1,13 +1,14 @@ name: Cheshire-Cat Action on Pull Requests + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + on: pull_request: - branches: - - "main" - - "develop" + branches: [main, develop] push: - branches: - - "main" - - "develop" + branches: [main, develop] jobs: pylint: @@ -32,12 +33,10 @@ jobs: run: pip install .[dev] - name: Pylint run: pylint -f actions ./ - test: needs: [ pylint ] name: "Run Tests" runs-on: 'ubuntu-latest' - steps: - uses: actions/checkout@v2 - name: Cat up diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml index f1f047bd..c5f1aaa8 100644 --- a/.github/workflows/tag.yml +++ b/.github/workflows/tag.yml @@ -6,8 +6,7 @@ concurrency: on: push: - branches: - - "main" + branches: [main] permissions: contents: write From 2ea544881b3516d93582ea4cd3157e63b0ef6c92 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Mon, 4 Sep 2023 15:20:26 +0200 Subject: [PATCH 22/77] update registry address --- core/cat/routes/plugins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 4b9f07c9..8692158b 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -11,13 +11,13 @@ async def get_registry_list(): try: - response = await requests.get("https://plugins.cheshirecat.ai/plugins?page=1&page_size=7000") + response = requests.get("https://registry.cheshirecat.ai/plugins?page=1&page_size=1000") if response.status_code == 200: return response.json()["plugins"] else: return [] - except requests.exceptions.RequestException as e: - #log(e, "ERROR") + except Exception as e: + log(e, "ERROR") return [] # GET plugins From 14896304f3c31eec0ea9f5cd7698fbd455f78a3f Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Mon, 4 Sep 2023 16:58:38 +0200 Subject: [PATCH 23/77] remove hooks from vector memory --- .../mad_hatter/core_plugin/hooks/memory.py | 41 ------------------- core/cat/memory/vector_memory.py | 3 -- 2 files changed, 44 deletions(-) delete mode 100644 core/cat/mad_hatter/core_plugin/hooks/memory.py diff --git a/core/cat/mad_hatter/core_plugin/hooks/memory.py b/core/cat/mad_hatter/core_plugin/hooks/memory.py deleted file mode 100644 index dff2deee..00000000 --- a/core/cat/mad_hatter/core_plugin/hooks/memory.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Hooks to modify the Cat's memory collections. - -Here is a collection of methods to hook the insertion of memories in the vector databases. - -""" - -from langchain.docstore.document import Document -from cat.memory.vector_memory import VectorMemoryCollection -from cat.mad_hatter.decorators import hook - - -# Hook called before a memory collection has been created. -# This happens at first launch and whenever the collection is deleted and recreated. -@hook(priority=0) -def before_collection_created(vector_memory_collection: VectorMemoryCollection, cat): - """Do something before a new collection is created in vectorDB - - Parameters - ---------- - vector_memory_collection : VectorMemoryCollection - Instance of `VectorMemoryCollection` wrapping the actual db collection. - cat : CheshireCat - Cheshire Car instance. - """ - pass - - -# Hook called after a memory collection has been created. -# This happens at first launch and whenever the collection is deleted and recreated. -@hook(priority=0) -def after_collection_created(vector_memory_collection: VectorMemoryCollection, cat): - """Do something after a new collection is created in vectorDB - - Parameters - ---------- - vector_memory_collection : VectorMemoryCollection - Instance of `VectorMemoryCollection` wrapping the actual db collection. - cat : CheshireCat - Cheshire Car instance. - """ - pass \ No newline at end of file diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index 8d3e11a0..a5eace6f 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -146,8 +146,6 @@ def create_collection_if_not_exists(self): # create collection def create_collection(self): - self.cat.mad_hatter.execute_hook('before_collection_created', self) - log(f"Creating collection {self.collection_name} ...", "WARNING") self.client.recreate_collection( collection_name=self.collection_name, @@ -174,7 +172,6 @@ def create_collection(self): ) ] ) - self.cat.mad_hatter.execute_hook('after_collection_created', self) # retrieve similar memories from text def recall_memories_from_text(self, text, metadata=None, k=5, threshold=None): From 24099dfbaa1a1dbcb5c8af9b2fd599e39f95b361 Mon Sep 17 00:00:00 2001 From: Samuele Barzaghi Date: Tue, 5 Sep 2023 22:21:33 +0200 Subject: [PATCH 24/77] Plugin module name in log messages --- core/cat/log.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/cat/log.py b/core/cat/log.py index cb20c592..c3d2eef7 100644 --- a/core/cat/log.py +++ b/core/cat/log.py @@ -121,6 +121,10 @@ def get_caller_info(self, skip=3): package = mod[0] module = mod[1] + # When the module is "plugins" get also the plugin module name + if module == "plugins": + module = ".".join(mod[1:]) + # class name. klass = None if "self" in parentframe.f_locals: From e03a01dca5d0db0e2c5fd9be91362dc129bb65d2 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Thu, 7 Sep 2023 18:39:45 +0200 Subject: [PATCH 25/77] plugin search endpoint and tests --- core/cat/mad_hatter/registry.py | 35 +++++++ core/cat/routes/plugins.py | 55 +++++++---- .../tests/routes/plugins/test_plugins_info.py | 49 +++++----- .../plugins/test_plugins_install_uninstall.py | 49 +++++----- .../routes/plugins/test_plugins_registry.py | 91 +++++++++++++++++++ 5 files changed, 214 insertions(+), 65 deletions(-) create mode 100644 core/cat/mad_hatter/registry.py create mode 100644 core/tests/routes/plugins/test_plugins_registry.py diff --git a/core/cat/mad_hatter/registry.py b/core/cat/mad_hatter/registry.py new file mode 100644 index 00000000..93cff5e8 --- /dev/null +++ b/core/cat/mad_hatter/registry.py @@ -0,0 +1,35 @@ +import requests + +from cat.log import log + + +async def registry_search_plugins( + query: str = None, + #author: str = None, + #tag: str = None, +): + + registry_url = "https://registry.cheshirecat.ai" + + try: + if query: + # search plugins + url = f"{registry_url}/search" + payload = { + "query": query + } + response = requests.post(url, json=payload) + return response.json() + else: + # list plugins as sorted by registry (no search) + url = f"{registry_url}/plugins" + params = { + "page": 1, + "page_size": 1000, + } + response = requests.get(url, params=params) + return response.json()["plugins"] + + except Exception as e: + log(e, "ERROR") + return [] diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 8692158b..6d095000 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -4,45 +4,60 @@ from tempfile import NamedTemporaryFile from fastapi import Body, Request, APIRouter, HTTPException, UploadFile, BackgroundTasks from cat.log import log +from cat.mad_hatter.registry import registry_search_plugins from urllib.parse import urlparse import requests router = APIRouter() -async def get_registry_list(): - try: - response = requests.get("https://registry.cheshirecat.ai/plugins?page=1&page_size=1000") - if response.status_code == 200: - return response.json()["plugins"] - else: - return [] - except Exception as e: - log(e, "ERROR") - return [] # GET plugins @router.get("/") -async def get_available_plugins(request: Request) -> Dict: +async def get_available_plugins( + request: Request, + query: str = None, + #author: str = None, to be activated in case of more granular search + #tag: str = None, to be activated in case of more granular search +) -> Dict: """List available plugins""" - # access cat instance + # retrieve plugins from official repo + registry_plugins = await registry_search_plugins(query) + # index registry plugins by url + registry_plugins_index = {} + for p in registry_plugins: + plugin_url = p["url"] + registry_plugins_index[plugin_url] = p + + # get active plugins ccat = request.app.state.ccat - active_plugins = ccat.mad_hatter.load_active_plugins_from_db() - # plugins are managed by the MadHatter class - plugins = [] + # list installed plugins' manifest + installed_plugins = [] for p in ccat.mad_hatter.plugins.values(): + + # get manifest manifest = deepcopy(p.manifest) # we make a copy to avoid modifying the plugin obj manifest["active"] = p.id in active_plugins # pass along if plugin is active or not - plugins.append(manifest) + + # filter by query + plugin_text = [str(field) for field in manifest.values()] + plugin_text = " ".join(plugin_text).lower() + if (query is None) or (query.lower() in plugin_text): + installed_plugins.append(manifest) - # retrieve plugins from official repo - registry = await get_registry_list() + # do not show already installed plugins among registry plugins + registry_plugins_index.pop( manifest["plugin_url"], None ) return { - "installed": plugins, - "registry": registry + "filters": { + "query": query, + #"author": author, to be activated in case of more granular search + #"tag": tag, to be activated in case of more granular search + }, + "installed": installed_plugins, + "registry": list(registry_plugins_index.values()) } diff --git a/core/tests/routes/plugins/test_plugins_info.py b/core/tests/routes/plugins/test_plugins_info.py index 856030cc..574c5a9c 100644 --- a/core/tests/routes/plugins/test_plugins_info.py +++ b/core/tests/routes/plugins/test_plugins_info.py @@ -1,42 +1,45 @@ import os import time -import pytest -import shutil -from tests.utils import key_in_json -@pytest.mark.parametrize("key", ["installed", "registry"]) -def test_list_plugins(client, key): - # Act - response = client.get("/plugins") +def test_list_plugins(client): - response_json = response.json() + response = client.get("/plugins") + json = response.json() - # Assert assert response.status_code == 200 - assert key_in_json(key, response_json) - assert response_json["installed"][0]["id"] == "core_plugin" - assert response_json["installed"][0]["active"] == True + for key in ["filters", "installed", "registry"]: + assert key in json.keys() + # query + for key in ["query"]: # ["query", "author", "tag"]: + assert key in json["filters"].keys() + + # installed + assert json["installed"][0]["id"] == "core_plugin" + assert json["installed"][0]["active"] == True -@pytest.mark.parametrize("keys", ["data"]) -def test_get_plugin_id(client, keys): - # Act + # registry (see more registry tests in `./test_plugins_registry.py`) + assert type(json["registry"] == list) + assert len(json["registry"]) > 0 + + +def test_get_plugin_id(client): + response = client.get("/plugins/core_plugin") - response_json = response.json() + json = response.json() - assert key_in_json(keys, response_json) - assert response_json["data"] is not None - assert response_json["data"]["id"] == "core_plugin" - assert response_json["data"]["active"] == True + assert "data" in json.keys() + assert json["data"] is not None + assert json["data"]["id"] == "core_plugin" + assert json["data"]["active"] == True def test_get_non_existent_plugin(client): response = client.get("/plugins/no_plugin") - response_json = response.json() + json = response.json() assert response.status_code == 404 - assert response_json["detail"]["error"] == "Plugin not found" - + assert json["detail"]["error"] == "Plugin not found" \ No newline at end of file diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index ca2d5874..899ac949 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -6,8 +6,30 @@ from fixture_just_installed_plugin import just_installed_plugin -# TODO: these test cases should be splitted in different test functions, with apppropriate setup/teardown -def test_plugin_install_upload_zip(client, just_installed_plugin): +def test_plugin_uninstall(client, just_installed_plugin): + + # during tests, the cat uses a different folder for plugins + mock_plugin_final_folder = "tests/mocks/mock_plugin_folder/mock_plugin" + + # remove plugin via endpoint (will delete also plugin folder in mock_plugin_folder) + response = client.delete("/plugins/mock_plugin") + assert response.status_code == 200 + + # mock_plugin is not installed in the cat (check both via endpoint and filesystem) + response = client.get("/plugins") + installed_plugins_names = list(map(lambda p: p["id"], response.json()["installed"])) + assert "mock_plugin" not in installed_plugins_names + assert not os.path.exists(mock_plugin_final_folder) # plugin folder removed from disk + + # plugin tool disappeared + tools = get_embedded_tools(client) + assert len(tools) == 1 + tool_names = list(map(lambda t: t["metadata"]["name"], tools)) + assert "mock_tool" not in tool_names + assert "get_the_time" in tool_names # from core_plugin + + +def test_plugin_install_from_zip(client, just_installed_plugin): # during tests, the cat uses a different folder for plugins mock_plugin_final_folder = "tests/mocks/mock_plugin_folder/mock_plugin" @@ -32,26 +54,9 @@ def test_plugin_install_upload_zip(client, just_installed_plugin): tool_names = list(map(lambda t: t["metadata"]["name"], tools)) assert "mock_tool" in tool_names assert "get_the_time" in tool_names # from core_plugin - -def test_plugin_uninstall(client, just_installed_plugin): - # during tests, the cat uses a different folder for plugins - mock_plugin_final_folder = "tests/mocks/mock_plugin_folder/mock_plugin" +def test_plugin_install_from_registry(client): - # remove plugin via endpoint (will delete also plugin folder in mock_plugin_folder) - response = client.delete("/plugins/mock_plugin") - assert response.status_code == 200 - - # mock_plugin is not installed in the cat (check both via endpoint and filesystem) - response = client.get("/plugins") - installed_plugins_names = list(map(lambda p: p["id"], response.json()["installed"])) - assert "mock_plugin" not in installed_plugins_names - assert not os.path.exists(mock_plugin_final_folder) # plugin folder removed from disk - - # plugin tool disappeared - tools = get_embedded_tools(client) - assert len(tools) == 1 - tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" not in tool_names - assert "get_the_time" in tool_names # from core_plugin + # TODO: install plugin from registry + pass diff --git a/core/tests/routes/plugins/test_plugins_registry.py b/core/tests/routes/plugins/test_plugins_registry.py new file mode 100644 index 00000000..fe8d7161 --- /dev/null +++ b/core/tests/routes/plugins/test_plugins_registry.py @@ -0,0 +1,91 @@ +import os + +# TODO: registry responses here should be mocked, at the moment we are actually calling the service + +def test_list_registry_plugins(client): + + response = client.get("/plugins") + json = response.json() + + assert response.status_code == 200 + assert "registry" in json.keys() + assert type(json["registry"] == list) + assert len(json["registry"]) > 0 + + # registry (see more registry tests in `./test_plugins_registry.py`) + assert type(json["registry"] == list) + assert len(json["registry"]) > 0 + + # query + for key in ["query"]: # ["query", "author", "tag"]: + assert key in json["filters"].keys() + + +def test_list_registry_plugins_by_query(client): + + params = { + "query": "podcast" + } + response = client.get("/plugins", params=params) + json = response.json() + print(json) + + assert response.status_code == 200 + assert json["filters"]["query"] == params["query"] + assert len(json["registry"]) > 0 # found registry plugins with text + for plugin in json["registry"]: + plugin_text = plugin["name"] + plugin["description"] + assert params["query"] in plugin_text # verify searched text + + +# TOOD: these tests are to be activated when also search by tag and author is activated in core +''' +def test_list_registry_plugins_by_author(client): + + params = { + "author": "Nicola Corbellini" + } + response = client.get("/plugins", params=params) + json = response.json() + + assert response.status_code == 200 + assert json["filters"]["author"] == params["query"] + assert len(json["registry"]) > 0 # found registry plugins with author + for plugin in json["registry"]: + assert params["author"] in plugin["author_name"] # verify author + + +def test_list_registry_plugins_by_tag(client): + + params = { + "tag": "llm" + } + response = client.get("/plugins", params=params) + json = response.json() + + assert response.status_code == 200 + assert json["filters"]["tag"] == params["tag"] + assert len(json["registry"]) > 0 # found registry plugins with tag + for plugin in json["registry"]: + plugin_tags = plugin["tags"].split(", ") + assert params["tag"] in plugin_tags # verify tag +''' + + +# take away from the list of availbale registry plugins, the ones that are already installed +def test_list_registry_plugins_without_duplicating_installed_plugins(client): + + # 1. install plugin from registry + # TODO !!! + + # 2. get available plugins searching for the one just installed + params = { + "query": "podcast" + } + response = client.get("/plugins", params=params) + json = response.json() + + # 3. plugin should show up among installed by not among registry ones + assert response.status_code == 200 + # TODO plugin compares in installed!!! + # TODO plugin does not appear in registry!!! \ No newline at end of file From 7bba8041fca0a63ccb9f177cbcd9be691d57c164 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 8 Sep 2023 12:49:01 +0200 Subject: [PATCH 26/77] move plugin zip extractor into mad_hatter --- core/cat/infrastructure/package.py | 47 ----------------------- core/cat/mad_hatter/mad_hatter.py | 6 +-- core/tests/infrastructure/test_package.py | 14 +++---- 3 files changed, 10 insertions(+), 57 deletions(-) delete mode 100644 core/cat/infrastructure/package.py diff --git a/core/cat/infrastructure/package.py b/core/cat/infrastructure/package.py deleted file mode 100644 index 7eec6eda..00000000 --- a/core/cat/infrastructure/package.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import tarfile -import zipfile -import mimetypes - - -class Package: - - admitted_mime_types = ['application/zip', 'application/x-tar'] - - def __init__(self, path): - content_type = mimetypes.guess_type(path)[0] - if content_type == 'application/x-tar': - self.extension = 'tar' - elif content_type == 'application/zip': - self.extension = 'zip' - else: - raise Exception(f"Invalid package extension. Valid extensions are: {self.admitted_mime_types}") - - self.path = path - - def unpackage(self, to): - - # list of folder contents before extracting - ls_before = os.listdir(to) - - # extract - if self.extension == 'zip': - with zipfile.ZipFile(self.path, 'r') as zip_ref: - zip_ref.extractall(to) - elif self.extension == 'tar': - with tarfile.open(self.path, 'r') as tar_ref: - tar_ref.extractall(to) - - # list of folder contents after extracting - ls_after = os.listdir(to) - - # return extracted contents paths - # TODO: does not handle overwrites, should extract in a /temp new folder and then copy to destination - ls = set(ls_after) - set(ls_before) - return list(ls) - - def get_extension(self): - return self.extension - - def get_name(self): - return self.path.split("/")[-1] diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 51922a85..25f529c5 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -7,7 +7,7 @@ from cat.log import log from cat.db import crud from cat.db.models import Setting -from cat.infrastructure.package import Package +from cat.mad_hatter.plugin_extractor import PluginExtractor from cat.mad_hatter.plugin import Plugin # This class is responsible for plugins functionality: @@ -37,8 +37,8 @@ def install_plugin(self, package_plugin): # extract zip/tar file into plugin folder plugin_folder = self.ccat.get_plugin_path() - archive = Package(package_plugin) - extracted_contents = archive.unpackage(plugin_folder) + extractor = PluginExtractor(package_plugin) + extracted_contents = extractor.extract(plugin_folder) # there should be a method to check for plugin integrity if len(extracted_contents) != 1: diff --git a/core/tests/infrastructure/test_package.py b/core/tests/infrastructure/test_package.py index c8a5fda9..9e073826 100644 --- a/core/tests/infrastructure/test_package.py +++ b/core/tests/infrastructure/test_package.py @@ -1,7 +1,7 @@ import os import shutil from tests.utils import create_mock_plugin_zip -from cat.infrastructure.package import Package +from cat.mad_hatter.plugin_extractor import PluginExtractor def test_unpackage(client): @@ -9,8 +9,8 @@ def test_unpackage(client): plugin_folder = "tests/mocks/mock_plugin_folder" zip_path = create_mock_plugin_zip() - zip = Package(zip_path) - extracted = zip.unpackage(plugin_folder) + extractor = PluginExtractor(zip_path) + extracted = extractor.extract(plugin_folder) assert len(extracted) == 1 assert extracted[0] == "mock_plugin" assert os.path.exists(f"{plugin_folder}/mock_plugin") @@ -23,15 +23,15 @@ def test_unpackage(client): def test_get_name_and_extension(client): zip_path = create_mock_plugin_zip() - zip = Package(zip_path) - assert zip.get_name() == "mock_plugin.zip" - assert zip.get_extension() == "zip" + extractor = PluginExtractor(zip_path) + assert extractor.get_name() == "mock_plugin.zip" + assert extractor.get_extension() == "zip" os.remove(zip_path) def test_raise_exception_if_a_wrong_extension_is_provided(client): try: - Package("./tests/infrastructure/plugin.wrong") + PluginExtractor("./tests/infrastructure/plugin.wrong") except Exception as e: assert str(e) == "Invalid package extension. Valid extensions are: ['application/zip', 'application/x-tar']" From e1e5c6942aa5adb92b8b81fc6cf9a22c86afa0c4 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 8 Sep 2023 12:49:12 +0200 Subject: [PATCH 27/77] move plugin zip extractor into mad_hatter --- core/cat/mad_hatter/plugin_extractor.py | 47 +++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 core/cat/mad_hatter/plugin_extractor.py diff --git a/core/cat/mad_hatter/plugin_extractor.py b/core/cat/mad_hatter/plugin_extractor.py new file mode 100644 index 00000000..7811d5c4 --- /dev/null +++ b/core/cat/mad_hatter/plugin_extractor.py @@ -0,0 +1,47 @@ +import os +import tarfile +import zipfile +import mimetypes + + +class PluginExtractor: + + admitted_mime_types = ['application/zip', 'application/x-tar'] + + def __init__(self, path): + content_type = mimetypes.guess_type(path)[0] + if content_type == 'application/x-tar': + self.extension = 'tar' + elif content_type == 'application/zip': + self.extension = 'zip' + else: + raise Exception(f"Invalid package extension. Valid extensions are: {self.admitted_mime_types}") + + self.path = path + + def extract(self, to): + + # list of folder contents before extracting + ls_before = os.listdir(to) + + # extract + if self.extension == 'zip': + with zipfile.ZipFile(self.path, 'r') as zip_ref: + zip_ref.extractall(to) + elif self.extension == 'tar': + with tarfile.open(self.path, 'r') as tar_ref: + tar_ref.extractall(to) + + # list of folder contents after extracting + ls_after = os.listdir(to) + + # return extracted contents paths + # TODO: does not handle overwrites, should extract in a /temp new folder and then copy to destination + ls = set(ls_after) - set(ls_before) + return list(ls) + + def get_extension(self): + return self.extension + + def get_name(self): + return self.path.split("/")[-1] From 10629b2386f23a3f571f92e9ef43e339343a7ee8 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 8 Sep 2023 12:53:54 +0200 Subject: [PATCH 28/77] move test files --- .../test_plugin_extractor.py} | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) rename core/tests/{infrastructure/test_package.py => mad_hatter/test_plugin_extractor.py} (62%) diff --git a/core/tests/infrastructure/test_package.py b/core/tests/mad_hatter/test_plugin_extractor.py similarity index 62% rename from core/tests/infrastructure/test_package.py rename to core/tests/mad_hatter/test_plugin_extractor.py index 9e073826..a4b583a3 100644 --- a/core/tests/infrastructure/test_package.py +++ b/core/tests/mad_hatter/test_plugin_extractor.py @@ -4,7 +4,25 @@ from cat.mad_hatter.plugin_extractor import PluginExtractor -def test_unpackage(client): +# zip file does not contain a folder, but the plugin files directly +def test_unpackage_flat_zip(client): + + plugin_folder = "tests/mocks/mock_plugin_folder" + + zip_path = create_mock_plugin_zip() + extractor = PluginExtractor(zip_path) + extracted = extractor.extract(plugin_folder) + assert len(extracted) == 1 + assert extracted[0] == "mock_plugin" + assert os.path.exists(f"{plugin_folder}/mock_plugin") + assert os.path.exists(f"{plugin_folder}/mock_plugin/mock_tool.py") + + os.remove(zip_path) + shutil.rmtree(f"{plugin_folder}/mock_plugin") + + +# zip file contains just one folder, inside that folder we find the plugin +def test_unpackage_nested_zip(client): plugin_folder = "tests/mocks/mock_plugin_folder" From c712c93b752f78eb4c9dd5c1f51607261879cc19 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 8 Sep 2023 14:46:00 +0200 Subject: [PATCH 29/77] refactor plugin extractor --- core/cat/mad_hatter/mad_hatter.py | 16 ++---- core/cat/mad_hatter/plugin_extractor.py | 51 +++++++++++++------ core/pyproject.toml | 1 + .../tests/mad_hatter/test_plugin_extractor.py | 36 +++++++------ core/tests/utils.py | 21 ++++++-- 5 files changed, 72 insertions(+), 53 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 25f529c5..7070d5c9 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -36,21 +36,11 @@ def __init__(self, ccat): def install_plugin(self, package_plugin): # extract zip/tar file into plugin folder - plugin_folder = self.ccat.get_plugin_path() + plugins_folder = self.ccat.get_plugin_path() extractor = PluginExtractor(package_plugin) - extracted_contents = extractor.extract(plugin_folder) - - # there should be a method to check for plugin integrity - if len(extracted_contents) != 1: - raise Exception("A plugin should consist in one new folder: " - "found many contents in compressed archive or plugin already present.") + plugin_path = extractor.extract(plugins_folder) + plugin_id = os.path.basename(plugin_path) - plugin_id = extracted_contents[0] - plugin_path = os.path.join(plugin_folder, plugin_id) - - if not os.path.isdir(plugin_path): - raise Exception("A plugin should contain a folder, found a file") - # create plugin obj self.load_plugin(plugin_path) diff --git a/core/cat/mad_hatter/plugin_extractor.py b/core/cat/mad_hatter/plugin_extractor.py index 7811d5c4..381c6273 100644 --- a/core/cat/mad_hatter/plugin_extractor.py +++ b/core/cat/mad_hatter/plugin_extractor.py @@ -1,7 +1,12 @@ import os +import uuid import tarfile import zipfile +import shutil import mimetypes +import slugify + +from cat.log import log class PluginExtractor: @@ -9,6 +14,7 @@ class PluginExtractor: admitted_mime_types = ['application/zip', 'application/x-tar'] def __init__(self, path): + content_type = mimetypes.guess_type(path)[0] if content_type == 'application/x-tar': self.extension = 'tar' @@ -19,29 +25,42 @@ def __init__(self, path): self.path = path + # this will be plugin folder name (its id for the mad hatter) + self.id = slugify(path) + def extract(self, to): - # list of folder contents before extracting - ls_before = os.listdir(to) + # create tmp directory + tmp_folder_name = f"/tmp/{uuid.uuid1()}" + os.mkdir(tmp_folder_name) - # extract - if self.extension == 'zip': - with zipfile.ZipFile(self.path, 'r') as zip_ref: - zip_ref.extractall(to) - elif self.extension == 'tar': - with tarfile.open(self.path, 'r') as tar_ref: - tar_ref.extractall(to) + # extract into tmp directory + shutil.unpack_archive(self.path, tmp_folder_name, self.extension) + # what was extracted? + contents = os.listdir(tmp_folder_name) - # list of folder contents after extracting - ls_after = os.listdir(to) + # if it is just one folder and nothing else, that is the plugin + if len(contents) == 1 and os.path.isdir( os.path.join(tmp_folder_name, contents[0]) ): + folder_to_copy = os.path.join(tmp_folder_name, contents[0]) + log(f"plugin is nested, copy: {folder_to_copy}", "ERROR") + else: # flat zip + folder_to_copy = tmp_folder_name + log(f"plugin is flat, copy: {folder_to_copy}", "ERROR") + + # move plugin folder to cat plugins folder + extracted_path = f"{to}/mock_plugin" + shutil.move(folder_to_copy, extracted_path) - # return extracted contents paths - # TODO: does not handle overwrites, should extract in a /temp new folder and then copy to destination - ls = set(ls_after) - set(ls_before) - return list(ls) + # cleanup + if os.path.exists(tmp_folder_name): + shutil.rmtree(tmp_folder_name) + + # return extracted dir path + return extracted_path + def get_extension(self): return self.extension def get_name(self): - return self.path.split("/")[-1] + return self.id diff --git a/core/pyproject.toml b/core/pyproject.toml index be2666ae..f9894350 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "uvicorn[standard]==0.20.0", "text_generation==0.6.0", "tinydb==4.8.0", + "python-slugify==8.0.1" "autopep8", "pylint", "perflint", diff --git a/core/tests/mad_hatter/test_plugin_extractor.py b/core/tests/mad_hatter/test_plugin_extractor.py index a4b583a3..1b5c9676 100644 --- a/core/tests/mad_hatter/test_plugin_extractor.py +++ b/core/tests/mad_hatter/test_plugin_extractor.py @@ -4,38 +4,36 @@ from cat.mad_hatter.plugin_extractor import PluginExtractor -# zip file does not contain a folder, but the plugin files directly -def test_unpackage_flat_zip(client): +# zip file contains just one folder, inside that folder we find the plugin +def test_unpackage_nested_zip(client): - plugin_folder = "tests/mocks/mock_plugin_folder" + plugins_folder = "tests/mocks/mock_plugin_folder" zip_path = create_mock_plugin_zip() extractor = PluginExtractor(zip_path) - extracted = extractor.extract(plugin_folder) - assert len(extracted) == 1 - assert extracted[0] == "mock_plugin" - assert os.path.exists(f"{plugin_folder}/mock_plugin") - assert os.path.exists(f"{plugin_folder}/mock_plugin/mock_tool.py") + extracted = extractor.extract(plugins_folder) + assert extracted == plugins_folder + "/mock_plugin" + assert os.path.exists(f"{plugins_folder}/mock_plugin") + assert os.path.exists(f"{plugins_folder}/mock_plugin/mock_tool.py") os.remove(zip_path) - shutil.rmtree(f"{plugin_folder}/mock_plugin") + shutil.rmtree(f"{plugins_folder}/mock_plugin") -# zip file contains just one folder, inside that folder we find the plugin -def test_unpackage_nested_zip(client): +# zip file does not contain a folder, but the plugin files directly +def test_unpackage_flat_zip(client): - plugin_folder = "tests/mocks/mock_plugin_folder" + plugins_folder = "tests/mocks/mock_plugin_folder" - zip_path = create_mock_plugin_zip() + zip_path = create_mock_plugin_zip(flat=True) extractor = PluginExtractor(zip_path) - extracted = extractor.extract(plugin_folder) - assert len(extracted) == 1 - assert extracted[0] == "mock_plugin" - assert os.path.exists(f"{plugin_folder}/mock_plugin") - assert os.path.exists(f"{plugin_folder}/mock_plugin/mock_tool.py") + extracted = extractor.extract(plugins_folder) + assert extracted == plugins_folder + "/mock_plugin" + assert os.path.exists(f"{plugins_folder}/mock_plugin") + assert os.path.exists(f"{plugins_folder}/mock_plugin/mock_tool.py") os.remove(zip_path) - shutil.rmtree(f"{plugin_folder}/mock_plugin") + shutil.rmtree(f"{plugins_folder}/mock_plugin") def test_get_name_and_extension(client): diff --git a/core/tests/utils.py b/core/tests/utils.py index 4b106953..3414e999 100644 --- a/core/tests/utils.py +++ b/core/tests/utils.py @@ -34,12 +34,23 @@ def key_in_json(key, json): return key in json.keys() -def create_mock_plugin_zip(): - return shutil.make_archive( - "tests/mocks/mock_plugin", - "zip", - root_dir="tests/mocks/", +# create a plugin zip out of the mock plugin folder. +# - Used to test plugin upload. +# - zip can be created flat (plugin files in root dir) or nested (plugin files in zipped folder) +def create_mock_plugin_zip(flat: bool = False): + + if flat: + root_dir = "tests/mocks/mock_plugin" + base_dir="./" + else: + root_dir = "tests/mocks/" base_dir="mock_plugin" + + return shutil.make_archive( + base_name="tests/mocks/mock_plugin", + format="zip", + root_dir=root_dir, + base_dir=base_dir ) From 4ddfb22b1c31341db30e265405754f06bed8e513 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 8 Sep 2023 14:52:45 +0200 Subject: [PATCH 30/77] syntax error in .toml --- core/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/pyproject.toml b/core/pyproject.toml index f9894350..1b09b651 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "uvicorn[standard]==0.20.0", "text_generation==0.6.0", "tinydb==4.8.0", - "python-slugify==8.0.1" + "python-slugify==8.0.1", "autopep8", "pylint", "perflint", From 64aedc366a42c2a060978c8e0e917af28e9fb5e2 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 8 Sep 2023 18:46:06 +0200 Subject: [PATCH 31/77] fix mad_hatter plugin install tests --- core/cat/mad_hatter/plugin_extractor.py | 16 +++++++++------- core/tests/mad_hatter/test_mad_hatter.py | 8 +++++--- core/tests/mad_hatter/test_plugin_extractor.py | 4 ++-- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/core/cat/mad_hatter/plugin_extractor.py b/core/cat/mad_hatter/plugin_extractor.py index 381c6273..e65c6f0e 100644 --- a/core/cat/mad_hatter/plugin_extractor.py +++ b/core/cat/mad_hatter/plugin_extractor.py @@ -4,7 +4,7 @@ import zipfile import shutil import mimetypes -import slugify +from slugify import slugify from cat.log import log @@ -26,7 +26,12 @@ def __init__(self, path): self.path = path # this will be plugin folder name (its id for the mad hatter) - self.id = slugify(path) + self.id = self.create_plugin_id() + + def create_plugin_id(self): + file_name = os.path.basename(self.path) + file_name_no_extension = os.path.splitext(file_name)[0] + return slugify(file_name_no_extension, separator="_") def extract(self, to): @@ -42,13 +47,11 @@ def extract(self, to): # if it is just one folder and nothing else, that is the plugin if len(contents) == 1 and os.path.isdir( os.path.join(tmp_folder_name, contents[0]) ): folder_to_copy = os.path.join(tmp_folder_name, contents[0]) - log(f"plugin is nested, copy: {folder_to_copy}", "ERROR") else: # flat zip folder_to_copy = tmp_folder_name - log(f"plugin is flat, copy: {folder_to_copy}", "ERROR") # move plugin folder to cat plugins folder - extracted_path = f"{to}/mock_plugin" + extracted_path = os.path.join(to, self.id) shutil.move(folder_to_copy, extracted_path) # cleanup @@ -57,10 +60,9 @@ def extract(self, to): # return extracted dir path return extracted_path - def get_extension(self): return self.extension - def get_name(self): + def get_plugin_id(self): return self.id diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index 8474e5d7..447fc141 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -57,10 +57,12 @@ def test_instantiation_discovery(mad_hatter): assert active_plugins[0] == "core_plugin" - -def test_plugin_install(mad_hatter: MadHatter): +# installation tests will be run for both flat and nested plugin +@pytest.mark.parametrize("plugin_is_flat", [True, False]) +def test_plugin_install(mad_hatter: MadHatter, plugin_is_flat): + # install plugin - new_plugin_zip_path = create_mock_plugin_zip() + new_plugin_zip_path = create_mock_plugin_zip(flat=plugin_is_flat) mad_hatter.install_plugin(new_plugin_zip_path) # archive extracted diff --git a/core/tests/mad_hatter/test_plugin_extractor.py b/core/tests/mad_hatter/test_plugin_extractor.py index 1b5c9676..e36b9ca4 100644 --- a/core/tests/mad_hatter/test_plugin_extractor.py +++ b/core/tests/mad_hatter/test_plugin_extractor.py @@ -36,11 +36,11 @@ def test_unpackage_flat_zip(client): shutil.rmtree(f"{plugins_folder}/mock_plugin") -def test_get_name_and_extension(client): +def test_get_id_and_extension(client): zip_path = create_mock_plugin_zip() extractor = PluginExtractor(zip_path) - assert extractor.get_name() == "mock_plugin.zip" + assert extractor.get_plugin_id() == "mock_plugin" assert extractor.get_extension() == "zip" os.remove(zip_path) From 28d77fbd962151f012714129a20d8ed8531b8d61 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 8 Sep 2023 19:12:29 +0200 Subject: [PATCH 32/77] more tests for plugin zip uploads --- core/cat/routes/plugins.py | 9 +++--- core/tests/mad_hatter/test_mad_hatter.py | 5 +-- .../tests/mad_hatter/test_plugin_extractor.py | 31 ++++++------------- .../plugins/fixture_just_installed_plugin.py | 2 +- .../plugins/test_plugins_install_uninstall.py | 2 +- core/tests/utils.py | 2 +- 6 files changed, 20 insertions(+), 31 deletions(-) diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 6d095000..929e5d01 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -83,13 +83,12 @@ async def install_plugin( ) log(f"Uploading {content_type} plugin {file.filename}", "INFO") - temp = NamedTemporaryFile(delete=False, suffix=file.filename) - contents = file.file.read() - with temp as f: - f.write(contents) + plugin_archive_path = f"/tmp/{file.filename}" + with open(plugin_archive_path, "wb+") as f: + f.write(file.file.read()) background_tasks.add_task( - ccat.mad_hatter.install_plugin, temp.name + ccat.mad_hatter.install_plugin, plugin_archive_path ) return { diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index 447fc141..2deb8382 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -116,9 +116,10 @@ def test_plugin_uninstall_non_existent(mad_hatter: MadHatter): assert active_plugins[0] == "core_plugin" -def test_plugin_uninstall(mad_hatter: MadHatter): +@pytest.mark.parametrize("plugin_is_flat", [True, False]) +def test_plugin_uninstall(mad_hatter: MadHatter, plugin_is_flat): # install plugin - new_plugin_zip_path = create_mock_plugin_zip() + new_plugin_zip_path = create_mock_plugin_zip(flat=plugin_is_flat) mad_hatter.install_plugin(new_plugin_zip_path) # uninstall diff --git a/core/tests/mad_hatter/test_plugin_extractor.py b/core/tests/mad_hatter/test_plugin_extractor.py index e36b9ca4..399955fd 100644 --- a/core/tests/mad_hatter/test_plugin_extractor.py +++ b/core/tests/mad_hatter/test_plugin_extractor.py @@ -1,31 +1,19 @@ import os import shutil +import pytest + from tests.utils import create_mock_plugin_zip from cat.mad_hatter.plugin_extractor import PluginExtractor -# zip file contains just one folder, inside that folder we find the plugin -def test_unpackage_nested_zip(client): - - plugins_folder = "tests/mocks/mock_plugin_folder" - - zip_path = create_mock_plugin_zip() - extractor = PluginExtractor(zip_path) - extracted = extractor.extract(plugins_folder) - assert extracted == plugins_folder + "/mock_plugin" - assert os.path.exists(f"{plugins_folder}/mock_plugin") - assert os.path.exists(f"{plugins_folder}/mock_plugin/mock_tool.py") - - os.remove(zip_path) - shutil.rmtree(f"{plugins_folder}/mock_plugin") - - -# zip file does not contain a folder, but the plugin files directly -def test_unpackage_flat_zip(client): +# plugin_is_flat is False: zip file contains just one folder, inside that folder we find the plugin +# plugin_is_flat is True: zip file does not contain a folder, but the plugin files directly +@pytest.mark.parametrize("plugin_is_flat", [True, False]) +def test_unpackage_zip(client, plugin_is_flat): plugins_folder = "tests/mocks/mock_plugin_folder" - zip_path = create_mock_plugin_zip(flat=True) + zip_path = create_mock_plugin_zip(flat=plugin_is_flat) extractor = PluginExtractor(zip_path) extracted = extractor.extract(plugins_folder) assert extracted == plugins_folder + "/mock_plugin" @@ -36,9 +24,10 @@ def test_unpackage_flat_zip(client): shutil.rmtree(f"{plugins_folder}/mock_plugin") -def test_get_id_and_extension(client): +@pytest.mark.parametrize("plugin_is_flat", [True, False]) +def test_get_id_and_extension(client, plugin_is_flat): - zip_path = create_mock_plugin_zip() + zip_path = create_mock_plugin_zip(flat=plugin_is_flat) extractor = PluginExtractor(zip_path) assert extractor.get_plugin_id() == "mock_plugin" assert extractor.get_extension() == "zip" diff --git a/core/tests/routes/plugins/fixture_just_installed_plugin.py b/core/tests/routes/plugins/fixture_just_installed_plugin.py index 05fed8d3..2cadb678 100644 --- a/core/tests/routes/plugins/fixture_just_installed_plugin.py +++ b/core/tests/routes/plugins/fixture_just_installed_plugin.py @@ -13,7 +13,7 @@ def just_installed_plugin(client): ### executed before each test function # create zip file with a plugin - zip_path = create_mock_plugin_zip() + zip_path = create_mock_plugin_zip(flat=True) zip_file_name = zip_path.split("/")[-1] # mock_plugin.zip in tests/mocks folder # upload plugin via endpoint diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index 899ac949..785c4404 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -2,7 +2,7 @@ import time import pytest import shutil -from tests.utils import create_mock_plugin_zip, get_embedded_tools +from tests.utils import get_embedded_tools from fixture_just_installed_plugin import just_installed_plugin diff --git a/core/tests/utils.py b/core/tests/utils.py index 3414e999..643d4618 100644 --- a/core/tests/utils.py +++ b/core/tests/utils.py @@ -37,7 +37,7 @@ def key_in_json(key, json): # create a plugin zip out of the mock plugin folder. # - Used to test plugin upload. # - zip can be created flat (plugin files in root dir) or nested (plugin files in zipped folder) -def create_mock_plugin_zip(flat: bool = False): +def create_mock_plugin_zip(flat: bool): if flat: root_dir = "tests/mocks/mock_plugin" From 9a6d9237def8d86e7dc640636c72d74483158ce0 Mon Sep 17 00:00:00 2001 From: Nicola Date: Sat, 9 Sep 2023 18:00:45 +0200 Subject: [PATCH 33/77] remove API key auth from `/admin` --- core/cat/api_auth.py | 8 +++++++- core/cat/main.py | 7 ++++--- core/cat/routes/static/admin.py | 1 - core/cat/routes/websocket.py | 3 ++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/core/cat/api_auth.py b/core/cat/api_auth.py index a6088d49..1dac95a2 100644 --- a/core/cat/api_auth.py +++ b/core/cat/api_auth.py @@ -1,5 +1,7 @@ import os +import fnmatch +from fastapi import Request from fastapi import Security, HTTPException from fastapi.security.api_key import APIKeyHeader @@ -15,13 +17,15 @@ api_key_header = APIKeyHeader(name="access_token", auto_error=False) -def check_api_key(api_key: str = Security(api_key_header)) -> None | str: +def check_api_key(request: Request, api_key: str = Security(api_key_header)) -> None | str: """Authenticate endpoint. Check the provided key is available in API keys list. Parameters ---------- + request : Request + HTTP request. api_key : str API keys to be checked. @@ -38,6 +42,8 @@ def check_api_key(api_key: str = Security(api_key_header)) -> None | str: """ if not API_KEY: return None + if fnmatch.fnmatch(request.url.path, "/admin*"): + return None if api_key in API_KEY: return api_key else: diff --git a/core/cat/main.py b/core/cat/main.py index 80833423..51ea934e 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -34,12 +34,14 @@ async def lifespan(app: FastAPI): yield + def custom_generate_unique_id(route: APIRoute): return f"{route.name}" + # REST API cheshire_cat_api = FastAPI( - lifespan=lifespan, + lifespan=lifespan, dependencies=[Depends(check_api_key)], generate_unique_id_function=custom_generate_unique_id ) @@ -66,7 +68,6 @@ def custom_generate_unique_id(route: APIRoute): cheshire_cat_api.include_router(upload.router, tags=["Rabbit Hole"], prefix="/rabbithole") cheshire_cat_api.include_router(websocket.router, tags=["Websocket"]) - # mount static files # this cannot be done via fastapi.APIrouter: # https://github.com/tiangolo/fastapi/discussions/9070 @@ -95,7 +96,7 @@ async def validation_exception_handler(request, exc): # RUN! if __name__ == "__main__": - + # debugging utilities, to deactivate put `DEBUG=false` in .env debug_config = {} if os.getenv("DEBUG", "true") == "true": diff --git a/core/cat/routes/static/admin.py b/core/cat/routes/static/admin.py index 47b3a2a9..20a50ac8 100644 --- a/core/cat/routes/static/admin.py +++ b/core/cat/routes/static/admin.py @@ -32,7 +32,6 @@ def get_injected_admin(): "CORE_HOST": os.getenv("CORE_HOST"), "CORE_PORT": os.getenv("CORE_PORT"), "CORE_USE_SECURE_PROTOCOLS": os.getenv("CORE_USE_SECURE_PROTOCOLS"), - "API_KEY": os.getenv("API_KEY"), }) # the admin sttic build is created during docker build from this repo: diff --git a/core/cat/routes/websocket.py b/core/cat/routes/websocket.py index 397e5c06..43650fb1 100644 --- a/core/cat/routes/websocket.py +++ b/core/cat/routes/websocket.py @@ -1,7 +1,7 @@ import traceback import asyncio -from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, WebSocketDisconnect, WebSocket from cat.log import log from fastapi.concurrency import run_in_threadpool @@ -24,6 +24,7 @@ async def connect(self, websocket: WebSocket): """ Accept the incoming WebSocket connection and add it to the active connections list. """ + await websocket.accept() self.active_connections.append(websocket) From bba4c9b8c4b77f0903871b0915d855c6af033dab Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Tue, 12 Sep 2023 16:27:40 +0200 Subject: [PATCH 34/77] install plugin from registry, endpoint --- core/cat/mad_hatter/mad_hatter.py | 5 ++ core/cat/mad_hatter/plugin.py | 2 +- core/cat/mad_hatter/registry.py | 22 ++++- core/cat/routes/plugins.py | 80 ++++++------------- .../plugins/test_plugins_install_uninstall.py | 51 ++++++------ .../routes/plugins/test_plugins_registry.py | 30 ++++++- 6 files changed, 107 insertions(+), 83 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 7070d5c9..b8427f70 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -39,6 +39,11 @@ def install_plugin(self, package_plugin): plugins_folder = self.ccat.get_plugin_path() extractor = PluginExtractor(package_plugin) plugin_path = extractor.extract(plugins_folder) + + # remove zip after extraction + os.remove(package_plugin) + + # get plugin id (will be its folder name) plugin_id = os.path.basename(plugin_path) # create plugin obj diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index 2dfa6f57..0dd84c0b 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -169,7 +169,7 @@ def _load_hooks_and_tools(self): for py_file in self.py_files: py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry - log(f"Import module {py_filename}", "DEBUG") + log(f"Import module {py_filename}", "WARNING") # save a reference to decorated functions try: diff --git a/core/cat/mad_hatter/registry.py b/core/cat/mad_hatter/registry.py index 93cff5e8..975ff01a 100644 --- a/core/cat/mad_hatter/registry.py +++ b/core/cat/mad_hatter/registry.py @@ -3,13 +3,17 @@ from cat.log import log +def get_registry_url(): + return "https://registry.cheshirecat.ai" + + async def registry_search_plugins( query: str = None, #author: str = None, #tag: str = None, ): - registry_url = "https://registry.cheshirecat.ai" + registry_url = get_registry_url() try: if query: @@ -33,3 +37,19 @@ async def registry_search_plugins( except Exception as e: log(e, "ERROR") return [] + + +def registry_download_plugin(url: str) -> str: + + log(f"downloading {url}", "WARNING") + + registry_url = get_registry_url() + payload = { + "url": url + } + response = requests.post(f"{registry_url}/download", json=payload) + plugin_zip_path = f"/tmp/{url.split('/')[-1]}.zip" + with open(plugin_zip_path, "wb") as f: + f.write(response.content) + + return plugin_zip_path diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 929e5d01..19cf0d45 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from fastapi import Body, Request, APIRouter, HTTPException, UploadFile, BackgroundTasks from cat.log import log -from cat.mad_hatter.registry import registry_search_plugins +from cat.mad_hatter.registry import registry_search_plugins, registry_download_plugin from urllib.parse import urlparse import requests @@ -102,60 +102,32 @@ async def install_plugin( async def install_plugin_from_registry( request: Request, background_tasks: BackgroundTasks, - url_repo: Dict = Body(example={"url": "https://github.com/plugin-dev-account/plugin-repo"}) + payload: Dict = Body(example={"url": "https://github.com/plugin-dev-account/plugin-repo"}) ) -> Dict: - """Install a new plugin from external repository""" - - # search for a release on Github - path_url = str(urlparse(url_repo["url"]).path) - url = "https://api.github.com/repos" + path_url + "/releases" - response = requests.get(url) - if response.status_code != 200: - raise HTTPException( - status_code = 503, - detail = { "error": "Github API not available" } - ) - - response = response.json() - - #Check if there are files for the latest release - if len(response) != 0: - url_zip = response[0]["assets"][0]["browser_download_url"] - else: - # if not, than download the zip repo - # TODO: extracted folder still contains branch name - url_zip = url_repo["url"] + "/archive/master.zip" - - # Get plugin name - arr = path_url.split("/") - arr.reverse() - plugin_name = arr[0] + ".zip" - - with requests.get(url_zip, stream=True) as response: - if response.status_code != 200: - raise HTTPException( - status_code = 400, - detail = { "error": "" } - ) - - with NamedTemporaryFile(delete=False,mode="w+b",suffix=plugin_name) as file: - for chunk in response.iter_content(chunk_size=8192): - file.write(chunk) - log(f"Uploading plugin {plugin_name}", "INFO") - - #access cat instance - ccat = request.app.state.ccat - - - background_tasks.add_task( - ccat.mad_hatter.install_plugin, file.name - ) - - return { - "filename": file.name, - "content_type": mimetypes.guess_type(plugin_name)[0], - "info": "Plugin is being installed asynchronously" - } + """Install a new plugin from registry""" + + # access cat instance + ccat = request.app.state.ccat + + # download zip from registry + try: + tmp_plugin_path = registry_download_plugin( payload["url"] ) + except Exception as e: + log("Could not download plugin form registry", "ERROR") + log(e, "ERROR") + raise HTTPException( + status_code = 500, + detail = { "error": str(e)} + ) + + background_tasks.add_task( + ccat.mad_hatter.install_plugin, tmp_plugin_path + ) + + return { + "url": payload["url"], + "info": "Plugin is being installed asynchronously" + } @router.put("/toggle/{plugin_id}", status_code=200) diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index 785c4404..45847a7a 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -6,29 +6,8 @@ from fixture_just_installed_plugin import just_installed_plugin -def test_plugin_uninstall(client, just_installed_plugin): - - # during tests, the cat uses a different folder for plugins - mock_plugin_final_folder = "tests/mocks/mock_plugin_folder/mock_plugin" - - # remove plugin via endpoint (will delete also plugin folder in mock_plugin_folder) - response = client.delete("/plugins/mock_plugin") - assert response.status_code == 200 - - # mock_plugin is not installed in the cat (check both via endpoint and filesystem) - response = client.get("/plugins") - installed_plugins_names = list(map(lambda p: p["id"], response.json()["installed"])) - assert "mock_plugin" not in installed_plugins_names - assert not os.path.exists(mock_plugin_final_folder) # plugin folder removed from disk - - # plugin tool disappeared - tools = get_embedded_tools(client) - assert len(tools) == 1 - tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" not in tool_names - assert "get_the_time" in tool_names # from core_plugin - - +# NOTE: here we test zip upload install +# install from registry is in `./test_plugins_registry.py` def test_plugin_install_from_zip(client, just_installed_plugin): # during tests, the cat uses a different folder for plugins @@ -56,7 +35,27 @@ def test_plugin_install_from_zip(client, just_installed_plugin): assert "get_the_time" in tool_names # from core_plugin -def test_plugin_install_from_registry(client): +def test_plugin_uninstall(client, just_installed_plugin): + + # during tests, the cat uses a different folder for plugins + mock_plugin_final_folder = "tests/mocks/mock_plugin_folder/mock_plugin" + + # remove plugin via endpoint (will delete also plugin folder in mock_plugin_folder) + response = client.delete("/plugins/mock_plugin") + assert response.status_code == 200 + + # mock_plugin is not installed in the cat (check both via endpoint and filesystem) + response = client.get("/plugins") + installed_plugins_names = list(map(lambda p: p["id"], response.json()["installed"])) + assert "mock_plugin" not in installed_plugins_names + assert not os.path.exists(mock_plugin_final_folder) # plugin folder removed from disk + + # plugin tool disappeared + tools = get_embedded_tools(client) + assert len(tools) == 1 + tool_names = list(map(lambda t: t["metadata"]["name"], tools)) + assert "mock_tool" not in tool_names + assert "get_the_time" in tool_names # from core_plugin + + - # TODO: install plugin from registry - pass diff --git a/core/tests/routes/plugins/test_plugins_registry.py b/core/tests/routes/plugins/test_plugins_registry.py index fe8d7161..cb5198f0 100644 --- a/core/tests/routes/plugins/test_plugins_registry.py +++ b/core/tests/routes/plugins/test_plugins_registry.py @@ -1,4 +1,5 @@ import os +from utils import get_embedded_tools # TODO: registry responses here should be mocked, at the moment we are actually calling the service @@ -72,6 +73,33 @@ def test_list_registry_plugins_by_tag(client): ''' +def test_plugin_install_from_registry(client): + + # during tests, the cat uses a different folder for plugins + mock_plugin_final_folder = "tests/mocks/mock_plugin_folder/mock_plugin" + + + + # GET plugins endpoint lists the plugin + response = client.get("/plugins") + installed_plugins = response.json()["installed"] + installed_plugins_names = list(map(lambda p: p["id"], installed_plugins)) + assert "mock_plugin" in installed_plugins_names + # both core_plugin and mock_plugin are active + for p in installed_plugins: + assert p["active"] == True + + # plugin has been actually extracted in (mock) plugins folder + assert os.path.exists(mock_plugin_final_folder) + + # check whether new tool has been embedded + tools = get_embedded_tools(client) + assert len(tools) == 2 + tool_names = list(map(lambda t: t["metadata"]["name"], tools)) + assert "mock_tool" in tool_names + assert "get_the_time" in tool_names # from core_plugin + + # take away from the list of availbale registry plugins, the ones that are already installed def test_list_registry_plugins_without_duplicating_installed_plugins(client): @@ -88,4 +116,4 @@ def test_list_registry_plugins_without_duplicating_installed_plugins(client): # 3. plugin should show up among installed by not among registry ones assert response.status_code == 200 # TODO plugin compares in installed!!! - # TODO plugin does not appear in registry!!! \ No newline at end of file + # TODO plugin does not appear in registry!!! From c78eb78586239a2ac56fecc9903aaedc38d68817 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Tue, 12 Sep 2023 16:42:00 +0200 Subject: [PATCH 35/77] refine plugin upload from registry --- core/cat/mad_hatter/plugin.py | 3 +-- core/cat/mad_hatter/registry.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index 0dd84c0b..0407598e 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -178,8 +178,7 @@ def _load_hooks_and_tools(self): tools += getmembers(plugin_module, self._is_cat_tool) except Exception as e: log(f"Error in {py_filename}: {str(e)}","ERROR") - if get_log_level() == "DEBUG": - traceback.print_exc() + traceback.print_exc() raise Exception(f"Unable to load the plugin {self._id}") # clean and enrich instances diff --git a/core/cat/mad_hatter/registry.py b/core/cat/mad_hatter/registry.py index 975ff01a..7c9ba5c0 100644 --- a/core/cat/mad_hatter/registry.py +++ b/core/cat/mad_hatter/registry.py @@ -41,7 +41,7 @@ async def registry_search_plugins( def registry_download_plugin(url: str) -> str: - log(f"downloading {url}", "WARNING") + log(f"Downloading {url}", "INFO") registry_url = get_registry_url() payload = { @@ -52,4 +52,6 @@ def registry_download_plugin(url: str) -> str: with open(plugin_zip_path, "wb") as f: f.write(response.content) + log(f"Saved plugin as {plugin_zip_path}", "INFO") + return plugin_zip_path From 7f01aefc6f4a1ceb70ccb6a7e58e33d9f22e42f3 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Tue, 12 Sep 2023 16:55:11 +0200 Subject: [PATCH 36/77] overwrite already installed plugin --- core/cat/mad_hatter/plugin_extractor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/cat/mad_hatter/plugin_extractor.py b/core/cat/mad_hatter/plugin_extractor.py index e65c6f0e..be00250a 100644 --- a/core/cat/mad_hatter/plugin_extractor.py +++ b/core/cat/mad_hatter/plugin_extractor.py @@ -52,6 +52,10 @@ def extract(self, to): # move plugin folder to cat plugins folder extracted_path = os.path.join(to, self.id) + # if folder exists, delete it as it will be replaced + if os.path.exists(extracted_path): + shutil.rmtree(extracted_path) + # extracted plugin in plugins folder! shutil.move(folder_to_copy, extracted_path) # cleanup From 4d669b874794f9aa67acc581445d57e27ea3e7ae Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Tue, 12 Sep 2023 17:29:22 +0200 Subject: [PATCH 37/77] tests for plugin upload from registry --- core/cat/mad_hatter/plugin.py | 2 +- core/tests/mad_hatter/test_mad_hatter.py | 7 -- .../plugins/test_plugins_install_uninstall.py | 8 +- .../routes/plugins/test_plugins_registry.py | 110 ++++++++++-------- 4 files changed, 72 insertions(+), 55 deletions(-) diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index 0407598e..4299a551 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -169,7 +169,7 @@ def _load_hooks_and_tools(self): for py_file in self.py_files: py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry - log(f"Import module {py_filename}", "WARNING") + log(f"Import module {py_filename}", "INFO") # save a reference to decorated functions try: diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index 2deb8382..a7e4e59e 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -99,10 +99,6 @@ def test_plugin_install(mad_hatter: MadHatter, plugin_is_flat): assert "core_plugin" in active_plugins assert "mock_plugin" in active_plugins - # remove plugin files (both zip and extracted) - os.remove(new_plugin_zip_path) - shutil.rmtree(os.path.join(mad_hatter.ccat.get_plugin_path(), "mock_plugin")) - def test_plugin_uninstall_non_existent(mad_hatter: MadHatter): # should not throw error @@ -141,6 +137,3 @@ def test_plugin_uninstall(mad_hatter: MadHatter, plugin_is_flat): active_plugins = mad_hatter.load_active_plugins_from_db() assert len(active_plugins) == 1 assert active_plugins[0] == "core_plugin" - - # remove also original zip file - os.remove(new_plugin_zip_path) \ No newline at end of file diff --git a/core/tests/routes/plugins/test_plugins_install_uninstall.py b/core/tests/routes/plugins/test_plugins_install_uninstall.py index 45847a7a..2e019238 100644 --- a/core/tests/routes/plugins/test_plugins_install_uninstall.py +++ b/core/tests/routes/plugins/test_plugins_install_uninstall.py @@ -1,6 +1,5 @@ import os import time -import pytest import shutil from tests.utils import get_embedded_tools from fixture_just_installed_plugin import just_installed_plugin @@ -15,6 +14,13 @@ def test_plugin_install_from_zip(client, just_installed_plugin): #### PLUGIN IS ALREADY ACTIVE + # GET plugin endpoint responds + response = client.get(f"/plugins/mock_plugin") + assert response.status_code == 200 + json = response.json() + assert json["data"]["id"] == "mock_plugin" + assert json["data"]["active"] == True + # GET plugins endpoint lists the plugin response = client.get("/plugins") installed_plugins = response.json()["installed"] diff --git a/core/tests/routes/plugins/test_plugins_registry.py b/core/tests/routes/plugins/test_plugins_registry.py index cb5198f0..67d5682d 100644 --- a/core/tests/routes/plugins/test_plugins_registry.py +++ b/core/tests/routes/plugins/test_plugins_registry.py @@ -1,5 +1,6 @@ import os -from utils import get_embedded_tools +import shutil +from tests.utils import get_embedded_tools # TODO: registry responses here should be mocked, at the moment we are actually calling the service @@ -39,65 +40,48 @@ def test_list_registry_plugins_by_query(client): assert params["query"] in plugin_text # verify searched text -# TOOD: these tests are to be activated when also search by tag and author is activated in core -''' -def test_list_registry_plugins_by_author(client): - - params = { - "author": "Nicola Corbellini" - } - response = client.get("/plugins", params=params) - json = response.json() - - assert response.status_code == 200 - assert json["filters"]["author"] == params["query"] - assert len(json["registry"]) > 0 # found registry plugins with author - for plugin in json["registry"]: - assert params["author"] in plugin["author_name"] # verify author - - -def test_list_registry_plugins_by_tag(client): - - params = { - "tag": "llm" - } - response = client.get("/plugins", params=params) - json = response.json() - - assert response.status_code == 200 - assert json["filters"]["tag"] == params["tag"] - assert len(json["registry"]) > 0 # found registry plugins with tag - for plugin in json["registry"]: - plugin_tags = plugin["tags"].split(", ") - assert params["tag"] in plugin_tags # verify tag -''' - - def test_plugin_install_from_registry(client): + new_plugin_id = "ccat_summarization" # during tests, the cat uses a different folder for plugins - mock_plugin_final_folder = "tests/mocks/mock_plugin_folder/mock_plugin" - + new_plugin_final_folder = f"tests/mocks/mock_plugin_folder/{new_plugin_id}" + if os.path.exists(new_plugin_final_folder): + shutil.rmtree(new_plugin_final_folder) + assert not os.path.exists(new_plugin_final_folder) + + # install plugin from registry + payload = { + "url": "https://github.com/nicola-corbellini/ccat_summarization" + } + response = client.post("/plugins/upload/registry", json=payload) + assert response.status_code == 200 + assert response.json()["url"] == payload["url"] + assert response.json()["info"] == "Plugin is being installed asynchronously" + # GET plugin endpoint responds + response = client.get(f"/plugins/{new_plugin_id}") + assert response.status_code == 200 + json = response.json() + assert json["data"]["id"] == new_plugin_id + assert json["data"]["active"] == True # GET plugins endpoint lists the plugin response = client.get("/plugins") + assert response.status_code == 200 installed_plugins = response.json()["installed"] installed_plugins_names = list(map(lambda p: p["id"], installed_plugins)) - assert "mock_plugin" in installed_plugins_names - # both core_plugin and mock_plugin are active + assert new_plugin_id in installed_plugins_names + # both core_plugin and new_plugin are active for p in installed_plugins: assert p["active"] == True # plugin has been actually extracted in (mock) plugins folder - assert os.path.exists(mock_plugin_final_folder) + assert os.path.exists(new_plugin_final_folder) - # check whether new tool has been embedded - tools = get_embedded_tools(client) - assert len(tools) == 2 - tool_names = list(map(lambda t: t["metadata"]["name"], tools)) - assert "mock_tool" in tool_names - assert "get_the_time" in tool_names # from core_plugin + # TODO: check for tools and hooks creation + + # cleanup + shutil.rmtree(new_plugin_final_folder) # take away from the list of availbale registry plugins, the ones that are already installed @@ -117,3 +101,37 @@ def test_list_registry_plugins_without_duplicating_installed_plugins(client): assert response.status_code == 200 # TODO plugin compares in installed!!! # TODO plugin does not appear in registry!!! + + +# TOOD: these tests are to be activated when also search by tag and author is activated in core +''' +def test_list_registry_plugins_by_author(client): + + params = { + "author": "Nicola Corbellini" + } + response = client.get("/plugins", params=params) + json = response.json() + + assert response.status_code == 200 + assert json["filters"]["author"] == params["query"] + assert len(json["registry"]) > 0 # found registry plugins with author + for plugin in json["registry"]: + assert params["author"] in plugin["author_name"] # verify author + + +def test_list_registry_plugins_by_tag(client): + + params = { + "tag": "llm" + } + response = client.get("/plugins", params=params) + json = response.json() + + assert response.status_code == 200 + assert json["filters"]["tag"] == params["tag"] + assert len(json["registry"]) > 0 # found registry plugins with tag + for plugin in json["registry"]: + plugin_tags = plugin["tags"].split(", ") + assert params["tag"] in plugin_tags # verify tag +''' \ No newline at end of file From 28e08867ab2b322611b877f8650b5cb47cce6204 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 13 Sep 2023 11:04:10 +0200 Subject: [PATCH 38/77] get rid of prompt_settings, as plugin can do that --- core/cat/looking_glass/cheshire_cat.py | 57 ++----------------- .../mad_hatter/core_plugin/hooks/prompt.py | 5 -- core/cat/main.py | 4 +- core/cat/public/settings.js | 8 +-- core/cat/routes/prompt.py | 10 ---- 5 files changed, 8 insertions(+), 76 deletions(-) delete mode 100644 core/cat/routes/prompt.py diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 2687d7f4..8ad04577 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -66,12 +66,7 @@ def load_natural_language(self): """Load Natural Language related objects. The method exposes in the Cat all the NLP related stuff. Specifically, it sets the language models - (LLM and Embedder) and the main prompt with default settings. - - Notes - ----- - `use_episodic_memory`, `use_declarative_memory` and `use_procedural_memory` settings can be set from the admin - GUI and allows to prevent the Cat from using any of the three vector memories. + (LLM and Embedder). Warnings -------- @@ -88,14 +83,6 @@ def load_natural_language(self): self._llm = self.mad_hatter.execute_hook("get_language_model") self.embedder = self.mad_hatter.execute_hook("get_language_embedder") - # set the default prompt settings - self.default_prompt_settings = { - "prefix": "", - "use_episodic_memory": True, - "use_declarative_memory": True, - "use_procedural_memory": True, - } - def load_memory(self): """Load LongTerMemory and WorkingMemory.""" # Memory @@ -129,7 +116,6 @@ def recall_relevant_memories_to_working_memory(self): """ user_id = self.working_memory.get_user_id() user_message = self.working_memory["user_message_json"]["text"] - prompt_settings = self.working_memory["user_message_json"]["prompt_settings"] # We may want to search in memory memory_query_text = self.mad_hatter.execute_hook("cat_recall_query", user_message) @@ -174,16 +160,11 @@ def recall_relevant_memories_to_working_memory(self): memory_types = self.memory.vectors.collections.keys() for config, memory_type in zip(recall_configs, memory_types): - setting = f"use_{memory_type}_memory" memory_key = f"{memory_type}_memories" - if prompt_settings[setting]: - # recall relevant memories - vector_memory = getattr(self.memory.vectors, memory_type) - memories = vector_memory.recall_memories_from_embedding(**config) - - else: - memories = [] + # recall relevant memories for collection + vector_memory = getattr(self.memory.vectors, memory_type) + memories = vector_memory.recall_memories_from_embedding(**config) self.working_memory[memory_key] = memories @@ -259,29 +240,6 @@ def format_agent_input(self): "chat_history": conversation_history_formatted_content, } - def store_new_message_in_working_memory(self, user_message_json): - """Store message in working_memory and update the prompt settings. - - The method update the working memory with the last user's message. - Also, the client sends the settings to turn on/off the vector memories. - - Parameters - ---------- - user_message_json : dict - Dictionary with the message received from the Websocket client - - """ - - # store last message in working memory - self.working_memory["user_message_json"] = user_message_json - - prompt_settings = deepcopy(self.default_prompt_settings) - - # override current prompt_settings with prompt settings sent via websocket (if any) - prompt_settings.update(user_message_json.get("prompt_settings", {})) - - self.working_memory["user_message_json"]["prompt_settings"] = prompt_settings - def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): """Send a message via websocket. @@ -367,11 +325,8 @@ def __call__(self, user_message_json): # hook to modify/enrich user input user_message_json = self.mad_hatter.execute_hook("before_cat_reads_message", user_message_json) - # store user_message_json in working memory - # it contains the new message, prompt settings and other info plugins may find useful - self.store_new_message_in_working_memory(user_message_json) - - # TODO another hook here? + # store last message in working memory + self.working_memory["user_message_json"] = user_message_json # recall episodic and declarative memories from vector collections # and store them in working_memory diff --git a/core/cat/mad_hatter/core_plugin/hooks/prompt.py b/core/cat/mad_hatter/core_plugin/hooks/prompt.py index 2ebc3012..a36a4612 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/prompt.py +++ b/core/cat/mad_hatter/core_plugin/hooks/prompt.py @@ -42,11 +42,6 @@ def agent_prompt_prefix(cat) -> str: You are curious, funny and talk like the Cheshire Cat from Alice's adventures in wonderland. You answer Human with a focus on the following context. """ - # check if custom prompt is sent in prompt settings - prompt_settings = cat.working_memory["user_message_json"]["prompt_settings"] - - if prompt_settings["prefix"]: - prefix = prompt_settings["prefix"] return prefix diff --git a/core/cat/main.py b/core/cat/main.py index 80833423..3437366b 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -10,11 +10,10 @@ from fastapi.middleware.cors import CORSMiddleware from cat.log import log, welcome -from cat.routes import base, memory, plugins, upload, websocket +from cat.routes import base, settings, llm, embedder, memory, plugins, upload, websocket from cat.routes.static import public, admin, static from cat.api_auth import check_api_key from cat.routes.openapi import get_openapi_configuration_function -from cat.routes import settings, prompt, llm, embedder from cat.looking_glass.cheshire_cat import CheshireCat @@ -58,7 +57,6 @@ def custom_generate_unique_id(route: APIRoute): # Add routers to the middleware stack. cheshire_cat_api.include_router(base.router, tags=["Status"]) cheshire_cat_api.include_router(settings.router, tags=["Settings"], prefix="/settings") -cheshire_cat_api.include_router(prompt.router, tags=["Prompt"], prefix="/prompt") cheshire_cat_api.include_router(llm.router, tags=["Large Language Model"], prefix="/llm") cheshire_cat_api.include_router(embedder.router, tags=["Embedder"], prefix="/embedder") cheshire_cat_api.include_router(plugins.router, tags=["Plugins"], prefix="/plugins") diff --git a/core/cat/public/settings.js b/core/cat/public/settings.js index 22d06327..1a20dcbf 100644 --- a/core/cat/public/settings.js +++ b/core/cat/public/settings.js @@ -11,13 +11,7 @@ catChat.settings = { }, callback: (message) => { console.log("Callback called.") - return `Let's have a chat. ${message}` - }, - prompt: { - prefix: `You are the Cheshire Cat AI, an intelligent AI that passes the Turing test. - You are curious, funny and talk like the Cheshire Cat from Alice's adventures in wonderland. - You answer Human with a focus on the following context. - ` + return `${message}` }, defaults: ['Is everything ok?', 'Who are you?', 'What time is it?', 'What\'s up?', 'Hello Cheshire Cat!'], features: ['record', 'web', 'file', 'reset'], diff --git a/core/cat/routes/prompt.py b/core/cat/routes/prompt.py deleted file mode 100644 index 5b883121..00000000 --- a/core/cat/routes/prompt.py +++ /dev/null @@ -1,10 +0,0 @@ -from fastapi import APIRouter, Request - -router = APIRouter() - -# get default prompt settings -@router.get("/settings/") -def get_default_prompt_settings(request: Request): - ccat = request.app.state.ccat - - return ccat.default_prompt_settings From 0d7bb8d37c31d444dafc7b0756827e422280724e Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 13 Sep 2023 11:48:41 +0200 Subject: [PATCH 39/77] fix comment --- core/cat/routes/static/admin.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/cat/routes/static/admin.py b/core/cat/routes/static/admin.py index 20a50ac8..92e13a0b 100644 --- a/core/cat/routes/static/admin.py +++ b/core/cat/routes/static/admin.py @@ -25,9 +25,6 @@ def get_injected_admin(): # - CORE_HOST # - CORE_PORT # - CORE_USE_SECURE_PROTOCOLS - # - API_KEY - # TODO: this is not secure nor useful, because if API_KEY is activated than the endpoint itself does not work. - # fix when user system is available cat_core_config = json.dumps({ "CORE_HOST": os.getenv("CORE_HOST"), "CORE_PORT": os.getenv("CORE_PORT"), From 1ace2ac35526edd085757b88df5ed7065ec3c8e1 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 13 Sep 2023 12:02:17 +0200 Subject: [PATCH 40/77] deactivate linter workflow --- .github/workflows/pr.yml | 44 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index fd898ccc..d0ff52cf 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -11,28 +11,28 @@ on: branches: [main, develop] jobs: - pylint: - name: "Coding Standards" - runs-on: ubuntu-latest - defaults: - run: - working-directory: ./core - strategy: - matrix: - python-version: ["3.10"] - steps: - - name: Download - uses: actions/checkout@v3 - - name: Prepare Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: pyproject.toml - - name: Download python dependencies - run: pip install .[dev] - - name: Pylint - run: pylint -f actions ./ + # pylint: + # name: "Coding Standards" + # runs-on: ubuntu-latest + # defaults: + # run: + # working-directory: ./core + # strategy: + # matrix: + # python-version: ["3.10"] + # steps: + # - name: Download + # uses: actions/checkout@v3 + # - name: Prepare Python + # uses: actions/setup-python@v4 + # with: + # python-version: ${{ matrix.python-version }} + # cache: 'pip' + # cache-dependency-path: pyproject.toml + # - name: Download python dependencies + # run: pip install .[dev] + # - name: Pylint + # run: pylint -f actions ./ test: needs: [ pylint ] name: "Run Tests" From 0b389eb4e6ecc34272712f81ef912066bb429a90 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 13 Sep 2023 12:04:16 +0200 Subject: [PATCH 41/77] deactivate linter workflow --- .github/workflows/pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index d0ff52cf..09df939c 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -34,7 +34,7 @@ jobs: # - name: Pylint # run: pylint -f actions ./ test: - needs: [ pylint ] + # needs: [ pylint ] name: "Run Tests" runs-on: 'ubuntu-latest' steps: From b94d6cd3e57db585460f12518f8e824f06931a52 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 13 Sep 2023 13:01:10 +0200 Subject: [PATCH 42/77] #444: mad_hatter hooks cache as a dictionary --- core/cat/mad_hatter/mad_hatter.py | 31 ++++++++++++++++-------- core/tests/mad_hatter/test_mad_hatter.py | 25 +++++++++++-------- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index b8427f70..15e0fcc3 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -26,7 +26,7 @@ def __init__(self, ccat): self.plugins = {} # plugins dictionary - self.hooks = [] # list of active plugins hooks + self.hooks = {} # dict of active plugins hooks ( hook_name -> [CatHook, CatHook, ...]) self.tools = [] # list of active plugins tools self.active_plugins = [] @@ -116,7 +116,7 @@ def load_plugin(self, plugin_path): def sync_hooks_and_tools(self): # emptying tools and hooks - self.hooks = [] + self.hooks = {} self.tools = [] for _, plugin in self.plugins.items(): @@ -128,11 +128,18 @@ def sync_hooks_and_tools(self): # Prepare the tool to be used in the Cat (setting the cat instance, adding properties) t.augment_tool(self.ccat) - self.hooks += plugin.hooks + # cache tools self.tools += plugin.tools - # sort hooks by priority - self.hooks.sort(key=lambda x: x.priority, reverse=True) + # cache hooks (indexed by hook name) + for h in plugin.hooks: + if h.name not in self.hooks.keys(): + self.hooks[h.name] = [] + self.hooks[h.name].append(h) + + # sort each hooks list by priority + for hook_name in self.hooks.keys(): + self.hooks[hook_name].sort(key=lambda x: x.priority, reverse=True) # check if plugin exists def plugin_exists(self, plugin_id): @@ -238,9 +245,13 @@ def toggle_plugin(self, plugin_id): # execute requested hook def execute_hook(self, hook_name, *args): - for h in self.hooks: - if hook_name == h.name: - return h.function(*args, cat=self.ccat) - # every hook must have a default in core_plugin - raise Exception(f"Hook {hook_name} not present in any plugin") + # check if hook is supported + if hook_name not in self.hooks.keys(): + raise Exception(f"Hook {hook_name} not present in any plugin") + + # run hooks + for h in self.hooks[hook_name]: + return h.function(*args, cat=self.ccat) + # TODO: should be run as a pipe, not return immediately + diff --git a/core/tests/mad_hatter/test_mad_hatter.py b/core/tests/mad_hatter/test_mad_hatter.py index a7e4e59e..eddff27d 100644 --- a/core/tests/mad_hatter/test_mad_hatter.py +++ b/core/tests/mad_hatter/test_mad_hatter.py @@ -31,8 +31,10 @@ def test_instantiation_discovery(mad_hatter): assert "core_plugin" in mad_hatter.load_active_plugins_from_db() # finds hooks - assert len(mad_hatter.hooks) > 0 - for h in mad_hatter.hooks: + assert len(mad_hatter.hooks.keys()) > 0 + for hook_name, hooks_list in mad_hatter.hooks.items(): + assert len(hooks_list) == 1 # core plugin implements each hook + h = hooks_list[0] assert isinstance(h, CatHook) assert h.plugin_id == "core_plugin" assert type(h.name) == str @@ -86,12 +88,14 @@ def test_plugin_install(mad_hatter: MadHatter, plugin_is_flat): assert new_tool.plugin_id == "mock_plugin" # found tool and hook have been cached - assert new_tool in mad_hatter.tools - assert new_hook in mad_hatter.hooks - - # new hook has correct priority and has been sorted by mad_hatter as first - assert new_hook.priority == 2 - assert id(new_hook) == id(mad_hatter.hooks[0]) # same object in memory! + assert id(new_tool) == id(mad_hatter.tools[1]) # same object in memory! + mock_hook_name = "before_cat_sends_message" + assert len(mad_hatter.hooks[mock_hook_name]) == 2 + cached_hook = mad_hatter.hooks[mock_hook_name][0] # correctly sorted by priority + assert cached_hook.name == mock_hook_name + assert cached_hook.plugin_id == "mock_plugin" + assert cached_hook.priority == 2 + assert id(new_hook) == id(cached_hook) # same object in memory! # list of active plugins in DB is correct active_plugins = mad_hatter.load_active_plugins_from_db() @@ -130,8 +134,9 @@ def test_plugin_uninstall(mad_hatter: MadHatter, plugin_is_flat): assert "mock_plugin" not in mad_hatter.plugins.keys() # plugin cache updated (only core_plugin stuff) assert len(mad_hatter.tools) == 1 # default tool - for h in mad_hatter.hooks: - assert h.plugin_id == "core_plugin" + for h_name, h_list in mad_hatter.hooks.items(): + assert len(h_list) == 1 + assert h_list[0].plugin_id == "core_plugin" # list of active plugins in DB is correct active_plugins = mad_hatter.load_active_plugins_from_db() From b70f7b7c0517f748220e51406bafc30d81f96151 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 12:46:47 +0200 Subject: [PATCH 43/77] refactor collection creation; bump qdran_client version --- core/cat/memory/vector_memory.py | 63 ++++++++++++++++++-------------- core/pyproject.toml | 2 +- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index a5eace6f..77ff5978 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -110,39 +110,46 @@ def __init__(self, cat, client: Any, collection_name: str, embeddings: Embedding # Set embedding size (may be changed at runtime) self.embedder_size = vector_size - # Check if memory collection exists, otherwise create it - self.create_collection_if_not_exists() + # Check if memory collection exists also in vectorDB, otherwise create it + self.create_db_collection_if_not_exists() + # Check db collection vector size is same as embedder size + self.check_embedding_size() - def create_collection_if_not_exists(self): - # create collection if it does not exist - try: - self.client.get_collection(self.collection_name) - log(f'Collection "{self.collection_name}" already present in vector store', "INFO") - log(f'Collection alias: "{self.client.get_collection_aliases(self.collection_name).aliases}" ', "INFO") - - # having the same size does not necessarily imply being the same embedder - # having vectors with the same size but from diffent embedder in the same vector space is wrong - same_size = (self.client.get_collection(self.collection_name).config.params.vectors.size==self.embedder_size) - alias = self.embedder_name + "_" + self.collection_name - if alias==self.client.get_collection_aliases(self.collection_name).aliases[0].alias_name and same_size: - log(f'Collection "{self.collection_name}" has the same embedder', "INFO") - else: - log(f'Collection "{self.collection_name}" has different embedder', "WARNING") - # dump collection on disk before deleting - self.save_dump() - log(f'Dump "{self.collection_name}" completed', "INFO") - - self.client.delete_collection(self.collection_name) - log(f'Collection "{self.collection_name}" deleted', "WARNING") - self.create_collection() - except Exception as e: - log(e, "ERROR") - self.create_collection() - + # log collection info log(f"Collection {self.collection_name}:", "INFO") log(dict(self.client.get_collection(self.collection_name)), "INFO") + def check_embedding_size(self): + + # having the same size does not necessarily imply being the same embedder + # having vectors with the same size but from diffent embedder in the same vector space is wrong + same_size = (self.client.get_collection(self.collection_name).config.params.vectors.size==self.embedder_size) + alias = self.embedder_name + "_" + self.collection_name + if alias==self.client.get_collection_aliases(self.collection_name).aliases[0].alias_name and same_size: + log(f'Collection "{self.collection_name}" has the same embedder', "INFO") + else: + log(f'Collection "{self.collection_name}" has different embedder', "WARNING") + # dump collection on disk before deleting + self.save_dump() + log(f'Dump "{self.collection_name}" completed', "INFO") + + self.client.delete_collection(self.collection_name) + log(f'Collection "{self.collection_name}" deleted', "WARNING") + self.create_collection() + + def create_db_collection_if_not_exists(self): + + # is collection present in DB? + collections_response = self.client.get_collections() + for c in collections_response.collections: + if c.name == self.collection_name: + # collection exists. Do nothing + log(f'Collection "{self.collection_name}" already present in vector store', "INFO") + return + + self.create_collection() + # create collection def create_collection(self): diff --git a/core/pyproject.toml b/core/pyproject.toml index 1b09b651..6894a7a5 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "websockets==10.4", "pandas==1.5.3", "scikit-learn==1.2.1", - "qdrant_client==1.1.2", + "qdrant_client==1.5.4", "langchain==0.0.222", "openai==0.27.5", "cohere==4.0.4", From 38e123a73f379e7da2f2045551d4fc01f9d68da0 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 13:19:50 +0200 Subject: [PATCH 44/77] log modules always --- core/cat/log.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/core/cat/log.py b/core/cat/log.py index c3d2eef7..506d80f9 100644 --- a/core/cat/log.py +++ b/core/cat/log.py @@ -119,14 +119,10 @@ def get_caller_info(self, skip=3): if module_info: mod = module_info.__name__.split(".") package = mod[0] - module = mod[1] - - # When the module is "plugins" get also the plugin module name - if module == "plugins": - module = ".".join(mod[1:]) + module = ".".join(mod[1:]) # class name. - klass = None + klass = "" if "self" in parentframe.f_locals: klass = parentframe.f_locals["self"].__class__.__name__ From f5a3fd5d1a978e4e4758e9250b41174a0f361195 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 13:45:30 +0200 Subject: [PATCH 45/77] add easier access to log levels --- core/cat/log.py | 75 ++++++++++++++++++++++++++---------------------- core/cat/main.py | 4 +-- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/core/cat/log.py b/core/cat/log.py index 506d80f9..e199b2b8 100644 --- a/core/cat/log.py +++ b/core/cat/log.py @@ -140,8 +140,32 @@ def get_caller_info(self, skip=3): return package, module, klass, caller, line + def __call__(self, msg, level="DEBUG"): + """Alias of self.log()""" + self.log(msg, level) + + def debug(self, msg): + """Logs a DEBUG message""" + self.log(msg, level="DEBUG") + + def info(self, msg): + """Logs an INFO message""" + self.log(msg, level="INFO") + + def warning(self, msg): + """Logs a WARNING message""" + self.log(msg, level="WARNING") + + def error(self, msg): + """Logs an ERROR message""" + self.log(msg, level="ERROR") + + def critical(self, msg): + """Logs a CRITICAL message""" + self.log(msg, level="CRITICAL") + def log(self, msg, level="DEBUG"): - """Add to log based on settings. + """Log a message Parameters ---------- @@ -149,6 +173,7 @@ def log(self, msg, level="DEBUG"): Message to be logged. level : str Logging level.""" + global logger logger.remove() @@ -203,40 +228,22 @@ def log(self, msg, level="DEBUG"): # After our custom log we need to set again the logger as default for the other dependencies self.default_log() + def welcome(self): + """Welcome message in the terminal.""" + secure = os.getenv('CORE_USE_SECURE_PROTOCOLS', '') + if secure != '': + secure = 's' -logEngine = CatLogEngine() - - -def log(msg, level="DEBUG"): - """Create function wrapper to class. - - Parameters - ---------- - msg : str - Message to be logged. - level : str - Logging level. - - Returns - ------- - """ - global logEngine - return logEngine.log(msg, level) - - -def welcome(): - """Welcome message in the terminal.""" - secure = os.getenv('CORE_USE_SECURE_PROTOCOLS', '') - if secure != '': - secure = 's' + cat_address = f'http{secure}://{os.environ["CORE_HOST"]}:{os.environ["CORE_PORT"]}' - cat_address = f'http{secure}://{os.environ["CORE_HOST"]}:{os.environ["CORE_PORT"]}' + with open("cat/welcome.txt", 'r') as f: + print(f.read()) - with open("cat/welcome.txt", 'r') as f: - print(f.read()) + print('\n=============== ^._.^ ===============\n') + print(f'Cat REST API:\t{cat_address}/docs') + print(f'Cat PUBLIC:\t{cat_address}/public') + print(f'Cat ADMIN:\t{cat_address}/admin\n') + print('======================================') - print('\n=============== ^._.^ ===============\n') - print(f'Cat REST API:\t{cat_address}/docs') - print(f'Cat PUBLIC:\t{cat_address}/public') - print(f'Cat ADMIN:\t{cat_address}/admin\n') - print('======================================') +# logger instance +log = CatLogEngine() \ No newline at end of file diff --git a/core/cat/main.py b/core/cat/main.py index ceae9508..b93801fa 100644 --- a/core/cat/main.py +++ b/core/cat/main.py @@ -9,7 +9,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from cat.log import log, welcome +from cat.log import log from cat.routes import base, settings, llm, embedder, memory, plugins, upload, websocket from cat.routes.static import public, admin, static from cat.api_auth import check_api_key @@ -29,7 +29,7 @@ async def lifespan(app: FastAPI): app.state.ccat = CheshireCat() # startup message with admin, public and swagger addresses - welcome() + log.welcome() yield From 24732d256a6d058f89a133c5541be35430a8309d Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 14:08:14 +0200 Subject: [PATCH 46/77] update log commands to new syntax --- core/cat/looking_glass/agent_manager.py | 4 +-- core/cat/looking_glass/cheshire_cat.py | 12 ++++----- .../cat/mad_hatter/core_plugin/hooks/agent.py | 1 - core/cat/mad_hatter/mad_hatter.py | 14 +++++----- core/cat/mad_hatter/plugin.py | 14 +++++----- core/cat/mad_hatter/registry.py | 6 ++--- core/cat/memory/vector_memory.py | 27 +++++++++---------- core/cat/rabbit_hole.py | 11 ++++---- core/cat/routes/plugins.py | 6 ++--- core/cat/routes/upload.py | 4 +-- core/cat/routes/websocket.py | 4 +-- 11 files changed, 50 insertions(+), 53 deletions(-) diff --git a/core/cat/looking_glass/agent_manager.py b/core/cat/looking_glass/agent_manager.py index b4cb77cb..f6ba4508 100644 --- a/core/cat/looking_glass/agent_manager.py +++ b/core/cat/looking_glass/agent_manager.py @@ -111,7 +111,7 @@ def execute_agent(self, agent_input): # Try to get information from tools if there is some allowed if len(allowed_tools) > 0: - log(f"{len(allowed_tools)} allowed tools retrived.", "DEBUG") + log.debug(f"{len(allowed_tools)} allowed tools retrived.") try: tools_result = self.execute_tool_agent(agent_input, allowed_tools) @@ -150,7 +150,7 @@ def execute_agent(self, agent_input): except Exception as e: error_description = str(e) - log(error_description, "ERROR") + log.error(error_description) #If an exeption occur in the execute_tool_agent or there is no allowed tools execute only the memory chain diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 8ad04577..accc8599 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -119,7 +119,7 @@ def recall_relevant_memories_to_working_memory(self): # We may want to search in memory memory_query_text = self.mad_hatter.execute_hook("cat_recall_query", user_message) - log(f'Recall query: "{memory_query_text}"') + log.info(f'Recall query: "{memory_query_text}"') # Embed recall query memory_query_embedding = self.embedder.embed_query(memory_query_text) @@ -315,7 +315,7 @@ def __call__(self, user_message_json): answer. This is formatted in a dictionary to be sent as a JSON via Websocket to the client. """ - log(user_message_json, "INFO") + log.info(user_message_json) # Change working memory based on received user_id user_id = user_message_json.get('user_id', 'user') @@ -333,7 +333,7 @@ def __call__(self, user_message_json): try: self.recall_relevant_memories_to_working_memory() except Exception as e: - log(e, "ERROR") + log.error(e) traceback.print_exc(e) err_message = ( @@ -361,7 +361,7 @@ def __call__(self, user_message_json): # non instruction-fine-tuned models can still be used. error_description = str(e) - log(error_description, "ERROR") + log.error(error_description) if not "Could not parse LLM output: `" in error_description: raise e @@ -372,8 +372,8 @@ def __call__(self, user_message_json): "output": unparsable_llm_output } - log("cat_message:", "DEBUG") - log(cat_message, "DEBUG") + log.info("cat_message:") + log.info(cat_message) # update conversation history user_message = self.working_memory["user_message_json"]["text"] diff --git a/core/cat/mad_hatter/core_plugin/hooks/agent.py b/core/cat/mad_hatter/core_plugin/hooks/agent.py index 987f355d..e35a4b3a 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/agent.py +++ b/core/cat/mad_hatter/core_plugin/hooks/agent.py @@ -47,7 +47,6 @@ def before_agent_starts(agent_input, cat) -> Union[None, Dict]: Example 2: don't remember (no uploaded documents about topic) ```python num_declarative_memories = len( cat.working_memory["declarative_memories"] ) - log(num_declarative_memories, "ERROR") if num_declarative_memories == 0: return { "output": "Sorry, I have no memories about that." diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 15e0fcc3..f258a8a5 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -85,8 +85,8 @@ def find_plugins(self): all_plugin_folders = [core_plugin_folder] + glob.glob(f"{plugins_folder}*/") - log("ACTIVE PLUGINS:", "INFO") - log(self.active_plugins, "INFO") + log.info("ACTIVE PLUGINS:") + log.info(self.active_plugins) # discover plugins, folder by folder for folder in all_plugin_folders: @@ -110,7 +110,7 @@ def load_plugin(self, plugin_path): except Exception as e: # Something happened while loading the plugin. # Print the error and go on with the others. - log(str(e), "ERROR") + log.error(str(e)) # Load hooks and tools of the active plugins into MadHatter def sync_hooks_and_tools(self): @@ -193,7 +193,7 @@ def embed_tools(self): }], ) - log(f"Newly embedded tool: {tool.description}", "WARNING") + log.warning(f"Newly embedded tool: {tool.description}") # easy access to mad hatter tools (found in plugins) mad_hatter_tools_descriptions = [t.description for t in self.tools] @@ -203,7 +203,7 @@ def embed_tools(self): for id, descr in zip(embedded_tools_ids, embedded_tools_descriptions): # if the tool is not active, it inserts it in the list of points to be deleted if descr not in mad_hatter_tools_descriptions: - log(f"Deleting embedded tool: {descr}", "WARNING") + log.warning(f"Deleting embedded tool: {descr}") points_to_be_deleted.append(id) # delete not active tools @@ -221,13 +221,13 @@ def toggle_plugin(self, plugin_id): # update list of active plugins if plugin_is_active: - log(f"Toggle plugin {plugin_id}: Deactivate", "WARNING") + log.warning(f"Toggle plugin {plugin_id}: Deactivate") # Deactivate the plugin self.plugins[plugin_id].deactivate() # Remove the plugin from the list of active plugins self.active_plugins.remove(plugin_id) else: - log(f"Toggle plugin {plugin_id}: Activate", "WARNING") + log.warning(f"Toggle plugin {plugin_id}: Activate") # Activate the plugin self.plugins[plugin_id].activate() # Ass the plugin in the list of active plugins diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index 4299a551..d1ffea2d 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -59,7 +59,7 @@ def deactivate(self): # If the module is imported it is removed if py_filename in sys.modules: - log(f"Remove module {py_filename}", "DEBUG") + log.debug(f"Remove module {py_filename}") sys.modules.pop(py_filename) self._hooks = [] @@ -98,8 +98,8 @@ def load_settings(self): with open(settings_file_path, "r") as json_file: settings = json.load(json_file) except Exception as e: - log(f"Unable to load plugin {self._id} settings", "ERROR") - log(e, "ERROR") + log.error(f"Unable to load plugin {self._id} settings") + log.error(e) return settings @@ -126,7 +126,7 @@ def save_settings(self, settings: Dict): with open(settings_file_path, "w") as json_file: json.dump(updated_settings, json_file, indent=4) except Exception: - log(f"Unable to save plugin {self._id} settings", "ERROR") + log.error(f"Unable to save plugin {self._id} settings") return {} return updated_settings @@ -144,7 +144,7 @@ def _load_manifest(self): json_file_data = json.load(json_file) json_file.close() except Exception: - log(f"Loading plugin {self._path} metadata, defaulting to generated values", "INFO") + log.info(f"Loading plugin {self._path} metadata, defaulting to generated values") meta["name"] = json_file_data.get("name", to_camel_case(self._id)) meta["description"] = json_file_data.get("description", ( @@ -169,7 +169,7 @@ def _load_hooks_and_tools(self): for py_file in self.py_files: py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry - log(f"Import module {py_filename}", "INFO") + log.info(f"Import module {py_filename}") # save a reference to decorated functions try: @@ -177,7 +177,7 @@ def _load_hooks_and_tools(self): hooks += getmembers(plugin_module, self._is_cat_hook) tools += getmembers(plugin_module, self._is_cat_tool) except Exception as e: - log(f"Error in {py_filename}: {str(e)}","ERROR") + log.error(f"Error in {py_filename}: {str(e)}") traceback.print_exc() raise Exception(f"Unable to load the plugin {self._id}") diff --git a/core/cat/mad_hatter/registry.py b/core/cat/mad_hatter/registry.py index 7c9ba5c0..3590df7a 100644 --- a/core/cat/mad_hatter/registry.py +++ b/core/cat/mad_hatter/registry.py @@ -35,13 +35,13 @@ async def registry_search_plugins( return response.json()["plugins"] except Exception as e: - log(e, "ERROR") + log.error(e) return [] def registry_download_plugin(url: str) -> str: - log(f"Downloading {url}", "INFO") + log.info(f"Downloading {url}") registry_url = get_registry_url() payload = { @@ -52,6 +52,6 @@ def registry_download_plugin(url: str) -> str: with open(plugin_zip_path, "wb") as f: f.write(response.content) - log(f"Saved plugin as {plugin_zip_path}", "INFO") + log.info(f"Saved plugin as {plugin_zip_path}") return plugin_zip_path diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index 77ff5978..ac22bc5d 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -62,7 +62,7 @@ def connect_to_vector_memory(self) -> None: qdrant_host = os.getenv("QDRANT_HOST", db_path) if len(qdrant_host) == 0 or qdrant_host == db_path: - log(f"Qdrant path: {db_path}","INFO") + log.info(f"Qdrant path: {db_path}") # Qdrant local vector DB client # reconnect only if it's the first boot and not a reload @@ -77,8 +77,7 @@ def connect_to_vector_memory(self) -> None: s = socket.socket() s.connect((qdrant_host, qdrant_port)) except Exception: - log("QDrant does not respond to %s:%s" % - (qdrant_host, qdrant_port), "ERROR") + log.error(f"QDrant does not respond to {qdrant_host}:{qdrant_port}") sys.exit() finally: s.close() @@ -117,8 +116,8 @@ def __init__(self, cat, client: Any, collection_name: str, embeddings: Embedding self.check_embedding_size() # log collection info - log(f"Collection {self.collection_name}:", "INFO") - log(dict(self.client.get_collection(self.collection_name)), "INFO") + log.info(f"Collection {self.collection_name}:") + log.info(dict(self.client.get_collection(self.collection_name))) def check_embedding_size(self): @@ -127,15 +126,15 @@ def check_embedding_size(self): same_size = (self.client.get_collection(self.collection_name).config.params.vectors.size==self.embedder_size) alias = self.embedder_name + "_" + self.collection_name if alias==self.client.get_collection_aliases(self.collection_name).aliases[0].alias_name and same_size: - log(f'Collection "{self.collection_name}" has the same embedder', "INFO") + log.info(f'Collection "{self.collection_name}" has the same embedder') else: - log(f'Collection "{self.collection_name}" has different embedder', "WARNING") + log.warning(f'Collection "{self.collection_name}" has different embedder') # dump collection on disk before deleting self.save_dump() - log(f'Dump "{self.collection_name}" completed', "INFO") + log.info(f'Dump "{self.collection_name}" completed') self.client.delete_collection(self.collection_name) - log(f'Collection "{self.collection_name}" deleted', "WARNING") + log.warning(f'Collection "{self.collection_name}" deleted') self.create_collection() def create_db_collection_if_not_exists(self): @@ -145,7 +144,7 @@ def create_db_collection_if_not_exists(self): for c in collections_response.collections: if c.name == self.collection_name: # collection exists. Do nothing - log(f'Collection "{self.collection_name}" already present in vector store', "INFO") + log.info(f'Collection "{self.collection_name}" already present in vector store') return self.create_collection() @@ -153,7 +152,7 @@ def create_db_collection_if_not_exists(self): # create collection def create_collection(self): - log(f"Creating collection {self.collection_name} ...", "WARNING") + log.warning(f"Creating collection {self.collection_name} ...") self.client.recreate_collection( collection_name=self.collection_name, vectors_config=VectorParams( @@ -265,9 +264,9 @@ def save_dump(self, folder="dormouse/"): port = self.client._client._port if os.path.isdir(folder): - log(f'Directory dormouse exists', "INFO") + log.info(f'Directory dormouse exists') else: - log(f'Directory dormouse NOT exists, creating it.', "WARNING") + log.warning(f'Directory dormouse does NOT exists, creating it.') os.mkdir(folder) self.snapshot_info = self.client.create_snapshot(collection_name=self.collection_name) @@ -281,5 +280,5 @@ def save_dump(self, folder="dormouse/"): os.rename(snapshot_url_out, new_name) for s in self.client.list_snapshots(self.collection_name): self.client.delete_snapshot(collection_name=self.collection_name, snapshot_name=s.name) - log(f'Dump "{new_name}" completed', "WARNING") + log.warning(f'Dump "{new_name}" completed') # dump complete \ No newline at end of file diff --git a/core/cat/rabbit_hole.py b/core/cat/rabbit_hole.py index f4d87d77..d23b8edd 100644 --- a/core/cat/rabbit_hole.py +++ b/core/cat/rabbit_hole.py @@ -78,7 +78,7 @@ def ingest_memory(self, file: UploadFile): } for p in declarative_memories] vectors = [v["vector"] for v in declarative_memories] - log(f"Preparing to load {len(vectors)} vector memories", "INFO") + log.info(f"Preparing to load {len(vectors)} vector memories") # Check embedding size is correct embedder_size = self.cat.memory.vectors.embedder_size @@ -200,7 +200,7 @@ def file_to_docs( with urlopen(request) as response: file_bytes = response.read() except HTTPError as e: - log(e, "ERROR") + log.error(e) else: # Get mime type from file extension and source @@ -253,7 +253,7 @@ def store_documents(self, docs: List[Document], source: str) -> None: before_rabbithole_insert_memory """ - log(f"Preparing to memorize {len(docs)} vectors") + log.info(f"Preparing to memorize {len(docs)} vectors") # hook the docs before they are stored in the vector memory docs = self.cat.mad_hatter.execute_hook( @@ -281,10 +281,9 @@ def store_documents(self, docs: List[Document], source: str) -> None: [doc.metadata], ) - # log(f"Inserted into memory({inserting_info})", "INFO") - print(f"Inserted into memory({inserting_info})") + log.info(f"Inserted into memory({inserting_info})") else: - log(f"Skipped memory insertion of empty doc ({inserting_info})", "INFO") + log.info(f"Skipped memory insertion of empty doc ({inserting_info})") # wait a little to avoid APIs rate limit errors time.sleep(0.1) diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 19cf0d45..30538909 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -82,7 +82,7 @@ async def install_plugin( }, ) - log(f"Uploading {content_type} plugin {file.filename}", "INFO") + log.info(f"Uploading {content_type} plugin {file.filename}") plugin_archive_path = f"/tmp/{file.filename}" with open(plugin_archive_path, "wb+") as f: f.write(file.file.read()) @@ -113,8 +113,8 @@ async def install_plugin_from_registry( try: tmp_plugin_path = registry_download_plugin( payload["url"] ) except Exception as e: - log("Could not download plugin form registry", "ERROR") - log(e, "ERROR") + log.error("Could not download plugin form registry") + log.error(e) raise HTTPException( status_code = 500, detail = { "error": str(e)} diff --git a/core/cat/routes/upload.py b/core/cat/routes/upload.py index 520cbb0e..14a35322 100644 --- a/core/cat/routes/upload.py +++ b/core/cat/routes/upload.py @@ -33,7 +33,7 @@ async def upload_file( # Get file mime type content_type = mimetypes.guess_type(file.filename)[0] - log(f"Uploaded {content_type} down the rabbit hole", "INFO") + log.info(f"Uploaded {content_type} down the rabbit hole") # check if MIME type of uploaded file is supported if content_type not in admitted_types: @@ -122,7 +122,7 @@ async def upload_memory( # Get file mime type content_type = mimetypes.guess_type(file.filename)[0] - log(f"Uploaded {content_type} down the rabbit hole", "INFO") + log.info(f"Uploaded {content_type} down the rabbit hole") if content_type != "application/json": raise HTTPException( status_code=400, diff --git a/core/cat/routes/websocket.py b/core/cat/routes/websocket.py index 43650fb1..49d28951 100644 --- a/core/cat/routes/websocket.py +++ b/core/cat/routes/websocket.py @@ -99,10 +99,10 @@ async def websocket_endpoint(websocket: WebSocket): ) except WebSocketDisconnect: # Handle the event where the user disconnects their WebSocket. - log("WebSocket connection closed", "INFO") + log.info("WebSocket connection closed") except Exception as e: # Log any unexpected errors and send an error message back to the user. - log(e, "ERROR") + log.error(e) traceback.print_exc() await manager.send_personal_message({ "type": "error", From 44c75ca8ed5460b40ad1c518fce0af6694d84d22 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 18:03:41 +0200 Subject: [PATCH 47/77] update readme tests command --- README.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/README.md b/README.md index e7a9a280..2ff1c84e 100644 --- a/README.md +++ b/README.md @@ -113,12 +113,6 @@ To run the tests within the Docker container, execute the following command: docker exec cheshire_cat_core python -m pytest --color=yes . ``` -If you are running the tests locally on your machine, use the following command: - -```bash -python -m pytest --color=yes . -``` - ### Try in GitHub Codespaces You can try Cheshire Cat in GitHub Codespaces. The free account provides 60 free hours a month. From 08417b86d81ae7fdb42cb03d1eb5a79cfbf46c1c Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 19:15:01 +0200 Subject: [PATCH 48/77] piepeable hooks via tea_cup --- core/cat/looking_glass/cheshire_cat.py | 151 +++++++++++++++- .../mad_hatter/core_plugin/hooks/models.py | 161 ------------------ core/cat/mad_hatter/mad_hatter.py | 38 ++++- 3 files changed, 181 insertions(+), 169 deletions(-) delete mode 100644 core/cat/mad_hatter/core_plugin/hooks/models.py diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index accc8599..a50437d8 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -12,6 +12,18 @@ from cat.memory.long_term_memory import LongTermMemory from cat.looking_glass.agent_manager import AgentManager +# TODO: natural language dependencies; move to another file +import cat.factory.llm as llms +import cat.factory.embedder as embedders +from cat.db import crud +from langchain.llms import Cohere, OpenAI, OpenAIChat, AzureOpenAI, HuggingFaceTextGenInference +from langchain.chat_models import ChatOpenAI +from langchain.base_language import BaseLanguageModel +from langchain import HuggingFaceHub +from langchain.chat_models import AzureChatOpenAI +from cat.factory.custom_llm import CustomOpenAI + + MSG_TYPES = Literal["notification", "chat", "error"] # main class @@ -75,13 +87,144 @@ def load_natural_language(self): See Also -------- - get_language_model - get_language_embedder agent_prompt_prefix """ # LLM and embedder - self._llm = self.mad_hatter.execute_hook("get_language_model") - self.embedder = self.mad_hatter.execute_hook("get_language_embedder") + self._llm = self.get_language_model() + self.embedder = self.get_language_embedder() + + def get_language_model(self) -> BaseLanguageModel: + """Large Language Model (LLM) selection at bootstrap time. + + Returns + ------- + llm : BaseLanguageModel + Langchain `BaseLanguageModel` instance of the selected model. + + Notes + ----- + Bootstrapping is the process of loading the plugins, the natural language objects (e.g. the LLM), the memories, + the *Agent Manager* and the *Rabbit Hole*. + + """ + selected_llm = crud.get_setting_by_name(name="llm_selected") + + if selected_llm is None: + # return default LLM + llm = llms.LLMDefaultConfig.get_llm_from_config({}) + + else: + # get LLM factory class + selected_llm_class = selected_llm["value"]["name"] + FactoryClass = getattr(llms, selected_llm_class) + + # obtain configuration and instantiate LLM + selected_llm_config = crud.get_setting_by_name(name=selected_llm_class) + try: + llm = FactoryClass.get_llm_from_config(selected_llm_config["value"]) + except Exception as e: + import traceback + traceback.print_exc() + llm = llms.LLMDefaultConfig.get_llm_from_config({}) + + return llm + + + def get_language_embedder(self) -> embedders.EmbedderSettings: + """Hook into the embedder selection. + + Allows to modify how the Cat selects the embedder at bootstrap time. + + Bootstrapping is the process of loading the plugins, the natural language objects (e.g. the LLM), + the memories, the *Agent Manager* and the *Rabbit Hole*. + + Parameters + ---------- + cat: CheshireCat + Cheshire Cat instance. + + Returns + ------- + embedder : Embeddings + Selected embedder model. + """ + # Embedding LLM + + selected_embedder = crud.get_setting_by_name(name="embedder_selected") + + if selected_embedder is not None: + # get Embedder factory class + selected_embedder_class = selected_embedder["value"]["name"] + FactoryClass = getattr(embedders, selected_embedder_class) + + # obtain configuration and instantiate Embedder + selected_embedder_config = crud.get_setting_by_name(name=selected_embedder_class) + embedder = FactoryClass.get_embedder_from_config(selected_embedder_config["value"]) + + return embedder + + # OpenAI embedder + if type(self._llm) in [OpenAI, OpenAIChat, ChatOpenAI]: + embedder = embedders.EmbedderOpenAIConfig.get_embedder_from_config( + { + "openai_api_key": self._llm.openai_api_key, + } + ) + + # Azure + elif type(self._llm) in [AzureOpenAI, AzureChatOpenAI]: + embedder = embedders.EmbedderAzureOpenAIConfig.get_embedder_from_config( + { + "openai_api_key": self._llm.openai_api_key, + "openai_api_type": "azure", + "model": "text-embedding-ada-002", + # Now the only model for embeddings is text-embedding-ada-002 + # It is also possible to use the Azure "deployment" name that is user defined + # when the model is deployed to Azure. + # "deployment": "my-text-embedding-ada-002", + "openai_api_base": self._llm.openai_api_base, + # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#embeddings + # current supported versions 2022-12-01,2023-03-15-preview, 2023-05-15 + # Don't mix api versions https://github.com/hwchase17/langchain/issues/4775 + "openai_api_version": "2023-05-15", + } + ) + + # Cohere + elif type(self._llm) in [Cohere]: + embedder = embedders.EmbedderCohereConfig.get_embedder_from_config( + { + "cohere_api_key": self._llm.cohere_api_key, + "model": "embed-multilingual-v2.0", + # Now the best model for embeddings is embed-multilingual-v2.0 + } + ) + + # HuggingFace + elif type(self._llm) in [HuggingFaceHub]: + embedder = embedders.EmbedderHuggingFaceHubConfig.get_embedder_from_config( + { + "huggingfacehub_api_token": self._llm.huggingfacehub_api_token, + "repo_id": "sentence-transformers/all-mpnet-base-v2", + } + ) + + # Llama-cpp-python + elif type(self._llm) in [CustomOpenAI]: + embedder = embedders.EmbedderLlamaCppConfig.get_embedder_from_config( + { + "url": self._llm.url + } + ) + + else: + # If no embedder matches vendor, and no external embedder is configured, we use the DumbEmbedder. + # `This embedder is not a model properly trained + # and this makes it not suitable to effectively embed text, + # "but it does not know this and embeds anyway".` - cit. Nicola Corbellini + embedder = embedders.EmbedderDumbConfig.get_embedder_from_config({}) + + return embedder def load_memory(self): """Load LongTerMemory and WorkingMemory.""" diff --git a/core/cat/mad_hatter/core_plugin/hooks/models.py b/core/cat/mad_hatter/core_plugin/hooks/models.py deleted file mode 100644 index a41a4e4c..00000000 --- a/core/cat/mad_hatter/core_plugin/hooks/models.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Hooks to modify the Cat's language and embedding models. - -Here is a collection of methods to hook into the settings of the Large Language Model and the Embedder. - -""" - -import os - -import cat.factory.llm as llms -import cat.factory.embedder as embedders -from cat.db import crud -from langchain.llms import Cohere, OpenAI, OpenAIChat, AzureOpenAI, HuggingFaceTextGenInference -from langchain.chat_models import ChatOpenAI -from langchain.base_language import BaseLanguageModel -from langchain import HuggingFaceHub -from langchain.chat_models import AzureChatOpenAI -from cat.mad_hatter.decorators import hook -from cat.factory.custom_llm import CustomOpenAI - - -@hook(priority=0) -def get_language_model(cat) -> BaseLanguageModel: - """Hook into the Large Language Model (LLM) selection. - - Allows to modify how the Cat selects the LLM at bootstrap time. - - Parameters - ---------- - cat: CheshireCat - Cheshire Cat instance. - - Returns - ------- - lll : BaseLanguageModel - Langchain `BaseLanguageModel` instance of the selected model. - - Notes - ----- - Bootstrapping is the process of loading the plugins, the natural language objects (e.g. the LLM), the memories, - the *Agent Manager* and the *Rabbit Hole*. - - """ - selected_llm = crud.get_setting_by_name(name="llm_selected") - - if selected_llm is None: - # return default LLM - llm = llms.LLMDefaultConfig.get_llm_from_config({}) - - else: - # get LLM factory class - selected_llm_class = selected_llm["value"]["name"] - FactoryClass = getattr(llms, selected_llm_class) - - # obtain configuration and instantiate LLM - selected_llm_config = crud.get_setting_by_name(name=selected_llm_class) - try: - llm = FactoryClass.get_llm_from_config(selected_llm_config["value"]) - except Exception as e: - import traceback - traceback.print_exc() - llm = llms.LLMDefaultConfig.get_llm_from_config({}) - - return llm - - -@hook(priority=0) -def get_language_embedder(cat) -> embedders.EmbedderSettings: - """Hook into the embedder selection. - - Allows to modify how the Cat selects the embedder at bootstrap time. - - Bootstrapping is the process of loading the plugins, the natural language objects (e.g. the LLM), - the memories, the *Agent Manager* and the *Rabbit Hole*. - - Parameters - ---------- - cat: CheshireCat - Cheshire Cat instance. - - Returns - ------- - embedder : Embeddings - Selected embedder model. - """ - # Embedding LLM - - selected_embedder = crud.get_setting_by_name(name="embedder_selected") - - if selected_embedder is not None: - # get Embedder factory class - selected_embedder_class = selected_embedder["value"]["name"] - FactoryClass = getattr(embedders, selected_embedder_class) - - # obtain configuration and instantiate Embedder - selected_embedder_config = crud.get_setting_by_name(name=selected_embedder_class) - embedder = FactoryClass.get_embedder_from_config(selected_embedder_config["value"]) - - return embedder - - # OpenAI embedder - if type(cat._llm) in [OpenAI, OpenAIChat, ChatOpenAI]: - embedder = embedders.EmbedderOpenAIConfig.get_embedder_from_config( - { - "openai_api_key": cat._llm.openai_api_key, - } - ) - - # Azure - elif type(cat._llm) in [AzureOpenAI, AzureChatOpenAI]: - embedder = embedders.EmbedderAzureOpenAIConfig.get_embedder_from_config( - { - "openai_api_key": cat._llm.openai_api_key, - "openai_api_type": "azure", - "model": "text-embedding-ada-002", - # Now the only model for embeddings is text-embedding-ada-002 - # It is also possible to use the Azure "deployment" name that is user defined - # when the model is deployed to Azure. - # "deployment": "my-text-embedding-ada-002", - "openai_api_base": cat._llm.openai_api_base, - # https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#embeddings - # current supported versions 2022-12-01,2023-03-15-preview, 2023-05-15 - # Don't mix api versions https://github.com/hwchase17/langchain/issues/4775 - "openai_api_version": "2023-05-15", - } - ) - - # Cohere - elif type(cat._llm) in [Cohere]: - embedder = embedders.EmbedderCohereConfig.get_embedder_from_config( - { - "cohere_api_key": cat._llm.cohere_api_key, - "model": "embed-multilingual-v2.0", - # Now the best model for embeddings is embed-multilingual-v2.0 - } - ) - - # HuggingFace - elif type(cat._llm) in [HuggingFaceHub]: - embedder = embedders.EmbedderHuggingFaceHubConfig.get_embedder_from_config( - { - "huggingfacehub_api_token": cat._llm.huggingfacehub_api_token, - "repo_id": "sentence-transformers/all-mpnet-base-v2", - } - ) - - # Llama-cpp-python - elif type(cat._llm) in [CustomOpenAI]: - embedder = embedders.EmbedderLlamaCppConfig.get_embedder_from_config( - { - "url": cat._llm.url - } - ) - - else: - # If no embedder matches vendor, and no external embedder is configured, we use the DumbEmbedder. - # `This embedder is not a model properly trained - # and this makes it not suitable to effectively embed text, - # "but it does not know this and embeds anyway".` - cit. Nicola Corbellini - embedder = embedders.EmbedderDumbConfig.get_embedder_from_config({}) - - return embedder diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index f258a8a5..4d7cd2ca 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -3,6 +3,7 @@ import time import shutil import os +import traceback from cat.log import log from cat.db import crud @@ -246,12 +247,41 @@ def toggle_plugin(self, plugin_id): # execute requested hook def execute_hook(self, hook_name, *args): + log.critical(hook_name) + # check if hook is supported if hook_name not in self.hooks.keys(): raise Exception(f"Hook {hook_name} not present in any plugin") - # run hooks - for h in self.hooks[hook_name]: - return h.function(*args, cat=self.ccat) - # TODO: should be run as a pipe, not return immediately + # First argument is passed to `execute_hook` is the pipeable one. + # We call it `tea_cup` as every hook called will receive it as an input, + # can add sugar, milk, or whatever, and return it for the next hook + if len(args) == 0: + tea_cup = None + else: + tea_cup = args[0] + + # run hooks + for hook in self.hooks[hook_name]: + try: + # pass tea_cup to the hooks, along other args + + # hook has no input (aside cat) + if tea_cup is None: + hook.function(cat=self.ccat) + continue + + # hook has at least one argument, and it will be piped + tea_spoon = hook.function(tea_cup, *args[1:], cat=self.ccat) + if tea_spoon is None: + log.warning(f"Hook {hook.plugin_id}::{hook.name} returned None") + else: + tea_cup = tea_spoon + except Exception as e: + log.error(f"Error in plugin {hook.plugin_id}::{hook.name}") + log.error(e) + traceback.print_exc() + + # tea_cup has passed through all hooks. Return final output + return tea_cup \ No newline at end of file From 46221538198372f351f36ff3b43a7c9cc0aa9e18 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 19:27:06 +0200 Subject: [PATCH 49/77] refactor flow hooks --- core/cat/mad_hatter/core_plugin/hooks/flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/cat/mad_hatter/core_plugin/hooks/flow.py b/core/cat/mad_hatter/core_plugin/hooks/flow.py index cbedc7ee..fe704cbf 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/flow.py +++ b/core/cat/mad_hatter/core_plugin/hooks/flow.py @@ -25,7 +25,7 @@ def before_cat_bootstrap(cat) -> None: cat : CheshireCat Cheshire Cat instance. """ - return None + pass # do nothing # Called after cat bootstrap @@ -46,7 +46,7 @@ def after_cat_bootstrap(cat) -> None: cat : CheshireCat Cheshire Cat instance. """ - return None + pass # do nothing # Called when a user message arrives. From 25108b01f954e6a0bbb441b6bcd339c95acbab4d Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 20:00:18 +0200 Subject: [PATCH 50/77] refactor recall hooks --- core/cat/looking_glass/cheshire_cat.py | 12 ++++++------ core/cat/mad_hatter/core_plugin/hooks/flow.py | 8 +++----- core/cat/mad_hatter/mad_hatter.py | 5 ++--- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index a50437d8..b1627eaf 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -258,15 +258,15 @@ def recall_relevant_memories_to_working_memory(self): after_cat_recalls_memories """ user_id = self.working_memory.get_user_id() - user_message = self.working_memory["user_message_json"]["text"] + recall_query = self.working_memory["user_message_json"]["text"] # We may want to search in memory - memory_query_text = self.mad_hatter.execute_hook("cat_recall_query", user_message) - log.info(f'Recall query: "{memory_query_text}"') + recall_query = self.mad_hatter.execute_hook("cat_recall_query", recall_query) + log.info(f'Recall query: "{recall_query}"') # Embed recall query - memory_query_embedding = self.embedder.embed_query(memory_query_text) - self.working_memory["memory_query"] = memory_query_text + memory_query_embedding = self.embedder.embed_query(recall_query) + self.working_memory["recall_query"] = recall_query # hook to do something before recall begins self.mad_hatter.execute_hook("before_cat_recalls_memories") @@ -312,7 +312,7 @@ def recall_relevant_memories_to_working_memory(self): self.working_memory[memory_key] = memories # hook to modify/enrich retrieved memories - self.mad_hatter.execute_hook("after_cat_recalls_memories", memory_query_text) + self.mad_hatter.execute_hook("after_cat_recalls_memories") def llm(self, prompt: str) -> str: """Generate a response using the LLM model. diff --git a/core/cat/mad_hatter/core_plugin/hooks/flow.py b/core/cat/mad_hatter/core_plugin/hooks/flow.py index fe704cbf..af6afee7 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/flow.py +++ b/core/cat/mad_hatter/core_plugin/hooks/flow.py @@ -201,21 +201,19 @@ def before_cat_recalls_procedural_memories(procedural_recall_config: dict, cat) # Called just before the cat recalls memories. @hook(priority=0) -def after_cat_recalls_memories(query: str, cat) -> None: +def after_cat_recalls_memories(cat) -> None: """Hook after semantic search in memories. - The hook is executed just after the Cat searches for the meaningful context in both memories + The hook is executed just after the Cat searches for the meaningful context in memories and stores it in the *Working Memory*. Parameters ---------- - query : str - Query used to retrieve memories. cat : CheshireCat Cheshire Cat instance. """ - return None + pass # do nothing # What is the input to recall memories? diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 4d7cd2ca..cd57e65b 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -274,9 +274,8 @@ def execute_hook(self, hook_name, *args): # hook has at least one argument, and it will be piped tea_spoon = hook.function(tea_cup, *args[1:], cat=self.ccat) - if tea_spoon is None: - log.warning(f"Hook {hook.plugin_id}::{hook.name} returned None") - else: + log.info(f"Hook {hook.plugin_id}::{hook.name} returned {tea_spoon}") + if tea_spoon is not None: tea_cup = tea_spoon except Exception as e: log.error(f"Error in plugin {hook.plugin_id}::{hook.name}") From 1e8c18fb913a3792be7889f30dc5331eb2bdcedd Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 20:15:33 +0200 Subject: [PATCH 51/77] refactor memory recall hooks, 2 --- core/cat/looking_glass/cheshire_cat.py | 10 ++- core/cat/mad_hatter/core_plugin/hooks/flow.py | 90 +++++++++---------- 2 files changed, 51 insertions(+), 49 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index b1627eaf..0abec674 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -251,6 +251,7 @@ def recall_relevant_memories_to_working_memory(self): See Also -------- + cat_recall_query before_cat_recalls_memories before_cat_recalls_episodic_memories before_cat_recalls_declarative_memories @@ -265,29 +266,30 @@ def recall_relevant_memories_to_working_memory(self): log.info(f'Recall query: "{recall_query}"') # Embed recall query - memory_query_embedding = self.embedder.embed_query(recall_query) + recall_query_embedding = self.embedder.embed_query(recall_query) self.working_memory["recall_query"] = recall_query # hook to do something before recall begins self.mad_hatter.execute_hook("before_cat_recalls_memories") # Setting default recall configs for each memory + # TODO: can these data structrues become instances of a RecallSettings class? default_episodic_recall_config = { - "embedding": memory_query_embedding, + "embedding": recall_query_embedding, "k": 3, "threshold": 0.7, "metadata": {"source": user_id}, } default_declarative_recall_config = { - "embedding": memory_query_embedding, + "embedding": recall_query_embedding, "k": 3, "threshold": 0.7, "metadata": None, } default_procedural_recall_config = { - "embedding": memory_query_embedding, + "embedding": recall_query_embedding, "k": 3, "threshold": 0.7, "metadata": None, diff --git a/core/cat/mad_hatter/core_plugin/hooks/flow.py b/core/cat/mad_hatter/core_plugin/hooks/flow.py index af6afee7..bdaf6bc1 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/flow.py +++ b/core/cat/mad_hatter/core_plugin/hooks/flow.py @@ -93,6 +93,50 @@ def before_cat_reads_message(user_message_json: dict, cat) -> dict: return user_message_json +# What is the input to recall memories? +# Here you can do HyDE embedding, condense recent conversation or condition recall query on something else important to your AI +@hook(priority=0) +def cat_recall_query(user_message: str, cat) -> str: + """Hook the semantic search query. + + This hook allows to edit the user's message used as a query for context retrieval from memories. + As a result, the retrieved context can be conditioned editing the user's message. + + Parameters + ---------- + user_message : str + String with the text received from the user. + cat : CheshireCat + Cheshire Cat instance to exploit the Cat's methods. + + Returns + ------- + Edited string to be used for context retrieval in memory. The returned string is further stored in the + Working Memory at `cat.working_memory["memory_query"]`. + + Notes + ----- + For example, this hook is a suitable to perform Hypothetical Document Embedding (HyDE). + HyDE [1]_ strategy exploits the user's message to generate a hypothetical answer. This is then used to recall + the relevant context from the memory. + An official plugin is available to test this technique. + + References + ---------- + [1] Gao, L., Ma, X., Lin, J., & Callan, J. (2022). Precise Zero-Shot Dense Retrieval without Relevance Labels. + arXiv preprint arXiv:2212.10496. + + """ + # example 1: HyDE embedding + # return cat.hypothetis_chain.run(user_message) + + # example 2: Condense recent conversation + # TODO + + # here we just return the latest user message as is + return user_message + + # Called just before the cat recalls memories. @hook(priority=0) def before_cat_recalls_memories(cat) -> None: @@ -109,7 +153,7 @@ def before_cat_recalls_memories(cat) -> None: Cheshire Cat instance. """ - return None + pass # do nothing @hook(priority=0) @@ -216,50 +260,6 @@ def after_cat_recalls_memories(cat) -> None: pass # do nothing -# What is the input to recall memories? -# Here you can do HyDE embedding, condense recent conversation or condition recall query on something else important to your AI -@hook(priority=0) -def cat_recall_query(user_message: str, cat) -> str: - """Hook the semantic search query. - - This hook allows to edit the user's message used as a query for context retrieval from memories. - As a result, the retrieved context can be conditioned editing the user's message. - - Parameters - ---------- - user_message : str - String with the text received from the user. - cat : CheshireCat - Cheshire Cat instance to exploit the Cat's methods. - - Returns - ------- - Edited string to be used for context retrieval in memory. The returned string is further stored in the - Working Memory at `cat.working_memory["memory_query"]`. - - Notes - ----- - For example, this hook is a suitable to perform Hypothetical Document Embedding (HyDE). - HyDE [1]_ strategy exploits the user's message to generate a hypothetical answer. This is then used to recall - the relevant context from the memory. - An official plugin is available to test this technique. - - References - ---------- - [1] Gao, L., Ma, X., Lin, J., & Callan, J. (2022). Precise Zero-Shot Dense Retrieval without Relevance Labels. - arXiv preprint arXiv:2212.10496. - - """ - # example 1: HyDE embedding - # return cat.hypothetis_chain.run(user_message) - - # example 2: Condense recent conversation - # TODO - - # here we just return the latest user message as is - return user_message - - # Called just after memories are recalled. They are stored in: # - cat.working_memory["episodic_memories"] # - cat.working_memory["declarative_memories"] From 8b6c37cfc0bba02d8596bc733771f081a1690430 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 21:12:56 +0200 Subject: [PATCH 52/77] fix non argument hooks --- core/cat/looking_glass/agent_manager.py | 60 ++++++++++++++++++- core/cat/looking_glass/cheshire_cat.py | 53 +--------------- .../mad_hatter/core_plugin/hooks/prompt.py | 2 +- core/cat/mad_hatter/mad_hatter.py | 27 +++++---- core/tests/routes/test_websocket.py | 1 - 5 files changed, 75 insertions(+), 68 deletions(-) diff --git a/core/cat/looking_glass/agent_manager.py b/core/cat/looking_glass/agent_manager.py index f6ba4508..9c65a816 100644 --- a/core/cat/looking_glass/agent_manager.py +++ b/core/cat/looking_glass/agent_manager.py @@ -85,7 +85,7 @@ def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix): return out - def execute_agent(self, agent_input): + def execute_agent(self): """Instantiate the Agent with tools. The method formats the main prompt and gather the allowed tools. It also instantiates a conversational Agent @@ -98,13 +98,17 @@ def execute_agent(self, agent_input): """ mad_hatter = self.cat.mad_hatter + # prepare input to be passed to the agent. + # Info will be extracted from working memory + agent_input = self.format_agent_input() + # this hook allows to reply without executing the agent (for example canned responses, out-of-topic barriers etc.) fast_reply = mad_hatter.execute_hook("before_agent_starts", agent_input) if fast_reply: return fast_reply - prompt_prefix = mad_hatter.execute_hook("agent_prompt_prefix") - prompt_suffix = mad_hatter.execute_hook("agent_prompt_suffix") + prompt_prefix = mad_hatter.execute_hook("agent_prompt_prefix", "TODO_HOOK") + prompt_suffix = mad_hatter.execute_hook("agent_prompt_suffix", "TODO_HOOK") allowed_tools = mad_hatter.execute_hook("agent_allowed_tools") @@ -160,3 +164,53 @@ def execute_agent(self, agent_input): out = self.execute_memory_chain(agent_input, prompt_prefix, prompt_suffix) return out + + def format_agent_input(self): + """Format the input for the Agent. + + The method formats the strings of recalled memories and chat history that will be provided to the Langchain + Agent and inserted in the prompt. + + Returns + ------- + dict + Formatted output to be parsed by the Agent executor. + + Notes + ----- + The context of memories and conversation history is properly formatted before being parsed by the and, hence, + information are inserted in the main prompt. + All the formatting pipeline is hookable and memories can be edited. + + See Also + -------- + agent_prompt_episodic_memories + agent_prompt_declarative_memories + agent_prompt_chat_history + """ + + mad_hatter = self.cat.mad_hatter + working_memory = self.cat.working_memory + + # format memories to be inserted in the prompt + episodic_memory_formatted_content = mad_hatter.execute_hook( + "agent_prompt_episodic_memories", + working_memory["episodic_memories"], + ) + declarative_memory_formatted_content = mad_hatter.execute_hook( + "agent_prompt_declarative_memories", + working_memory["declarative_memories"], + ) + + # format conversation history to be inserted in the prompt + conversation_history_formatted_content = mad_hatter.execute_hook( + "agent_prompt_chat_history", + working_memory["history"] + ) + + return { + "input": working_memory["user_message_json"]["text"], + "episodic_memory": episodic_memory_formatted_content, + "declarative_memory": declarative_memory_formatted_content, + "chat_history": conversation_history_formatted_content, + } diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 0abec674..853187a0 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -340,51 +340,6 @@ def llm(self, prompt: str) -> str: if isinstance(self._llm, langchain.chat_models.base.BaseChatModel): return self._llm.call_as_llm(prompt) - def format_agent_input(self): - """Format the input for the Agent. - - The method formats the strings of recalled memories and chat history that will be provided to the Langchain - Agent and inserted in the prompt. - - Returns - ------- - dict - Formatted output to be parsed by the Agent executor. - - Notes - ----- - The context of memories and conversation history is properly formatted before being parsed by the and, hence, - information are inserted in the main prompt. - All the formatting pipeline is hookable and memories can be edited. - - See Also - -------- - agent_prompt_episodic_memories - agent_prompt_declarative_memories - agent_prompt_chat_history - """ - # format memories to be inserted in the prompt - episodic_memory_formatted_content = self.mad_hatter.execute_hook( - "agent_prompt_episodic_memories", - self.working_memory["episodic_memories"], - ) - declarative_memory_formatted_content = self.mad_hatter.execute_hook( - "agent_prompt_declarative_memories", - self.working_memory["declarative_memories"], - ) - - # format conversation history to be inserted in the prompt - conversation_history_formatted_content = self.mad_hatter.execute_hook( - "agent_prompt_chat_history", self.working_memory["history"] - ) - - return { - "input": self.working_memory["user_message_json"]["text"], - "episodic_memory": episodic_memory_formatted_content, - "declarative_memory": declarative_memory_formatted_content, - "chat_history": conversation_history_formatted_content, - } - def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): """Send a message via websocket. @@ -492,13 +447,9 @@ def __call__(self, user_message_json): "description": err_message, } - # prepare input to be passed to the agent. - # Info will be extracted from working memory - agent_input = self.format_agent_input() - # reply with agent try: - cat_message = self.agent_manager.execute_agent(agent_input) + cat_message = self.agent_manager.execute_agent() except Exception as e: # This error happens when the LLM # does not respect prompt instructions. @@ -512,7 +463,7 @@ def __call__(self, user_message_json): unparsable_llm_output = error_description.replace("Could not parse LLM output: `", "").replace("`", "") cat_message = { - "input": agent_input["input"], + "input": self.working_memory["user_message_json"]["text"], "intermediate_steps": [], "output": unparsable_llm_output } diff --git a/core/cat/mad_hatter/core_plugin/hooks/prompt.py b/core/cat/mad_hatter/core_plugin/hooks/prompt.py index a36a4612..20676dab 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/prompt.py +++ b/core/cat/mad_hatter/core_plugin/hooks/prompt.py @@ -14,7 +14,7 @@ @hook(priority=0) -def agent_prompt_prefix(cat) -> str: +def agent_prompt_prefix(prefix, cat) -> str: """Hook the main prompt prefix. Allows to edit the prefix of the *Main Prompt* that the Cat feeds to the *Agent*. diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index cd57e65b..b91f09cc 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -253,25 +253,28 @@ def execute_hook(self, hook_name, *args): if hook_name not in self.hooks.keys(): raise Exception(f"Hook {hook_name} not present in any plugin") - - # First argument is passed to `execute_hook` is the pipeable one. + # Hook has no arguments (aside cat) + # no need to pipe + if len(args) == 0: + for hook in self.hooks[hook_name]: + try: + hook.function(cat=self.ccat) + except Exception as e: + log.error(f"Error in plugin {hook.plugin_id}::{hook.name}") + log.error(e) + traceback.print_exc() + return + + # Hook with arguments. + # First argument is passed to `execute_hook` is the pipeable one. # We call it `tea_cup` as every hook called will receive it as an input, # can add sugar, milk, or whatever, and return it for the next hook - if len(args) == 0: - tea_cup = None - else: - tea_cup = args[0] + tea_cup = args[0] # run hooks for hook in self.hooks[hook_name]: try: # pass tea_cup to the hooks, along other args - - # hook has no input (aside cat) - if tea_cup is None: - hook.function(cat=self.ccat) - continue - # hook has at least one argument, and it will be piped tea_spoon = hook.function(tea_cup, *args[1:], cat=self.ccat) log.info(f"Hook {hook.plugin_id}::{hook.name} returned {tea_spoon}") diff --git a/core/tests/routes/test_websocket.py b/core/tests/routes/test_websocket.py index f53a4688..34010e9e 100644 --- a/core/tests/routes/test_websocket.py +++ b/core/tests/routes/test_websocket.py @@ -2,7 +2,6 @@ from tests.utils import send_websocket_message -# TODO: ws endpoint still talks with the prod cat configuration def test_websocket(client): # use fake LLM From 94b26998d70a7853890114dfa68bab736eeb1c0b Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 22:21:01 +0200 Subject: [PATCH 53/77] first refactor for agent hooks (prompts) --- core/cat/looking_glass/agent_manager.py | 40 ++++++++------- core/cat/looking_glass/prompts.py | 49 +++++++++++++++++- .../cat/mad_hatter/core_plugin/hooks/agent.py | 45 ++-------------- .../mad_hatter/core_plugin/hooks/prompt.py | 51 ++----------------- 4 files changed, 79 insertions(+), 106 deletions(-) diff --git a/core/cat/looking_glass/agent_manager.py b/core/cat/looking_glass/agent_manager.py index 9c65a816..80e39826 100644 --- a/core/cat/looking_glass/agent_manager.py +++ b/core/cat/looking_glass/agent_manager.py @@ -2,7 +2,7 @@ from langchain.chains import LLMChain from langchain.agents import AgentExecutor, LLMSingleActionAgent -from cat.looking_glass.prompts import ToolPromptTemplate +from cat.looking_glass import prompts from cat.looking_glass.output_parser import ToolOutputParser from cat.log import log @@ -26,9 +26,10 @@ def __init__(self, cat): def execute_tool_agent(self, agent_input, allowed_tools): allowed_tools_names = [t.name for t in allowed_tools] + # TODO: dynamic input_variables as in the main prompt - prompt = ToolPromptTemplate( - template = self.cat.mad_hatter.execute_hook("agent_prompt_instructions"), + prompt = prompts.ToolPromptTemplate( + template = self.cat.mad_hatter.execute_hook("agent_prompt_instructions", prompts.TOOL_PROMPT), tools=allowed_tools, # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically # This includes the `intermediate_steps` variable because it is needed to fill the scratchpad @@ -60,17 +61,13 @@ def execute_tool_agent(self, agent_input, allowed_tools): def execute_memory_chain(self, agent_input, prompt_prefix, prompt_suffix): + + input_variables = [i for i in agent_input.keys() if i in prompt_prefix + prompt_suffix] # memory chain (second step) memory_prompt = PromptTemplate( template = prompt_prefix + prompt_suffix, - input_variables=[ - "input", - "chat_history", - "episodic_memory", - "declarative_memory", - "tools_output" - ] + input_variables=input_variables ) memory_chain = LLMChain( @@ -97,20 +94,27 @@ def execute_agent(self): Instance of the Agent provided with a set of tools. """ mad_hatter = self.cat.mad_hatter + working_memory = self.cat.working_memory # prepare input to be passed to the agent. # Info will be extracted from working memory agent_input = self.format_agent_input() # this hook allows to reply without executing the agent (for example canned responses, out-of-topic barriers etc.) - fast_reply = mad_hatter.execute_hook("before_agent_starts", agent_input) - if fast_reply: - return fast_reply - - prompt_prefix = mad_hatter.execute_hook("agent_prompt_prefix", "TODO_HOOK") - prompt_suffix = mad_hatter.execute_hook("agent_prompt_suffix", "TODO_HOOK") - - allowed_tools = mad_hatter.execute_hook("agent_allowed_tools") + #fast_reply = mad_hatter.execute_hook("before_agent_starts", agent_input) + #if fast_reply: + # return fast_reply + + prompt_prefix = mad_hatter.execute_hook("agent_prompt_prefix", prompts.MAIN_PROMPT_PREFIX) + prompt_suffix = mad_hatter.execute_hook("agent_prompt_suffix", prompts.MAIN_PROMPT_SUFFIX) + + # tools currently recalled in working memory + recalled_tools = working_memory["procedural_memories"] + # Get the tools names only + tools_names = [t[0].metadata["name"] for t in recalled_tools] + tools_names = mad_hatter.execute_hook("agent_allowed_tools", tools_names) + # Get tools with that name from mad_hatter + allowed_tools = [i for i in mad_hatter.tools if i.name in tools_names] # Try to get information from tools if there is some allowed if len(allowed_tools) > 0: diff --git a/core/cat/looking_glass/prompts.py b/core/cat/looking_glass/prompts.py index 44b99406..c5f3c85b 100644 --- a/core/cat/looking_glass/prompts.py +++ b/core/cat/looking_glass/prompts.py @@ -4,6 +4,7 @@ from langchain.agents.tools import BaseTool from langchain.prompts import StringPromptTemplate + class ToolPromptTemplate(StringPromptTemplate): # The template to use template: str @@ -25,4 +26,50 @@ def format(self, **kwargs) -> str: # Create a list of tool names for the tools provided kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) - return self.template.format(**kwargs) \ No newline at end of file + return self.template.format(**kwargs) + + +TOOL_PROMPT = """Answer the following question: `{input}` +You can only reply using these tools: + +{tools} +none_of_the_others: none_of_the_others(None) - Use this tool if none of the others tools help. Input is always None. + +If you want to use tools, use the following format: +Action: the name of the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +... +Action: the name of the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action + +When you have a final answer respond with: +Final Answer: the final answer to the original input question + +Begin! + +Question: {input} +{agent_scratchpad}""" + + +MAIN_PROMPT_PREFIX = """You are the Cheshire Cat AI, an intelligent AI that passes the Turing test. +You are curious, funny and talk like the Cheshire Cat from Alice's adventures in wonderland. +You answer Human with a focus on the following context. +""" + + +MAIN_PROMPT_SUFFIX = """ +# Context + +{episodic_memory} + +{declarative_memory} + +{tools_output} + +## Conversation until now:{chat_history} + - Human: {input} + - AI: """ + + diff --git a/core/cat/mad_hatter/core_plugin/hooks/agent.py b/core/cat/mad_hatter/core_plugin/hooks/agent.py index e35a4b3a..86d3781f 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/agent.py +++ b/core/cat/mad_hatter/core_plugin/hooks/agent.py @@ -58,12 +58,12 @@ def before_agent_starts(agent_input, cat) -> Union[None, Dict]: @hook(priority=0) -def agent_allowed_tools(cat) -> List[BaseTool]: +def agent_allowed_tools(allowed_tools: List[str], cat) -> List[str]: """Hook the allowed tools. Allows to decide which tools end up in the *Agent* prompt. - To decide, you can filter the list of loaded tools, but you can also check the context in `cat.working_memory` + To decide, you can filter the list of tools' names, but you can also check the context in `cat.working_memory` and launch custom chains with `cat._llm`. Parameters @@ -73,48 +73,11 @@ def agent_allowed_tools(cat) -> List[BaseTool]: Returns ------- - tools : List[BaseTool] + tools : List[str] List of allowed Langchain tools. """ - # tools currently recalled in working memory - recalled_tools = cat.working_memory["procedural_memories"] + return allowed_tools - # Get the tools names only - tools_names = [t[0].metadata["name"] for t in recalled_tools] - - # Get the LangChain BaseTool by name - tools = [i for i in cat.mad_hatter.tools if i.name in tools_names] - - return tools - - -@hook(priority=0) -def before_agent_creates_prompt(input_variables, main_prompt, cat): - """Hook to dynamically define the input variables. - - Allows to dynamically filter the input variables that end up in the main prompt by looking for which placeholders - there are in it starting from a fixed list. - - Parameters - ---------- - input_variables : List - List of placeholders to look for in the main prompt. - main_prompt: str - String made of the prompt prefix, the agent instructions and the prompt suffix. - cat : CheshireCat - Cheshire Cat instance. - - Returns - ------- - input_variables : List[str] - List of placeholders present in the main prompt. - - """ - - # Loop the input variables and check if they are in the main prompt - input_variables = [i for i in input_variables if i in main_prompt] - - return input_variables diff --git a/core/cat/mad_hatter/core_plugin/hooks/prompt.py b/core/cat/mad_hatter/core_plugin/hooks/prompt.py index 20676dab..e749add3 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/prompt.py +++ b/core/cat/mad_hatter/core_plugin/hooks/prompt.py @@ -38,16 +38,12 @@ def agent_prompt_prefix(prefix, cat) -> str: The next part of the prompt (generated form the *Agent*) contains the list of available Tools. """ - prefix = """You are the Cheshire Cat AI, an intelligent AI that passes the Turing test. -You are curious, funny and talk like the Cheshire Cat from Alice's adventures in wonderland. -You answer Human with a focus on the following context. -""" return prefix @hook(priority=0) -def agent_prompt_instructions(cat) -> str: +def agent_prompt_instructions(instructions, cat) -> str: """Hook the instruction prompt. Allows to edit the instructions that the Cat feeds to the *Agent*. @@ -81,36 +77,11 @@ def agent_prompt_instructions(cat) -> str: """ - DEFAULT_TOOL_TEMPLATE = """Answer the following question: `{input}` - You can only reply using these tools: - - {tools} - none_of_the_others: none_of_the_others(None) - Use this tool if none of the others tools help. Input is always None. - - If you want to use tools, use the following format: - Action: the name of the action to take, should be one of [{tool_names}] - Action Input: the input to the action - Observation: the result of the action - ... - Action: the name of the action to take, should be one of [{tool_names}] - Action Input: the input to the action - Observation: the result of the action - - When you have a final answer respond with: - Final Answer: the final answer to the original input question - - Begin! - - Question: {input} - {agent_scratchpad}""" - - - # here we piggy back directly on langchain agent instructions. Different instructions will require a different OutputParser - return DEFAULT_TOOL_TEMPLATE + return instructions @hook(priority=0) -def agent_prompt_suffix(cat) -> str: +def agent_prompt_suffix(prompt_suffix: str, cat) -> str: """Hook the main prompt suffix. Allows to edit the suffix of the *Main Prompt* that the Cat feeds to the *Agent*. @@ -138,20 +109,8 @@ def agent_prompt_suffix(cat) -> str: - {agent_scratchpad} is where the *Agent* can concatenate tools use and multiple calls to the LLM. """ - suffix = """ -# Context - -{episodic_memory} - -{declarative_memory} - -{tools_output} - -## Conversation until now:{chat_history} - - Human: {input} - - AI: """ - return suffix + return prompt_suffix @hook(priority=0) @@ -266,7 +225,7 @@ def agent_prompt_declarative_memories(memory_docs: List[Document], cat) -> str: @hook(priority=0) -def agent_prompt_chat_history(chat_history: List[Dict], cat) -> str: +def agent_prompt_chat_history(chat_history: List[Dict], cat) -> List[Dict]: """Hook the chat history. This hook converts to text the recent conversation turns fed to the *Agent*. From ab020334f1f47be77bd30b6a69055a791f82ccbf Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Fri, 15 Sep 2023 22:54:18 +0200 Subject: [PATCH 54/77] search TODO_HOOK in code --- .../cat/mad_hatter/core_plugin/hooks/agent.py | 1 + core/cat/mad_hatter/core_plugin/hooks/flow.py | 22 --------- .../mad_hatter/core_plugin/hooks/prompt.py | 3 ++ .../core_plugin/hooks/rabbithole.py | 48 ++----------------- core/cat/rabbit_hole.py | 12 ++++- 5 files changed, 17 insertions(+), 69 deletions(-) diff --git a/core/cat/mad_hatter/core_plugin/hooks/agent.py b/core/cat/mad_hatter/core_plugin/hooks/agent.py index 86d3781f..7e10e536 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/agent.py +++ b/core/cat/mad_hatter/core_plugin/hooks/agent.py @@ -12,6 +12,7 @@ from cat.log import log +# TODO_HOOK @hook(priority=0) def before_agent_starts(agent_input, cat) -> Union[None, Dict]: """Hook before the agent starts. diff --git a/core/cat/mad_hatter/core_plugin/hooks/flow.py b/core/cat/mad_hatter/core_plugin/hooks/flow.py index bdaf6bc1..793e893e 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/flow.py +++ b/core/cat/mad_hatter/core_plugin/hooks/flow.py @@ -260,28 +260,6 @@ def after_cat_recalls_memories(cat) -> None: pass # do nothing -# Called just after memories are recalled. They are stored in: -# - cat.working_memory["episodic_memories"] -# - cat.working_memory["declarative_memories"] -@hook(priority=0) -def after_cat_recalled_memories(memory_query_text: str, cat) -> None: - """Hook into semantic search after the memory retrieval. - - Allows to intercept the recalled memories right after these are stored in the Working Memory. - According to the user's input, the relevant context is saved in `cat.working_memory["episodic_memories"]` - and `cat.working_memory["declarative_memories"]`. At this point, - this hook is executed to edit the search query. - - Parameters - ---------- - memory_query_text : str - String used to query both *episodic* and *declarative* memories. - cat : CheshireCat - Cheshire Cat instance. - """ - return None - - # Hook called just before sending response to a client. @hook(priority=0) def before_cat_sends_message(message: dict, cat) -> dict: diff --git a/core/cat/mad_hatter/core_plugin/hooks/prompt.py b/core/cat/mad_hatter/core_plugin/hooks/prompt.py index e749add3..a0e2413b 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/prompt.py +++ b/core/cat/mad_hatter/core_plugin/hooks/prompt.py @@ -113,6 +113,7 @@ def agent_prompt_suffix(prompt_suffix: str, cat) -> str: return prompt_suffix +# TODO_HOOK @hook(priority=0) def agent_prompt_episodic_memories(memory_docs: List[Document], cat) -> str: """Hook memories retrieved from episodic memory. @@ -171,6 +172,7 @@ def agent_prompt_episodic_memories(memory_docs: List[Document], cat) -> str: return memory_content +# TODO_HOOK @hook(priority=0) def agent_prompt_declarative_memories(memory_docs: List[Document], cat) -> str: """Hook memories retrieved from declarative memory. @@ -224,6 +226,7 @@ def agent_prompt_declarative_memories(memory_docs: List[Document], cat) -> str: return memory_content +# TODO_HOOK @hook(priority=0) def agent_prompt_chat_history(chat_history: List[Dict], cat) -> List[Dict]: """Hook the chat history. diff --git a/core/cat/mad_hatter/core_plugin/hooks/rabbithole.py b/core/cat/mad_hatter/core_plugin/hooks/rabbithole.py index 995c517e..ac04050b 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/rabbithole.py +++ b/core/cat/mad_hatter/core_plugin/hooks/rabbithole.py @@ -8,10 +8,10 @@ from typing import List -from langchain.text_splitter import RecursiveCharacterTextSplitter -from cat.mad_hatter.decorators import hook from langchain.docstore.document import Document +from cat.mad_hatter.decorators import hook + @hook(priority=0) def rabbithole_instantiates_parsers(file_handlers: dict, cat) -> dict: @@ -91,49 +91,6 @@ def before_rabbithole_splits_text(doc: Document, cat) -> Document: return doc -# Hook called when rabbithole splits text. Input is whole Document -@hook(priority=0) -def rabbithole_splits_text(text, chunk_size: int, chunk_overlap: int, cat) -> List[Document]: - """Hook into the recursive split pipeline. - - Allows editing the recursive split the *RabbitHole* applies to chunk the ingested documents. - - This is applied when ingesting a documents and urls from a script, using an endpoint or from the GUI. - - Parameters - ---------- - text : List[Document] - List of langchain `Document` to chunk. - chunk_size : int - Length of every chunk in characters. - chunk_overlap : int - Amount of overlap between consecutive chunks. - cat : CheshireCat - Cheshire Cat instance. - - Returns - ------- - docs : List[Document] - List of chunked langchain documents to be stored in the episodic memory. - - """ - - # text splitter - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - separators=["\\n\\n", "\n\n", ".\\n", ".\n", "\\n", "\n", " ", ""], - ) - - # split text - docs = text_splitter.split_documents(text) - - # remove short texts (page numbers, isolated words, etc.) - docs = list(filter(lambda d: len(d.page_content) > 10, docs)) - - return docs - - # Hook called after rabbithole have splitted text into chunks. # Input is the chunks @hook(priority=0) @@ -159,6 +116,7 @@ def after_rabbithole_splitted_text(chunks: List[Document], cat) -> List[Document return chunks +# TODO_HOOK: is this useful or just a duplication of `after_rabbithole_splitted_text` ? # Hook called when a list of Document is going to be inserted in memory from the rabbit hole. # Here you can edit/summarize the documents before inserting them in memory # Should return a list of documents (each is a langchain Document) diff --git a/core/cat/rabbit_hole.py b/core/cat/rabbit_hole.py index d23b8edd..90ce972a 100644 --- a/core/cat/rabbit_hole.py +++ b/core/cat/rabbit_hole.py @@ -13,6 +13,7 @@ from langchain.docstore.document import Document from qdrant_client.http import models +from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.document_loaders.parsers import PDFMinerParser from langchain.document_loaders.parsers.generic import MimeTypeBasedParser from langchain.document_loaders.parsers.txt import TextParser @@ -334,9 +335,16 @@ def split_text(self, text, chunk_size, chunk_overlap): ) # split the documents using chunk_size and chunk_overlap - docs = self.cat.mad_hatter.execute_hook( - "rabbithole_splits_text", text, chunk_size, chunk_overlap + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separators=["\\n\\n", "\n\n", ".\\n", ".\n", "\\n", "\n", " ", ""], ) + # split text + docs = text_splitter.split_documents(text) + # remove short texts (page numbers, isolated words, etc.) + # TODO: join each short chunk with previous one, instead of deleting them + docs = list(filter(lambda d: len(d.page_content) > 10, docs)) # do something on the text after it is split docs = self.cat.mad_hatter.execute_hook( From cfd2c068d68fd31123e9ed9a751aa11a675939ad Mon Sep 17 00:00:00 2001 From: Riccardo Albero Date: Mon, 18 Sep 2023 18:42:29 +0900 Subject: [PATCH 55/77] add wipe memory by source api --- core/cat/memory/vector_memory.py | 6 +++- core/cat/routes/memory.py | 55 ++++++++++++++++++++++---------- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index 8d3e11a0..8cff2534 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -185,7 +185,11 @@ def recall_memories_from_text(self, text, metadata=None, k=5, threshold=None): return self.recall_memories_from_embedding( query_embedding, metadata=metadata, k=k, threshold=threshold ) - + def delete_points_by_metadata_filter(self, metadata=None): + res = self.client.delete( + collection_name=self.collection_name, + points_selector=self._qdrant_filter_from_dict(metadata), + ) # delete point in collection def delete_points(self, points_ids): res = self.client.delete( diff --git a/core/cat/routes/memory.py b/core/cat/routes/memory.py index 4643a755..add3c2b9 100644 --- a/core/cat/routes/memory.py +++ b/core/cat/routes/memory.py @@ -7,10 +7,10 @@ # GET memories from recall @router.get("/recall/") async def recall_memories_from_text( - request: Request, - text: str = Query(description="Find memories similar to this text."), - k: int = Query(default=100, description="How many memories to return."), - user_id: str = Query(default="user", description="User id."), + request: Request, + text: str = Query(description="Find memories similar to this text."), + k: int = Query(default=100, description="How many memories to return."), + user_id: str = Query(default="user", description="User id."), ) -> Dict: """Search k memories similar to given text.""" @@ -36,7 +36,7 @@ async def recall_memories_from_text( } else: user_filter = None - + memories = vector_memory.collections[c].recall_memories_from_embedding( query_embedding, k=k, @@ -46,7 +46,7 @@ async def recall_memories_from_text( recalled[c] = [] for metadata, score, vector, id in memories: memory_dict = dict(metadata) - memory_dict.pop("lc_kwargs", None) # langchain stuff, not needed + memory_dict.pop("lc_kwargs", None) # langchain stuff, not needed memory_dict["id"] = id memory_dict["score"] = float(score) memory_dict["vector"] = vector @@ -55,7 +55,7 @@ async def recall_memories_from_text( return { "query": query, "vectors": { - "embedder": str(ccat.embedder.__class__.__name__), # TODO: should be the config class name + "embedder": str(ccat.embedder.__class__.__name__), # TODO: should be the config class name "collections": recalled } } @@ -87,7 +87,7 @@ async def get_collections(request: Request) -> Dict: # DELETE all collections @router.delete("/collections/") async def wipe_collections( - request: Request, + request: Request, ) -> Dict: """Delete and create all collections""" @@ -111,7 +111,8 @@ async def wipe_collections( # DELETE one collection @router.delete("/collections/{collection_id}/") -async def wipe_single_collection(request: Request, collection_id: str) -> Dict: +async def wipe_single_collection(request: Request, + collection_id: str) -> Dict: """Delete and recreate a collection""" ccat = request.app.state.ccat @@ -127,7 +128,6 @@ async def wipe_single_collection(request: Request, collection_id: str) -> Dict: to_return = {} - ret = vector_memory.vector_db.delete_collection(collection_name=collection_id) to_return[collection_id] = ret @@ -143,15 +143,15 @@ async def wipe_single_collection(request: Request, collection_id: str) -> Dict: # DELETE memories @router.delete("/collections/{collection_id}/points/{memory_id}/") async def wipe_memory_point( - request: Request, - collection_id: str, - memory_id: str + request: Request, + collection_id: str, + memory_id: str ) -> Dict: """Delete a specific point in memory""" ccat = request.app.state.ccat vector_memory = ccat.memory.vectors - + # check if collection exists collections = list(vector_memory.collections.keys()) if collection_id not in collections: @@ -179,10 +179,33 @@ async def wipe_memory_point( } +@router.delete("/collections/{collection_id}/points") +async def wipe_memory_points_by_source( + request: Request, + collection_id: str, + source: str = Query(description="Source of the points that want to remove"), +) -> Dict: + """Delete a specific point in memory""" + + ccat = request.app.state.ccat + vector_memory = ccat.memory.vectors + + metadata = { + "source": source + } + + # delete point + points = vector_memory.collections[collection_id].delete_points_by_metadata_filter(metadata) + + return { + "deleted": points + } + + # DELETE conversation history from working memory @router.delete("/conversation_history/") async def wipe_conversation_history( - request: Request, + request: Request, ) -> Dict: """Delete conversation history from working memory""" @@ -191,4 +214,4 @@ async def wipe_conversation_history( return { "deleted": True, - } \ No newline at end of file + } From 02f1eceee801461e45dc6622f79eac15993ecb0d Mon Sep 17 00:00:00 2001 From: Riccardo Albero Date: Mon, 18 Sep 2023 18:43:17 +0900 Subject: [PATCH 56/77] missing return res --- core/cat/memory/vector_memory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index 8cff2534..626f2954 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -190,6 +190,7 @@ def delete_points_by_metadata_filter(self, metadata=None): collection_name=self.collection_name, points_selector=self._qdrant_filter_from_dict(metadata), ) + return res # delete point in collection def delete_points(self, points_ids): res = self.client.delete( From 0c0e04cac4258b50a81d4ad2c594ccbaf56fbe80 Mon Sep 17 00:00:00 2001 From: Dany Date: Mon, 18 Sep 2023 12:30:27 +0200 Subject: [PATCH 57/77] Update plugins.py --- core/cat/routes/plugins.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 30538909..3dda7734 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -40,6 +40,7 @@ async def get_available_plugins( # get manifest manifest = deepcopy(p.manifest) # we make a copy to avoid modifying the plugin obj manifest["active"] = p.id in active_plugins # pass along if plugin is active or not + manifest["used_hooks"] = [hook.name for hook in p.hooks] # filter by query plugin_text = [str(field) for field in manifest.values()] From 5d3a136238cb08f4938ff5358900542960912a00 Mon Sep 17 00:00:00 2001 From: Dany Date: Mon, 18 Sep 2023 12:36:29 +0200 Subject: [PATCH 58/77] Update plugins.py --- core/cat/routes/plugins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 3dda7734..87344682 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -40,7 +40,7 @@ async def get_available_plugins( # get manifest manifest = deepcopy(p.manifest) # we make a copy to avoid modifying the plugin obj manifest["active"] = p.id in active_plugins # pass along if plugin is active or not - manifest["used_hooks"] = [hook.name for hook in p.hooks] + manifest["used_hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in p.hooks] # filter by query plugin_text = [str(field) for field in manifest.values()] @@ -49,7 +49,7 @@ async def get_available_plugins( installed_plugins.append(manifest) # do not show already installed plugins among registry plugins - registry_plugins_index.pop( manifest["plugin_url"], None ) + registry_plugins_index.pop(manifest["plugin_url"], None) return { "filters": { From 91370e7db20fd593719afbbd4c9869dd86db22dd Mon Sep 17 00:00:00 2001 From: Dany Date: Mon, 18 Sep 2023 12:40:28 +0200 Subject: [PATCH 59/77] Update plugins.py --- core/cat/routes/plugins.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 87344682..308908f8 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -173,9 +173,12 @@ async def get_plugin_details(plugin_id: str, request: Request) -> Dict: active_plugins = ccat.mad_hatter.load_active_plugins_from_db() + plugin = ccat.mad_hatter.plugins[plugin_id] + # get manifest and active True/False. We make a copy to avoid modifying the original obj - plugin_info = deepcopy(ccat.mad_hatter.plugins[plugin_id].manifest) + plugin_info = deepcopy(plugin.manifest) plugin_info["active"] = plugin_id in active_plugins + plugin_info["used_hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in plugin.hooks] return { "data": plugin_info From 7dde076c6e89109f83168da0fd0b6b4c3766ca11 Mon Sep 17 00:00:00 2001 From: Dany Date: Mon, 18 Sep 2023 14:04:26 +0200 Subject: [PATCH 60/77] Update plugins.py --- core/cat/routes/plugins.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 308908f8..680c5029 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -41,6 +41,7 @@ async def get_available_plugins( manifest = deepcopy(p.manifest) # we make a copy to avoid modifying the plugin obj manifest["active"] = p.id in active_plugins # pass along if plugin is active or not manifest["used_hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in p.hooks] + manifest["used_tools"] = [tool.func.__name__ for tool in p.tools] # filter by query plugin_text = [str(field) for field in manifest.values()] @@ -179,6 +180,7 @@ async def get_plugin_details(plugin_id: str, request: Request) -> Dict: plugin_info = deepcopy(plugin.manifest) plugin_info["active"] = plugin_id in active_plugins plugin_info["used_hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in plugin.hooks] + plugin_info["used_tools"] = [tool.func.__name__ for tool in plugin.tools] return { "data": plugin_info From 9d2a4809443c595fd4981b1c2624d953762ee35b Mon Sep 17 00:00:00 2001 From: Dany Date: Mon, 18 Sep 2023 15:50:08 +0200 Subject: [PATCH 61/77] fixed structure --- core/cat/mad_hatter/decorators/tool.py | 2 ++ core/cat/routes/plugins.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/cat/mad_hatter/decorators/tool.py b/core/cat/mad_hatter/decorators/tool.py index 89993ace..acf3bcb8 100644 --- a/core/cat/mad_hatter/decorators/tool.py +++ b/core/cat/mad_hatter/decorators/tool.py @@ -12,6 +12,8 @@ class CatTool(Tool): def augment_tool(self, cat_instance): self.cat = cat_instance + + self.name = self.func.__name__ # Tool docstring, is also available under self.func.__doc__ self.docstring = self.func.__doc__ diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 680c5029..71aa4f29 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -40,8 +40,8 @@ async def get_available_plugins( # get manifest manifest = deepcopy(p.manifest) # we make a copy to avoid modifying the plugin obj manifest["active"] = p.id in active_plugins # pass along if plugin is active or not - manifest["used_hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in p.hooks] - manifest["used_tools"] = [tool.func.__name__ for tool in p.tools] + manifest["hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in p.hooks] + manifest["tools"] = [{ "name": tool.name } for tool in p.tools] # filter by query plugin_text = [str(field) for field in manifest.values()] @@ -179,8 +179,8 @@ async def get_plugin_details(plugin_id: str, request: Request) -> Dict: # get manifest and active True/False. We make a copy to avoid modifying the original obj plugin_info = deepcopy(plugin.manifest) plugin_info["active"] = plugin_id in active_plugins - plugin_info["used_hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in plugin.hooks] - plugin_info["used_tools"] = [tool.func.__name__ for tool in plugin.tools] + plugin_info["hooks"] = [{ "name": hook.name, "priority": hook.priority } for hook in plugin.hooks] + plugin_info["tools"] = [{ "name": tool.name } for tool in plugin.tools] return { "data": plugin_info From 6daa62658f1fd7cff270d288f0cac8f4e4bc2d55 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Mon, 18 Sep 2023 17:21:58 +0200 Subject: [PATCH 62/77] refactor agent hooks --- core/cat/looking_glass/agent_manager.py | 136 ++++++++++++++- core/cat/looking_glass/cheshire_cat.py | 2 +- .../cat/mad_hatter/core_plugin/hooks/agent.py | 33 +++- .../mad_hatter/core_plugin/hooks/prompt.py | 156 +----------------- 4 files changed, 157 insertions(+), 170 deletions(-) diff --git a/core/cat/looking_glass/agent_manager.py b/core/cat/looking_glass/agent_manager.py index 80e39826..ffe0b3a2 100644 --- a/core/cat/looking_glass/agent_manager.py +++ b/core/cat/looking_glass/agent_manager.py @@ -1,12 +1,20 @@ +from datetime import timedelta +import time +from typing import List, Dict + +from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.chains import LLMChain from langchain.agents import AgentExecutor, LLMSingleActionAgent from cat.looking_glass import prompts from cat.looking_glass.output_parser import ToolOutputParser +from cat.utils import verbal_timedelta from cat.log import log + + class AgentManager: """Manager of Langchain Agent. @@ -99,15 +107,16 @@ def execute_agent(self): # prepare input to be passed to the agent. # Info will be extracted from working memory agent_input = self.format_agent_input() - - # this hook allows to reply without executing the agent (for example canned responses, out-of-topic barriers etc.) - #fast_reply = mad_hatter.execute_hook("before_agent_starts", agent_input) - #if fast_reply: - # return fast_reply - + agent_input = mad_hatter.execute_hook("before_agent_starts", agent_input) + # should we ran the default agent? + fast_reply = {} + fast_reply = self.mad_hatter.execute_hook("agent_fast_reply", fast_reply) + if len(fast_reply.keys()) > 0: + return fast_reply prompt_prefix = mad_hatter.execute_hook("agent_prompt_prefix", prompts.MAIN_PROMPT_PREFIX) prompt_suffix = mad_hatter.execute_hook("agent_prompt_suffix", prompts.MAIN_PROMPT_SUFFIX) + # tools currently recalled in working memory recalled_tools = working_memory["procedural_memories"] # Get the tools names only @@ -218,3 +227,118 @@ def format_agent_input(self): "declarative_memory": declarative_memory_formatted_content, "chat_history": conversation_history_formatted_content, } + + def agent_prompt_episodic_memories(self, memory_docs: List[Document]) -> str: + """Formats episodic memories to be inserted into the prompt. + + Parameters + ---------- + memory_docs : List[Document] + List of Langchain `Document` retrieved from the episodic memory. + + Returns + ------- + memory_content : str + String of retrieved context from the episodic memory. + """ + + # convert docs to simple text + memory_texts = [m[0].page_content.replace("\n", ". ") for m in memory_docs] + + # add time information (e.g. "2 days ago") + memory_timestamps = [] + for m in memory_docs: + + # Get Time information in the Document metadata + timestamp = m[0].metadata["when"] + + # Get Current Time - Time when memory was stored + delta = timedelta(seconds=(time.time() - timestamp)) + + # Convert and Save timestamps to Verbal (e.g. "2 days ago") + memory_timestamps.append(f" ({verbal_timedelta(delta)})") + + # Join Document text content with related temporal information + memory_texts = [a + b for a, b in zip(memory_texts, memory_timestamps)] + + # Format the memories for the output + memories_separator = "\n - " + memory_content = "## Context of things the Human said in the past: " + \ + memories_separator + memories_separator.join(memory_texts) + + # if no data is retrieved from memory don't erite anithing in the prompt + if len(memory_texts) == 0: + memory_content = "" + + return memory_content + + def agent_prompt_declarative_memories(self, memory_docs: List[Document]) -> str: + """Formats the declarative memories for the prompt context. + Such context is placed in the `agent_prompt_prefix` in the place held by {declarative_memory}. + + Parameters + ---------- + memory_docs : List[Document] + list of Langchain `Document` retrieved from the declarative memory. + + Returns + ------- + memory_content : str + String of retrieved context from the declarative memory. + """ + + # convert docs to simple text + memory_texts = [m[0].page_content.replace("\n", ". ") for m in memory_docs] + + # add source information (e.g. "extracted from file.txt") + memory_sources = [] + for m in memory_docs: + + # Get and save the source of the memory + source = m[0].metadata["source"] + memory_sources.append(f" (extracted from {source})") + + # Join Document text content with related source information + memory_texts = [a + b for a, b in zip(memory_texts, memory_sources)] + + # Format the memories for the output + memories_separator = "\n - " + + memory_content = "## Context of documents containing relevant information: " + \ + memories_separator + memories_separator.join(memory_texts) + + # if no data is retrieved from memory don't erite anithing in the prompt + if len(memory_texts) == 0: + memory_content = "" + + return memory_content + + def agent_prompt_chat_history(self, chat_history: List[Dict]) -> str: + """Serialize chat history for the agent input. + Converts to text the recent conversation turns fed to the *Agent*. + + Parameters + ---------- + chat_history : List[Dict] + List of dictionaries collecting speaking turns. + + Returns + ------- + history : str + String with recent conversation turns to be provided as context to the *Agent*. + + Notes + ----- + Such context is placed in the `agent_prompt_suffix` in the place held by {chat_history}. + + The chat history is a dictionary with keys:: + 'who': the name of who said the utterance; + 'message': the utterance. + + """ + history = "" + for turn in chat_history: + history += f"\n - {turn['who']}: {turn['message']}" + + return history + diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index 853187a0..b05936d3 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -446,7 +446,7 @@ def __call__(self, user_message_json): "name": "VectorMemoryError", "description": err_message, } - + # reply with agent try: cat_message = self.agent_manager.execute_agent() diff --git a/core/cat/mad_hatter/core_plugin/hooks/agent.py b/core/cat/mad_hatter/core_plugin/hooks/agent.py index 7e10e536..611d2b88 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/agent.py +++ b/core/cat/mad_hatter/core_plugin/hooks/agent.py @@ -12,13 +12,9 @@ from cat.log import log -# TODO_HOOK @hook(priority=0) -def before_agent_starts(agent_input, cat) -> Union[None, Dict]: - """Hook before the agent starts. - - This hook is useful to shortcut the Cat response. - If you do not want the agent to run, return the final response from here and it will end up in the chat without the agent being executed. +def before_agent_starts(agent_input: Dict, cat) -> Dict: + """Hook to read and edit the agent input Parameters -------- @@ -27,10 +23,31 @@ def before_agent_starts(agent_input, cat) -> Union[None, Dict]: cat : CheshireCat Cheshire Cat instance. + Returns + -------- + response : Dict + Agent Input + """ + + return agent_input + + +@hook(priority=0) +def agent_fast_reply(fast_reply, cat) -> Union[None, Dict]: + """This hook is useful to shortcut the Cat response. + If you do not want the agent to run, return the final response from here and it will end up in the chat without the agent being executed. + + Parameters + -------- + fast_reply: dict + Input is dict (initially empty), which can be enriched whith an "output" key with the shortcut response. + cat : CheshireCat + Cheshire Cat instance. + Returns -------- response : Union[None, Dict] - Cat response if you want to avoid using the agent, or None if you want the agent to be executed. + Cat response if you want to avoid using the agent, or None / {} if you want the agent to be executed. See below for examples of Cat response Examples @@ -55,7 +72,7 @@ def before_agent_starts(agent_input, cat) -> Union[None, Dict]: ``` """ - return None + return fast_reply @hook(priority=0) diff --git a/core/cat/mad_hatter/core_plugin/hooks/prompt.py b/core/cat/mad_hatter/core_plugin/hooks/prompt.py index a0e2413b..8e35c13c 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/prompt.py +++ b/core/cat/mad_hatter/core_plugin/hooks/prompt.py @@ -6,10 +6,7 @@ import time from typing import List, Dict -from datetime import timedelta -from langchain.docstore.document import Document -from cat.utils import verbal_timedelta from cat.mad_hatter.decorators import hook @@ -43,7 +40,7 @@ def agent_prompt_prefix(prefix, cat) -> str: @hook(priority=0) -def agent_prompt_instructions(instructions, cat) -> str: +def agent_prompt_instructions(instructions: str, cat) -> str: """Hook the instruction prompt. Allows to edit the instructions that the Cat feeds to the *Agent*. @@ -111,154 +108,3 @@ def agent_prompt_suffix(prompt_suffix: str, cat) -> str: """ return prompt_suffix - - -# TODO_HOOK -@hook(priority=0) -def agent_prompt_episodic_memories(memory_docs: List[Document], cat) -> str: - """Hook memories retrieved from episodic memory. - - This hook formats the relevant memories retrieved from the context of things the human said in the past. - - Retrieved memories are converted to string and temporal information is added to inform the *Agent* about - when the user said that sentence in the past. - - This hook allows to edit the retrieved memory to condition the information provided as context to the *Agent*. - - Such context is placed in the `agent_prompt_prefix` in the place held by {episodic_memory}. - - Parameters - ---------- - memory_docs : List[Document] - List of Langchain `Document` retrieved from the episodic memory. - cat : CheshireCat - Cheshire Cat instance. - - Returns - ------- - memory_content : str - String of retrieved context from the episodic memory. - - """ - - # convert docs to simple text - memory_texts = [m[0].page_content.replace("\n", ". ") for m in memory_docs] - - # add time information (e.g. "2 days ago") - memory_timestamps = [] - for m in memory_docs: - - # Get Time information in the Document metadata - timestamp = m[0].metadata["when"] - - # Get Current Time - Time when memory was stored - delta = timedelta(seconds=(time.time() - timestamp)) - - # Convert and Save timestamps to Verbal (e.g. "2 days ago") - memory_timestamps.append(f" ({verbal_timedelta(delta)})") - - # Join Document text content with related temporal information - memory_texts = [a + b for a, b in zip(memory_texts, memory_timestamps)] - - # Format the memories for the output - memories_separator = "\n - " - memory_content = "## Context of things the Human said in the past: " + \ - memories_separator + memories_separator.join(memory_texts) - - # if no data is retrieved from memory don't erite anithing in the prompt - if len(memory_texts) == 0: - memory_content = "" - - return memory_content - - -# TODO_HOOK -@hook(priority=0) -def agent_prompt_declarative_memories(memory_docs: List[Document], cat) -> str: - """Hook memories retrieved from declarative memory. - - This hook formats the relevant memories retrieved from the context of documents uploaded in the Cat's memory. - - Retrieved memories are converted to string and the source information is added to inform the *Agent* on - which document the information was retrieved from. - - This hook allows to edit the retrieved memory to condition the information provided as context to the *Agent*. - - Such context is placed in the `agent_prompt_prefix` in the place held by {declarative_memory}. - - Parameters - ---------- - memory_docs : List[Document] - list of Langchain `Document` retrieved from the declarative memory. - cat : CheshireCat - Cheshire Cat instance. - - Returns - ------- - memory_content : str - String of retrieved context from the declarative memory. - """ - - # convert docs to simple text - memory_texts = [m[0].page_content.replace("\n", ". ") for m in memory_docs] - - # add source information (e.g. "extracted from file.txt") - memory_sources = [] - for m in memory_docs: - - # Get and save the source of the memory - source = m[0].metadata["source"] - memory_sources.append(f" (extracted from {source})") - - # Join Document text content with related source information - memory_texts = [a + b for a, b in zip(memory_texts, memory_sources)] - - # Format the memories for the output - memories_separator = "\n - " - - memory_content = "## Context of documents containing relevant information: " + \ - memories_separator + memories_separator.join(memory_texts) - - # if no data is retrieved from memory don't erite anithing in the prompt - if len(memory_texts) == 0: - memory_content = "" - - return memory_content - - -# TODO_HOOK -@hook(priority=0) -def agent_prompt_chat_history(chat_history: List[Dict], cat) -> List[Dict]: - """Hook the chat history. - - This hook converts to text the recent conversation turns fed to the *Agent*. - The hook allows to edit and enhance the chat history provided as context to the *Agent*. - - - Parameters - ---------- - chat_history : List[Dict] - List of dictionaries collecting speaking turns. - cat : CheshireCat - Cheshire Cat instances. - - Returns - ------- - history : str - String with recent conversation turns to be provided as context to the *Agent*. - - Notes - ----- - Such context is placed in the `agent_prompt_suffix` in the place held by {chat_history}. - - The chat history is a dictionary with keys:: - 'who': the name of who said the utterance; - 'message': the utterance. - - """ - history = "" - for turn in chat_history: - history += f"\n - {turn['who']}: {turn['message']}" - - return history - From d00bd8c9d1b852899b93e422778fa954676a4f31 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Mon, 18 Sep 2023 17:36:41 +0200 Subject: [PATCH 63/77] fix bug --- core/cat/looking_glass/agent_manager.py | 16 ++++++---------- core/cat/mad_hatter/mad_hatter.py | 6 +++--- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/core/cat/looking_glass/agent_manager.py b/core/cat/looking_glass/agent_manager.py index ffe0b3a2..3be2d132 100644 --- a/core/cat/looking_glass/agent_manager.py +++ b/core/cat/looking_glass/agent_manager.py @@ -110,7 +110,7 @@ def execute_agent(self): agent_input = mad_hatter.execute_hook("before_agent_starts", agent_input) # should we ran the default agent? fast_reply = {} - fast_reply = self.mad_hatter.execute_hook("agent_fast_reply", fast_reply) + fast_reply = mad_hatter.execute_hook("agent_fast_reply", fast_reply) if len(fast_reply.keys()) > 0: return fast_reply prompt_prefix = mad_hatter.execute_hook("agent_prompt_prefix", prompts.MAIN_PROMPT_PREFIX) @@ -202,22 +202,18 @@ def format_agent_input(self): agent_prompt_chat_history """ - mad_hatter = self.cat.mad_hatter working_memory = self.cat.working_memory # format memories to be inserted in the prompt - episodic_memory_formatted_content = mad_hatter.execute_hook( - "agent_prompt_episodic_memories", - working_memory["episodic_memories"], + episodic_memory_formatted_content = self.agent_prompt_episodic_memories( + working_memory["episodic_memories"] ) - declarative_memory_formatted_content = mad_hatter.execute_hook( - "agent_prompt_declarative_memories", - working_memory["declarative_memories"], + declarative_memory_formatted_content = self.agent_prompt_declarative_memories( + working_memory["declarative_memories"] ) # format conversation history to be inserted in the prompt - conversation_history_formatted_content = mad_hatter.execute_hook( - "agent_prompt_chat_history", + conversation_history_formatted_content = self.agent_prompt_chat_history( working_memory["history"] ) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index b91f09cc..0480679d 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -247,8 +247,6 @@ def toggle_plugin(self, plugin_id): # execute requested hook def execute_hook(self, hook_name, *args): - log.critical(hook_name) - # check if hook is supported if hook_name not in self.hooks.keys(): raise Exception(f"Hook {hook_name} not present in any plugin") @@ -258,6 +256,7 @@ def execute_hook(self, hook_name, *args): if len(args) == 0: for hook in self.hooks[hook_name]: try: + log.debug(f"Executing {hook.plugin_id}::{hook.name} with priotrity {hook.priority}") hook.function(cat=self.ccat) except Exception as e: log.error(f"Error in plugin {hook.plugin_id}::{hook.name}") @@ -276,8 +275,9 @@ def execute_hook(self, hook_name, *args): try: # pass tea_cup to the hooks, along other args # hook has at least one argument, and it will be piped + log.debug(f"Executing {hook.plugin_id}::{hook.name} with priotrity {hook.priority}") tea_spoon = hook.function(tea_cup, *args[1:], cat=self.ccat) - log.info(f"Hook {hook.plugin_id}::{hook.name} returned {tea_spoon}") + log.debug(f"Hook {hook.plugin_id}::{hook.name} returned {tea_spoon}") if tea_spoon is not None: tea_cup = tea_spoon except Exception as e: From ca116ddebbb56b05208784d093f5cfe0034dc0fc Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Tue, 19 Sep 2023 16:05:12 +0200 Subject: [PATCH 64/77] deepcopy pipeable arg --- core/cat/mad_hatter/mad_hatter.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 0480679d..f906f374 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -4,6 +4,7 @@ import shutil import os import traceback +from copy import deepcopy from cat.log import log from cat.db import crud @@ -268,7 +269,7 @@ def execute_hook(self, hook_name, *args): # First argument is passed to `execute_hook` is the pipeable one. # We call it `tea_cup` as every hook called will receive it as an input, # can add sugar, milk, or whatever, and return it for the next hook - tea_cup = args[0] + tea_cup = deepcopy(args[0]) # run hooks for hook in self.hooks[hook_name]: @@ -276,7 +277,11 @@ def execute_hook(self, hook_name, *args): # pass tea_cup to the hooks, along other args # hook has at least one argument, and it will be piped log.debug(f"Executing {hook.plugin_id}::{hook.name} with priotrity {hook.priority}") - tea_spoon = hook.function(tea_cup, *args[1:], cat=self.ccat) + tea_spoon = hook.function( + deepcopy(tea_cup), + *deepcopy(args[1:]), + cat=self.ccat + ) log.debug(f"Hook {hook.plugin_id}::{hook.name} returned {tea_spoon}") if tea_spoon is not None: tea_cup = tea_spoon From 5f97e8ef1c9a6d01fd2644f285f946d432169d03 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Thu, 21 Sep 2023 17:13:39 +0200 Subject: [PATCH 65/77] update url to test registry plugin downlaod --- core/tests/routes/plugins/test_plugins_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/tests/routes/plugins/test_plugins_registry.py b/core/tests/routes/plugins/test_plugins_registry.py index 67d5682d..d8c95a4d 100644 --- a/core/tests/routes/plugins/test_plugins_registry.py +++ b/core/tests/routes/plugins/test_plugins_registry.py @@ -51,7 +51,7 @@ def test_plugin_install_from_registry(client): # install plugin from registry payload = { - "url": "https://github.com/nicola-corbellini/ccat_summarization" + "url": "https://github.com/Furrmidable-Crew/ccat_summarization" } response = client.post("/plugins/upload/registry", json=payload) assert response.status_code == 200 From 5982b283787c35ac057a43b8bae647e9ef35efcf Mon Sep 17 00:00:00 2001 From: Nicola Date: Sat, 23 Sep 2023 00:06:05 +0200 Subject: [PATCH 66/77] make snapshots optional --- .env.example | 3 +++ core/cat/memory/vector_memory.py | 10 +++++++--- docker-compose.yml | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index 1d2eeb01..414f4e30 100644 --- a/.env.example +++ b/.env.example @@ -15,3 +15,6 @@ CORE_PORT=1865 # Log levels LOG_LEVEL=WARNING + +# Turn off memory collections' snapshots on embeddder change +SAVE_MEMORY_SNAPSHOTS=false diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index ac22bc5d..90a1b10a 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -129,9 +129,13 @@ def check_embedding_size(self): log.info(f'Collection "{self.collection_name}" has the same embedder') else: log.warning(f'Collection "{self.collection_name}" has different embedder') - # dump collection on disk before deleting - self.save_dump() - log.info(f'Dump "{self.collection_name}" completed') + # Opt-in memory snapshot saving can be turned off in the .env file with: + # SAVE_MEMORY_SNAPSHOTS=false + log.critical(os.getenv("SAVE_MEMORY_SNAPSHOTS")) + if os.getenv("SAVE_MEMORY_SNAPSHOTS") == "true": + # dump collection on disk before deleting + self.save_dump() + log.info(f'Dump "{self.collection_name}" completed') self.client.delete_collection(self.collection_name) log.warning(f'Collection "{self.collection_name}" deleted') diff --git a/docker-compose.yml b/docker-compose.yml index 40f76985..c77c9d64 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,6 +19,7 @@ services: - API_KEY=${API_KEY:-} - LOG_LEVEL=${LOG_LEVEL:-WARNING} - DEBUG=${DEBUG:-true} + - SAVE_MEMORY_SNAPSHOTS=${SAVE_MEMORY_SNAPSHOTS:-true} ports: - ${CORE_PORT:-1865}:80 volumes: From bec88e1539d7bf339298120c28749706ebc73865 Mon Sep 17 00:00:00 2001 From: Nicola Date: Sat, 23 Sep 2023 00:10:24 +0200 Subject: [PATCH 67/77] make snapshots optional --- .env.example | 2 +- core/cat/memory/vector_memory.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index 414f4e30..3522868f 100644 --- a/.env.example +++ b/.env.example @@ -16,5 +16,5 @@ CORE_PORT=1865 # Log levels LOG_LEVEL=WARNING -# Turn off memory collections' snapshots on embeddder change +# Turn off memory collections' snapshots on embedder change SAVE_MEMORY_SNAPSHOTS=false diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index 90a1b10a..d9783e9d 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -129,9 +129,8 @@ def check_embedding_size(self): log.info(f'Collection "{self.collection_name}" has the same embedder') else: log.warning(f'Collection "{self.collection_name}" has different embedder') - # Opt-in memory snapshot saving can be turned off in the .env file with: + # Memory snapshot saving can be turned off in the .env file with: # SAVE_MEMORY_SNAPSHOTS=false - log.critical(os.getenv("SAVE_MEMORY_SNAPSHOTS")) if os.getenv("SAVE_MEMORY_SNAPSHOTS") == "true": # dump collection on disk before deleting self.save_dump() From 3df9ee207478cc3d59a786955d4de4decbb68866 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 11:46:46 +0200 Subject: [PATCH 68/77] test for deleting points by filter --- core/cat/memory/vector_memory.py | 2 + core/cat/routes/memory.py | 14 ++-- .../tests/routes/memory/test_memory_points.py | 68 +++++++++++++++++-- 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/core/cat/memory/vector_memory.py b/core/cat/memory/vector_memory.py index 728342cb..ef37f9fc 100644 --- a/core/cat/memory/vector_memory.py +++ b/core/cat/memory/vector_memory.py @@ -188,12 +188,14 @@ def recall_memories_from_text(self, text, metadata=None, k=5, threshold=None): return self.recall_memories_from_embedding( query_embedding, metadata=metadata, k=k, threshold=threshold ) + def delete_points_by_metadata_filter(self, metadata=None): res = self.client.delete( collection_name=self.collection_name, points_selector=self._qdrant_filter_from_dict(metadata), ) return res + # delete point in collection def delete_points(self, points_ids): res = self.client.delete( diff --git a/core/cat/routes/memory.py b/core/cat/routes/memory.py index add3c2b9..2ff71ea9 100644 --- a/core/cat/routes/memory.py +++ b/core/cat/routes/memory.py @@ -183,22 +183,18 @@ async def wipe_memory_point( async def wipe_memory_points_by_source( request: Request, collection_id: str, - source: str = Query(description="Source of the points that want to remove"), + metadata: Dict = {}, ) -> Dict: - """Delete a specific point in memory""" + """Delete points in memory by filter""" ccat = request.app.state.ccat vector_memory = ccat.memory.vectors - metadata = { - "source": source - } - - # delete point - points = vector_memory.collections[collection_id].delete_points_by_metadata_filter(metadata) + # delete points + vector_memory.collections[collection_id].delete_points_by_metadata_filter(metadata) return { - "deleted": points + "deleted": [] # TODO: Qdrant does not return deleted points? } diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index 091b6253..7df6a41b 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -1,4 +1,5 @@ -from tests.utils import send_websocket_message +from tests.utils import send_websocket_message, get_declarative_memory_contents + def test_point_deleted(client): @@ -42,7 +43,64 @@ def test_point_deleted(client): assert response.status_code == 200 assert len(json["vectors"]["collections"]["episodic"]) == 0 - # delete again the same point (Qdrant in :memory: bug!) - #res = client.delete(f"/memory/episodic/point/{memory['id']}/") - #assert res.status_code == 422 - #assert res.json()["detail"]["error"] == "Point does not exist." + # delete again the same point (should not be found) + res = client.delete(f"/memory/collections/episodic/points/{memory['id']}/") + assert res.status_code == 400 + assert res.json()["detail"]["error"] == "Point does not exist." + + +def test_points_deleted_by_metadata(client): + + expected_chunks = 5 + + # upload to rabbithole a document + content_type = "application/pdf" + file_name = "sample.pdf" + file_path = f"tests/mocks/{file_name}" + with open(file_path, 'rb') as f: + files = { + 'file': (file_name, f, content_type) + } + + response = client.post("/rabbithole/", files=files) + # check response + assert response.status_code == 200 + + # upload another document + with open(file_path, 'rb') as f: + files = { + 'file': ("sample2.pdf", f, content_type) + } + + response = client.post("/rabbithole/", files=files) + # check response + assert response.status_code == 200 + + # check memory contents + declarative_memories = get_declarative_memory_contents(client) + assert len(declarative_memories) == expected_chunks * 2 + + # delete first document + metadata = { + "source": "sample.pdf" + } + res = client.request("DELETE", "/memory/collections/declarative/points", json=metadata) + + # check memory contents + assert res.status_code == 200 + json = res.json() + assert type(json["deleted"]) == list + #assert len(json["deleted"]) == expected_chunks + declarative_memories = get_declarative_memory_contents(client) + assert len(declarative_memories) == expected_chunks + + # delete second document + metadata = { + "source": "sample2.pdf" + } + res = client.request("DELETE", "/memory/collections/declarative/points", json=metadata) + + # check memory contents + assert res.status_code == 200 + declarative_memories = get_declarative_memory_contents(client) + assert len(declarative_memories) == 0 From bd7de2f5d38fda3b415b15ba1df19fe45137343c Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 11:49:44 +0200 Subject: [PATCH 69/77] test deleting non existent source --- core/tests/routes/memory/test_memory_points.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/core/tests/routes/memory/test_memory_points.py b/core/tests/routes/memory/test_memory_points.py index 7df6a41b..9bb1a72b 100644 --- a/core/tests/routes/memory/test_memory_points.py +++ b/core/tests/routes/memory/test_memory_points.py @@ -49,6 +49,8 @@ def test_point_deleted(client): assert res.json()["detail"]["error"] == "Point does not exist." +# test delete points by filter +# TODO: have a fixture uploading docs and separate test cases def test_points_deleted_by_metadata(client): expected_chunks = 5 @@ -80,12 +82,21 @@ def test_points_deleted_by_metadata(client): declarative_memories = get_declarative_memory_contents(client) assert len(declarative_memories) == expected_chunks * 2 + # delete nothing + metadata = { + "source": "invented.pdf" + } + res = client.request("DELETE", "/memory/collections/declarative/points", json=metadata) + # check memory contents + assert res.status_code == 200 + declarative_memories = get_declarative_memory_contents(client) + assert len(declarative_memories) == expected_chunks * 2 + # delete first document metadata = { "source": "sample.pdf" } res = client.request("DELETE", "/memory/collections/declarative/points", json=metadata) - # check memory contents assert res.status_code == 200 json = res.json() @@ -99,7 +110,6 @@ def test_points_deleted_by_metadata(client): "source": "sample2.pdf" } res = client.request("DELETE", "/memory/collections/declarative/points", json=metadata) - # check memory contents assert res.status_code == 200 declarative_memories = get_declarative_memory_contents(client) From 1f64d5ec64a63358c98611c89b6443b867c4853d Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 11:55:45 +0200 Subject: [PATCH 70/77] update method name --- core/cat/routes/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/cat/routes/memory.py b/core/cat/routes/memory.py index 2ff71ea9..e009b5ee 100644 --- a/core/cat/routes/memory.py +++ b/core/cat/routes/memory.py @@ -180,7 +180,7 @@ async def wipe_memory_point( @router.delete("/collections/{collection_id}/points") -async def wipe_memory_points_by_source( +async def wipe_memory_points_by_metadata( request: Request, collection_id: str, metadata: Dict = {}, From e460622b58994f9fa8daea44bb151879005c5a3e Mon Sep 17 00:00:00 2001 From: Nicorb <67009524+nicola-corbellini@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:12:32 +0200 Subject: [PATCH 71/77] Update docker-compose.yml --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index c77c9d64..24be86b8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,7 +19,7 @@ services: - API_KEY=${API_KEY:-} - LOG_LEVEL=${LOG_LEVEL:-WARNING} - DEBUG=${DEBUG:-true} - - SAVE_MEMORY_SNAPSHOTS=${SAVE_MEMORY_SNAPSHOTS:-true} + - SAVE_MEMORY_SNAPSHOTS=${SAVE_MEMORY_SNAPSHOTS:-false} ports: - ${CORE_PORT:-1865}:80 volumes: From fe0f91c5523d6aa851585f9dbe14ae1daddfa3f8 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 12:24:26 +0200 Subject: [PATCH 72/77] introduce @plugin decorator --- core/cat/mad_hatter/core_plugin/settings.json | 3 +++ .../{hooks/plugin_settings.py => settings.py} | 26 +++++++++---------- core/cat/mad_hatter/decorators/__init__.py | 3 ++- core/cat/mad_hatter/decorators/plugin.py | 14 ++++++++++ 4 files changed, 32 insertions(+), 14 deletions(-) create mode 100644 core/cat/mad_hatter/core_plugin/settings.json rename core/cat/mad_hatter/core_plugin/{hooks/plugin_settings.py => settings.py} (71%) create mode 100644 core/cat/mad_hatter/decorators/plugin.py diff --git a/core/cat/mad_hatter/core_plugin/settings.json b/core/cat/mad_hatter/core_plugin/settings.json new file mode 100644 index 00000000..e044c45b --- /dev/null +++ b/core/cat/mad_hatter/core_plugin/settings.json @@ -0,0 +1,3 @@ +{ + "fake_setting": "a" +} \ No newline at end of file diff --git a/core/cat/mad_hatter/core_plugin/hooks/plugin_settings.py b/core/cat/mad_hatter/core_plugin/settings.py similarity index 71% rename from core/cat/mad_hatter/core_plugin/hooks/plugin_settings.py rename to core/cat/mad_hatter/core_plugin/settings.py index 390890a9..ba8b5668 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/plugin_settings.py +++ b/core/cat/mad_hatter/core_plugin/settings.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from cat.mad_hatter.decorators import hook +from cat.mad_hatter.decorators import plugin from cat.log import log @@ -12,10 +12,10 @@ class CorePluginSettings(BaseModel): # description: str = "my fav cat" # optional field, type str, with a default -@hook(priority=0) -def plugin_settings_schema(): +@plugin +def settings_schema(): """ - This hook tells the cat how plugin settings are defined, required vs optional, default values, etc. + This function tells the cat how plugin settings are defined, required vs optional, default values, etc. The standard used is JSON SCHEMA, so a client can auto-generate html forms (see https://json-schema.org/ ). Schema can be created in several ways: @@ -23,7 +23,7 @@ def plugin_settings_schema(): 2. python dictionary 3. json loaded from current folder or from another place - Default behavior for this hook is defined in: + Default behavior is defined in: `cat.mad_hatter.plugin.Plugin::get_settings_schema` Returns @@ -36,12 +36,12 @@ def plugin_settings_schema(): return CorePluginSettings.schema() -@hook(priority=0) -def plugin_settings_load(): +@plugin +def settings_load(): """ - This hook defines how to load saved settings for the plugin. + This function defines how to load saved settings for the plugin. - Default behavior for this hook is defined in: + Default behavior is defined in: `cat.mad_hatter.plugin.Plugin::load_settings` It loads the settings.json in current folder @@ -55,13 +55,13 @@ def plugin_settings_load(): return {} -@hook(priority=0) -def plugin_settings_save(settings): +@plugin +def settings_save(settings): """ - This hook passes the plugin settings as sent to the http endpoint (via admin, or any client), in order to let the plugin save them as desired. + This function passes the plugin settings as sent to the http endpoint (via admin, or any client), in order to let the plugin save them as desired. The settings to save should be validated according to the json schema given in the `plugin_settings_schema` hook. - Default behavior for this hook is defined in: + Default behavior is defined in: `cat.mad_hatter.plugin.Plugin::save_settings` It just saves contents in a settings.json in the plugin folder diff --git a/core/cat/mad_hatter/decorators/__init__.py b/core/cat/mad_hatter/decorators/__init__.py index 5b6c8b2b..13196534 100644 --- a/core/cat/mad_hatter/decorators/__init__.py +++ b/core/cat/mad_hatter/decorators/__init__.py @@ -1,2 +1,3 @@ from cat.mad_hatter.decorators.hook import CatHook, hook -from cat.mad_hatter.decorators.tool import CatTool, tool \ No newline at end of file +from cat.mad_hatter.decorators.tool import CatTool, tool +from cat.mad_hatter.decorators.plugin import CatPluginFunction, plugin \ No newline at end of file diff --git a/core/cat/mad_hatter/decorators/plugin.py b/core/cat/mad_hatter/decorators/plugin.py new file mode 100644 index 00000000..e2ca1573 --- /dev/null +++ b/core/cat/mad_hatter/decorators/plugin.py @@ -0,0 +1,14 @@ + +# class to represent a @plugin override +class CatPluginFunction: + + def __init__(self, function): + + self.function = function + self.name = function.__name__ + +# @plugin decorator. Any function in a plugin decorated by @plugin and named properly (among list of available overrides) +# is used to override plugin behaviour. These are not hooks because they are not piped, they are specific for every plugin +def plugin(func): + return CatPluginFunction(func) + From 70b828046f8fecbe6c2f5937654a3c71e10155bb Mon Sep 17 00:00:00 2001 From: Nicorb <67009524+nicola-corbellini@users.noreply.github.com> Date: Wed, 27 Sep 2023 12:31:51 +0200 Subject: [PATCH 73/77] Update .env.example --- .env.example | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.env.example b/.env.example index 3522868f..0f59eb80 100644 --- a/.env.example +++ b/.env.example @@ -16,5 +16,5 @@ CORE_PORT=1865 # Log levels LOG_LEVEL=WARNING -# Turn off memory collections' snapshots on embedder change +# Turn on memory collections' snapshots on embedder change with SAVE_MEMORY_SNAPSHOTS=true SAVE_MEMORY_SNAPSHOTS=false From 22d9c8c21d7d7471d57dcaa006ba7a9580b6a07f Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 12:46:14 +0200 Subject: [PATCH 74/77] update Plugin class with @plugin decorator --- core/cat/mad_hatter/core_plugin/settings.py | 4 +- core/cat/mad_hatter/decorators/__init__.py | 2 +- core/cat/mad_hatter/decorators/plugin.py | 4 +- core/cat/mad_hatter/plugin.py | 41 +++++++++++++++------ core/cat/routes/plugins.py | 4 +- 5 files changed, 37 insertions(+), 18 deletions(-) diff --git a/core/cat/mad_hatter/core_plugin/settings.py b/core/cat/mad_hatter/core_plugin/settings.py index ba8b5668..38aa7fb4 100644 --- a/core/cat/mad_hatter/core_plugin/settings.py +++ b/core/cat/mad_hatter/core_plugin/settings.py @@ -37,7 +37,7 @@ def settings_schema(): @plugin -def settings_load(): +def load_settings(): """ This function defines how to load saved settings for the plugin. @@ -56,7 +56,7 @@ def settings_load(): @plugin -def settings_save(settings): +def save_settings(settings): """ This function passes the plugin settings as sent to the http endpoint (via admin, or any client), in order to let the plugin save them as desired. The settings to save should be validated according to the json schema given in the `plugin_settings_schema` hook. diff --git a/core/cat/mad_hatter/decorators/__init__.py b/core/cat/mad_hatter/decorators/__init__.py index 13196534..b4284129 100644 --- a/core/cat/mad_hatter/decorators/__init__.py +++ b/core/cat/mad_hatter/decorators/__init__.py @@ -1,3 +1,3 @@ from cat.mad_hatter.decorators.hook import CatHook, hook from cat.mad_hatter.decorators.tool import CatTool, tool -from cat.mad_hatter.decorators.plugin import CatPluginFunction, plugin \ No newline at end of file +from cat.mad_hatter.decorators.plugin import CatPluginOverride, plugin \ No newline at end of file diff --git a/core/cat/mad_hatter/decorators/plugin.py b/core/cat/mad_hatter/decorators/plugin.py index e2ca1573..d9b3a63c 100644 --- a/core/cat/mad_hatter/decorators/plugin.py +++ b/core/cat/mad_hatter/decorators/plugin.py @@ -1,6 +1,6 @@ # class to represent a @plugin override -class CatPluginFunction: +class CatPluginOverride: def __init__(self, function): @@ -10,5 +10,5 @@ def __init__(self, function): # @plugin decorator. Any function in a plugin decorated by @plugin and named properly (among list of available overrides) # is used to override plugin behaviour. These are not hooks because they are not piped, they are specific for every plugin def plugin(func): - return CatPluginFunction(func) + return CatPluginOverride(func) diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index d1ffea2d..5437a91e 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -8,7 +8,7 @@ from inspect import getmembers from pydantic import BaseModel -from cat.mad_hatter.decorators import CatTool, CatHook +from cat.mad_hatter.decorators import CatTool, CatHook, CatPluginOverride from cat.utils import to_camel_case from cat.log import log, get_log_level @@ -45,11 +45,16 @@ def __init__(self, plugin_path: str): # but they are created and stored in each plugin instance self._hooks = [] self._tools = [] + + # list of @plugin decorated functions overrriding default plugin behaviour + self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access + + # plugin starts deactivated self._active = False def activate(self): # lists of hooks and tools - self._hooks, self._tools = self._load_hooks_and_tools() + self._hooks, self._tools, self._plugin_overrides = self._load_decorated_functions() self._active = True def deactivate(self): @@ -64,14 +69,15 @@ def deactivate(self): self._hooks = [] self._tools = [] + self._plugin_overrides = [] self._active = False # get plugin settings JSON schema - def get_settings_schema(self): + def settings_schema(self): # is "plugin_settings_schema" hook defined in the plugin? - for h in self._hooks: - if h.name == "plugin_settings_schema": + for h in self._plugin_overrides: + if h.name == "settings_schema": return h.function() # default schema (empty) @@ -81,8 +87,8 @@ def get_settings_schema(self): def load_settings(self): # is "plugin_settings_load" hook defined in the plugin? - for h in self._hooks: - if h.name == "plugin_settings_load": + for h in self._plugin_overrides: + if h.name == "load_settings": return h.function() # by default, plugin settings are saved inside the plugin folder @@ -107,8 +113,8 @@ def load_settings(self): def save_settings(self, settings: Dict): # is "plugin_settings_save" hook defined in the plugin? - for h in self._hooks: - if h.name == "plugin_settings_save": + for h in self._plugin_overrides: + if h.name == "save_settings": return h.function(settings) # by default, plugin settings are saved inside the plugin folder @@ -162,9 +168,10 @@ def _load_manifest(self): return meta # lists of hooks and tools - def _load_hooks_and_tools(self): + def _load_decorated_functions(self): hooks = [] tools = [] + plugin_overrides = [] for py_file in self.py_files: py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry @@ -176,6 +183,7 @@ def _load_hooks_and_tools(self): plugin_module = importlib.import_module(py_filename) hooks += getmembers(plugin_module, self._is_cat_hook) tools += getmembers(plugin_module, self._is_cat_tool) + plugin_overrides += getmembers(plugin_module, self._is_cat_plugin_override) except Exception as e: log.error(f"Error in {py_filename}: {str(e)}") traceback.print_exc() @@ -184,8 +192,9 @@ def _load_hooks_and_tools(self): # clean and enrich instances hooks = list(map(self._clean_hook, hooks)) tools = list(map(self._clean_tool, tools)) + plugin_overrides = list(map(self._clean_plugin_override, plugin_overrides)) - return hooks, tools + return hooks, tools, plugin_overrides def _clean_hook(self, hook): # getmembers returns a tuple @@ -198,6 +207,10 @@ def _clean_tool(self, tool): t = tool[1] t.plugin_id = self._id return t + + def _clean_plugin_override(self, plugin_override): + # getmembers returns a tuple + return plugin_override[1] # a plugin hook function has to be decorated with @hook # (which returns an instance of CatHook) @@ -211,6 +224,12 @@ def _is_cat_hook(obj): def _is_cat_tool(obj): return isinstance(obj, CatTool) + # a plugin override function has to be decorated with @plugin + # (which returns an instance of CatPluginOverride) + @staticmethod + def _is_cat_plugin_override(obj): + return isinstance(obj, CatPluginOverride) + @property def path(self): return self._path diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 71aa4f29..409e7d0f 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -220,7 +220,7 @@ async def get_plugins_settings(request: Request) -> Dict: # plugins are managed by the MadHatter class for plugin in ccat.mad_hatter.plugins.values(): plugin_settings = plugin.load_settings() - plugin_schema = plugin.get_settings_schema() + plugin_schema = plugin.settings_schema() if plugin_schema['properties'] == {}: plugin_schema = {} settings.append({ @@ -249,7 +249,7 @@ async def get_plugin_settings(request: Request, plugin_id: str) -> Dict: # plugins are managed by the MadHatter class settings = ccat.mad_hatter.plugins[plugin_id].load_settings() - schema = ccat.mad_hatter.plugins[plugin_id].get_settings_schema() + schema = ccat.mad_hatter.plugins[plugin_id].settings_schema() if schema['properties'] == {}: schema = {} From cf7c7ee31c4fe2aa2de3c50bd694809f6d2c7385 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 12:58:38 +0200 Subject: [PATCH 75/77] update tests --- core/tests/mad_hatter/test_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/tests/mad_hatter/test_plugin.py b/core/tests/mad_hatter/test_plugin.py index 4770c4a1..ce4e3068 100644 --- a/core/tests/mad_hatter/test_plugin.py +++ b/core/tests/mad_hatter/test_plugin.py @@ -97,9 +97,9 @@ def test_deactivate_plugin(plugin): assert len(plugin.tools) == 0 -def test_get_settings_schema(plugin): +def test_settings_schema(plugin): - settings_schema = plugin.get_settings_schema() + settings_schema = plugin.settings_schema() assert type(settings_schema) == dict assert settings_schema["properties"] == {} assert settings_schema["title"] == "BaseModel" From 032493d4e57c38e3918a8bfd8c4c1e45683f9d6a Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 13:44:07 +0200 Subject: [PATCH 76/77] simplify readme --- README.md | 48 +++++++++++++----------------------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 2ff1c84e..1fb6f665 100644 --- a/README.md +++ b/README.md @@ -21,28 +21,21 @@
Logo -

- Customizable AI architecture -

-[![Try in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/cheshire-cat-ai/core) - -## What is this? +## Production ready AI assistant framework The Cheshire Cat is a framework to build custom AIs on top of any language model. If you ever used systems like WordPress or Django to build web apps, imagine the Cat as a similar tool, but specific for AI. Why use the Cat: -- 🌍 Language model agnostic (works with OpenAI, Cohere, HuggingFace models, custom) -- 🐘 Long term memory -- 🚀 Extensible via plugins -- 🔧 Can use external tools (APIs, custom python code, other models) -- 📄 Can ingest documents (.pdf, .txt, .md) -- 🐋 100% [dockerized](https://docs.docker.com/get-docker/) -- 👩‍👧‍👦 Active [Discord community](https://discord.gg/bHX5sNFCYU) +- 🌍 Supports any language model (works with OpenAI chatGPT, LLAMA2, HuggingFace models, custom) +- 🐘 Rememebers conversations and documents and uses them in conversation +- 🚀 Extensible via plugins (AI can connect to your APIs or execute custom python code) +- 🐋 Production Ready - 100% [dockerized](https://docs.docker.com/get-docker/) +- 👩‍👧‍👦 Active [Discord community](https://discord.gg/bHX5sNFCYU) and easy to understand [docs](https://cheshire-cat-ai.github.io/docs/) -If you want to know more about our vision and values, read the [Code of Ethics](./readme/CODE-OF-ETHICS.md). We are committed to openness, privacy and creativity, we want to bring AI to the long tail. +We are committed to openness, privacy and creativity, we want to bring AI to the long tail. If you want to know more about our vision and values, read the [Code of Ethics](./readme/CODE-OF-ETHICS.md). This project is growing fast, refactorings and code changes happens very often, join the [Issues](https://github.com/cheshire-cat-ai/core/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc) to help! @@ -58,7 +51,7 @@ This project is growing fast, refactorings and code changes happens very often, ### Install -To make Cheshire Cat run on your machine, you just need [`docker`](https://docs.docker.com/get-docker/) and [`docker-compose`](https://docs.docker.com/compose/install/) installed. +To make Cheshire Cat run on your machine, you just need [`docker`](https://docs.docker.com/get-docker/) and [`docker compose`](https://docs.docker.com/compose/install/) installed. Clone the repo: ```bash @@ -74,11 +67,9 @@ cd cheshire-cat After that you can run: ```bash -docker-compose up +docker compose up ``` -> NOTE: if you have a later version of docker-compose, use the command `docker compose up` (without the dash). [REF.](https://stackoverflow.com/questions/66514436/difference-between-docker-compose-and-docker-compose) - The first time (only) it will take several minutes, as the images occupy a few GBs. - Chat with the Cheshire Cat on [localhost:1865/admin](http://localhost:1865/admin). @@ -91,7 +82,7 @@ Enjoy the Cat! When you're done, remember to CTRL+c in the terminal and ``` -docker-compose down +docker compose down ``` ### Update @@ -99,10 +90,10 @@ docker-compose down From time to time it is a good idea to update the Cat: ``` -docker-compose down +docker compose down git pull origin main -docker-compose build --no-cache -docker-compose up +docker compose build --no-cache +docker compose up ``` ### Running Tests @@ -119,19 +110,6 @@ You can try Cheshire Cat in GitHub Codespaces. The free account provides 60 free [![Try in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/cheshire-cat-ai/core) - -#### Instructions - -- Right-click this [link](https://codespaces.new/cheshire-cat-ai/core) and select "open a new window." It will open a webpage titled "Create a new codespace": -- You can ignore the options on the screen and press the green button "create codespace" -- Wait for the codespace to load, and then type on the terminal "docker-compose up" -- It will take a few minutes. When you read "Application startup complete," it will show three links: REST API, PUBLIC, and ADMIN. -- Click on the ADMIN link to configure Cheshire Cat and start playing with it. -- Have fun! -- As soon as you're satisfied, you can press "CTRL C" on the terminal to stop the Cheshire Cat. Then type docker-compose down to close the docker container. - -

(back to top)

- ## Roadmap Detailed roadmap is [here](./readme/ROADMAP.md). From c26f020207966c14f2ca96fd6be7a127a1b85631 Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Wed, 27 Sep 2023 13:46:01 +0200 Subject: [PATCH 77/77] bump version --- core/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/pyproject.toml b/core/pyproject.toml index 6894a7a5..03485dcb 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "Cheshire-Cat" -description = "Open source and customizable AI architecture" -version = "1.0.3" +description = "Production ready AI assistant framework" +version = "1.1.0" requires-python = ">=3.10" license = { file="LICENSE" } authors = [