diff --git a/core/cat/factory/dumb_embedder.py b/core/cat/factory/custom_embedder.py similarity index 69% rename from core/cat/factory/dumb_embedder.py rename to core/cat/factory/custom_embedder.py index 825c6474..40b6b29a 100644 --- a/core/cat/factory/dumb_embedder.py +++ b/core/cat/factory/custom_embedder.py @@ -1,8 +1,11 @@ +import os import string +import json from typing import List from itertools import combinations from sklearn.feature_extraction.text import CountVectorizer from langchain.embeddings.base import Embeddings +import httpx class DumbEmbedder(Embeddings): @@ -41,3 +44,24 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_query(self, text: str) -> List[float]: """Embed a string of text and returns the embedding vector as a list of floats.""" return self.embedder.transform([text]).astype(float).todense().tolist()[0] + + +class CustomOpenAIEmbeddings(Embeddings): + """Use LLAMA2 as embedder by calling a self-hosted lama-cpp-python instance. + """ + + def __init__(self, url): + self.url = os.path.join(url, "v1/embeddings") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + payload = json.dumps({"input": texts}) + ret = httpx.post(self.url, data=payload, timeout=None) + ret.raise_for_status() + return [e['embedding'] for e in ret.json()['data']] + + def embed_query(self, text: str) -> List[float]: + payload = json.dumps({"input": text}) + ret = httpx.post(self.url, data=payload, timeout=None) + ret.raise_for_status() + return ret.json()['data'][0]['embedding'] + \ No newline at end of file diff --git a/core/cat/factory/custom_llm.py b/core/cat/factory/custom_llm.py index 3e77f482..2a2d37b5 100644 --- a/core/cat/factory/custom_llm.py +++ b/core/cat/factory/custom_llm.py @@ -1,6 +1,8 @@ +import os from typing import Optional, List, Any, Mapping, Dict import requests from langchain.llms.base import LLM +from langchain.llms.openai import OpenAI class LLMDefault(LLM): @@ -61,3 +63,27 @@ def _identifying_params(self) -> Mapping[str, Any]: "auth_key": self.auth_key, "options": self.options } + + +class CustomOpenAI(OpenAI): + url: str + + def __init__(self, **kwargs): + model_kwargs = { + 'repeat_penalty': kwargs.pop('repeat_penalty'), + 'top_k': kwargs.pop('top_k') + } + + stop = kwargs.pop('stop', None) + if stop: + model_kwargs['stop'] = stop.split(',') + + super().__init__( + openai_api_key=" ", + model_kwargs=model_kwargs, + **kwargs + ) + + self.url = kwargs['url'] + self.openai_api_base = os.path.join(self.url, "v1") + \ No newline at end of file diff --git a/core/cat/factory/embedder.py b/core/cat/factory/embedder.py index 5df0030c..464a1f68 100644 --- a/core/cat/factory/embedder.py +++ b/core/cat/factory/embedder.py @@ -1,7 +1,7 @@ import langchain from pydantic import PyObject, BaseSettings -from cat.factory.dumb_embedder import DumbEmbedder +from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings # Base class to manage LLM configuration. @@ -36,11 +36,22 @@ class EmbedderDumbConfig(EmbedderSettings): class Config: schema_extra = { - "name_human_readable": "Dumb Embedder", + "humanReadableName": "Dumb Embedder", "description": "Configuration for default embedder. It encodes the pairs of characters", } +class EmbedderLlamaCppConfig(EmbedderSettings): + url: str + _pyclass = PyObject = CustomOpenAIEmbeddings + + class Config: + schema_extra = { + "humanReadableName": "Self-hosted llama-cpp-python embedder", + "description": "Self-hosted llama-cpp-python embedder", + } + + class EmbedderOpenAIConfig(EmbedderSettings): openai_api_key: str model: str = "text-embedding-ada-002" @@ -98,6 +109,7 @@ class Config: SUPPORTED_EMDEDDING_MODELS = [ EmbedderDumbConfig, EmbedderFakeConfig, + EmbedderLlamaCppConfig, EmbedderOpenAIConfig, EmbedderAzureOpenAIConfig, EmbedderCohereConfig, diff --git a/core/cat/factory/llm.py b/core/cat/factory/llm.py index 9e7d4dbe..eb781cfc 100644 --- a/core/cat/factory/llm.py +++ b/core/cat/factory/llm.py @@ -1,9 +1,9 @@ import langchain -from typing import Dict +from typing import Dict, List import json from pydantic import PyObject, BaseSettings -from cat.factory.custom_llm import LLMDefault, LLMCustom +from cat.factory.custom_llm import LLMDefault, LLMCustom, CustomOpenAI # Base class to manage LLM configuration. @@ -65,6 +65,22 @@ class Config: } +class LLMLlamaCppConfig(LLMSettings): + url: str + temperature: float = 0.01 + max_tokens: int = 512 + stop: str = "Human:,###" + top_k: int = 40 + top_p: float = 0.95 + repeat_penalty: float = 1.1 + _pyclass: PyObject = CustomOpenAI + + class Config: + schema_extra = { + "humanReadableName": "Self-hosted llama-cpp-python", + "description": "Self-hosted llama-cpp-python compatible LLM", + } + class LLMOpenAIChatConfig(LLMSettings): openai_api_key: str model_name: str = "gpt-3.5-turbo" @@ -230,6 +246,7 @@ class Config: SUPPORTED_LANGUAGE_MODELS = [ LLMDefaultConfig, LLMCustomConfig, + LLMLlamaCppConfig, LLMOpenAIChatConfig, LLMOpenAIConfig, LLMCohereConfig, diff --git a/core/cat/mad_hatter/core_plugin/hooks/models.py b/core/cat/mad_hatter/core_plugin/hooks/models.py index a2779124..a41a4e4c 100644 --- a/core/cat/mad_hatter/core_plugin/hooks/models.py +++ b/core/cat/mad_hatter/core_plugin/hooks/models.py @@ -15,6 +15,7 @@ from langchain import HuggingFaceHub from langchain.chat_models import AzureChatOpenAI from cat.mad_hatter.decorators import hook +from cat.factory.custom_llm import CustomOpenAI @hook(priority=0) @@ -142,6 +143,14 @@ def get_language_embedder(cat) -> embedders.EmbedderSettings: } ) + # Llama-cpp-python + elif type(cat._llm) in [CustomOpenAI]: + embedder = embedders.EmbedderLlamaCppConfig.get_embedder_from_config( + { + "url": cat._llm.url + } + ) + else: # If no embedder matches vendor, and no external embedder is configured, we use the DumbEmbedder. # `This embedder is not a model properly trained diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index ef211be6..8ffe4746 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -3,7 +3,6 @@ import time import shutil import os -from typing import Dict from cat.log import log from cat.db import crud @@ -114,7 +113,9 @@ def load_plugin(self, plugin_path, active): # if plugin is valid, keep a reference self.plugins[plugin.id] = plugin except Exception as e: - log(e, "WARNING") + # Something happened while loading the plugin. + # Print the error and go on with the others. + log(str(e), "ERROR") # Load hooks and tools of the active plugins into MadHatter def sync_hooks_and_tools(self): @@ -212,19 +213,19 @@ def embed_tools(self): # activate / deactivate plugin def toggle_plugin(self, plugin_id): - log(f"toggle plugin {plugin_id}", "WARNING") - if self.plugin_exists(plugin_id): plugin_is_active = plugin_id in self.active_plugins # update list of active plugins if plugin_is_active: + log(f"Toggle plugin {plugin_id}: Deactivate", "WARNING") # Deactivate the plugin self.plugins[plugin_id].deactivate() # Remove the plugin from the list of active plugins self.active_plugins.remove(plugin_id) else: + log(f"Toggle plugin {plugin_id}: Activate", "WARNING") # Activate the plugin self.plugins[plugin_id].activate() # Ass the plugin in the list of active plugins diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index 18f3e4d0..dbe508c0 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -1,14 +1,16 @@ import os +import sys import json import glob import importlib +import traceback from typing import Dict from inspect import getmembers from pydantic import BaseModel from cat.mad_hatter.decorators import CatTool, CatHook from cat.utils import to_camel_case -from cat.log import log +from cat.log import log, get_log_level # this class represents a plugin in memory # the plugin itsefl is managed as much as possible unix style @@ -51,12 +53,19 @@ def __init__(self, plugin_path: str, active: bool): self.activate() def activate(self): - self._active = True # lists of hooks and tools self._hooks, self._tools = self._load_hooks_and_tools() + self._active = True def deactivate(self): self._active = False + + # Remove the imported modules + for py_file in self.py_files: + py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry + log(f"Remove module {py_filename}", "DEBUG") + sys.modules.pop(py_filename) + self._hooks = [] self._tools = [] @@ -163,11 +172,19 @@ def _load_hooks_and_tools(self): for py_file in self.py_files: py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry - # save a reference to decorated functions - plugin_module = importlib.import_module(py_filename) - hooks += getmembers(plugin_module, self._is_cat_hook) - tools += getmembers(plugin_module, self._is_cat_tool) + log(f"Import module {py_filename}", "DEBUG") + # save a reference to decorated functions + try: + plugin_module = importlib.import_module(py_filename) + hooks += getmembers(plugin_module, self._is_cat_hook) + tools += getmembers(plugin_module, self._is_cat_tool) + except Exception as e: + log(f"Error in {py_filename}: {str(e)}","ERROR") + if get_log_level() == "DEBUG": + traceback.print_exc() + raise Exception(f"Unable to load the plugin {self._id}") + # clean and enrich instances hooks = list(map(self._clean_hook, hooks)) tools = list(map(self._clean_tool, tools)) diff --git a/core/cat/routes/embedder.py b/core/cat/routes/embedder.py index 082f8686..6293ebb7 100644 --- a/core/cat/routes/embedder.py +++ b/core/cat/routes/embedder.py @@ -2,7 +2,7 @@ from fastapi import Request, APIRouter, Body, HTTPException -from cat.factory.embedder import EMBEDDER_SCHEMAS +from cat.factory.embedder import EMBEDDER_SCHEMAS, SUPPORTED_EMDEDDING_MODELS from cat.db import crud, models from cat.log import log @@ -20,13 +20,19 @@ # get configured Embedders and configuration schemas @router.get("/settings/") -def get_embedders_settings() -> Dict: +def get_embedders_settings(request: Request) -> Dict: """Get the list of the Embedders""" # get selected Embedder, if any selected = crud.get_setting_by_name(name=EMBEDDER_SELECTED_NAME) if selected is not None: selected = selected["value"]["name"] + # If DB does not contain a selected embedder, it means an embedder was automatically selected. + # Deduce selected embedder: + ccat = request.app.state.ccat + for embedder_config_class in reversed(SUPPORTED_EMDEDDING_MODELS): + if embedder_config_class._pyclass == type(ccat.embedder): + selected = embedder_config_class.__name__ saved_settings = crud.get_settings_by_category(category=EMBEDDER_CATEGORY) saved_settings = { s["name"]: s for s in saved_settings } diff --git a/core/cat/routes/plugins.py b/core/cat/routes/plugins.py index 8505b660..4b9f07c9 100644 --- a/core/cat/routes/plugins.py +++ b/core/cat/routes/plugins.py @@ -158,12 +158,17 @@ async def toggle_plugin(plugin_id: str, request: Request) -> Dict: detail = { "error": "Plugin not found" } ) - # toggle plugin - ccat.mad_hatter.toggle_plugin(plugin_id) - - return { - "info": f"Plugin {plugin_id} toggled" - } + try: + # toggle plugin + ccat.mad_hatter.toggle_plugin(plugin_id) + return { + "info": f"Plugin {plugin_id} toggled" + } + except Exception as e: + raise HTTPException( + status_code = 500, + detail = { "error": str(e)} + ) @router.get("/{plugin_id}") diff --git a/core/cat/routes/upload.py b/core/cat/routes/upload.py index 906ea29a..520cbb0e 100644 --- a/core/cat/routes/upload.py +++ b/core/cat/routes/upload.py @@ -139,3 +139,16 @@ async def upload_memory( "content_type": file.content_type, "info": "Memory is being ingested asynchronously" } + + +@router.get("/allowed-mimetypes/") +async def get_allowed_mimetypes(request: Request) -> Dict: + """Retrieve the allowed mimetypes that can be ingested by the Rabbit Hole""" + + ccat = request.app.state.ccat + + admitted_types = list(ccat.rabbit_hole.file_handlers.keys()) + + return { + "allowed": admitted_types + } \ No newline at end of file diff --git a/core/tests/routes/embedder/test_embedder_setting.py b/core/tests/routes/embedder/test_embedder_setting.py index e5732d7d..2da6f5cf 100644 --- a/core/tests/routes/embedder/test_embedder_setting.py +++ b/core/tests/routes/embedder/test_embedder_setting.py @@ -20,7 +20,8 @@ def test_get_all_embedder_settings(client): expected_schema = EMBEDDER_SCHEMAS[setting["name"]] assert dumps(jsonable_encoder(expected_schema)) == dumps(setting["schema"]) - assert json["selected_configuration"] == None # no embedder configured at stratup + # automatically selected embedder + assert json["selected_configuration"] == "EmbedderDumbConfig" def test_get_embedder_settings_non_existent(client):