Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #438

Merged
merged 23 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
25a107c
Add LLMLlamaCppConfig in llm factory
AlessandroSpallina Aug 22, 2023
a9c1b3e
Add hyperparamethers to LLMLlamaCppConfig
AlessandroSpallina Aug 23, 2023
1bda39d
Remove unused library
AlessandroSpallina Aug 23, 2023
e167600
Merge branch 'cheshire-cat-ai:main' into feature/llama-support
AlessandroSpallina Aug 22, 2023
c512139
Catch plugin activation error
Pingdred Aug 23, 2023
2ff84f9
Return HTTP 500 if an error occurred during plugin activation
Pingdred Aug 23, 2023
dc43195
Merge branch 'plugin_error' of https://github.com/Pingdred/cheshire-c…
pieroit Aug 24, 2023
0f3c23b
review PR; fix try/except blocks in MadHatter and Plugin
pieroit Aug 24, 2023
6d7bc98
Merge branch 'Pingdred-plugin_error' into develop
pieroit Aug 24, 2023
55a303a
Merge branch 'feature/llama-support' of https://github.com/Alessandro…
pieroit Aug 24, 2023
c4b84f6
review PR LLAMA2 adapter
pieroit Aug 24, 2023
74dd8e6
Merge branch 'AlessandroSpallina-feature/llama-support' into develop
pieroit Aug 24, 2023
1ef99b8
Add support for llama-cpp-python embedder
AlessandroSpallina Aug 26, 2023
bad72cf
Add automatic llama embedder when llama 2 is used as LLM
AlessandroSpallina Aug 27, 2023
724c47a
Update upload.py
zAlweNy26 Aug 28, 2023
c3c9389
Update upload.py
zAlweNy26 Aug 28, 2023
9b1b05e
Merge pull request #437 from zAlweNy26/mimetypes-endpoint
pieroit Aug 28, 2023
83a7f4e
Fix warning when initializing CustomOpenAI
AlessandroSpallina Aug 28, 2023
4af3b10
Merge branch 'feature/llama-embedder' of https://github.com/Alessandr…
pieroit Aug 28, 2023
de2b350
add docstring
pieroit Aug 28, 2023
5d14fb1
return selected embedder even if not in DB
pieroit Aug 28, 2023
6ddbc1c
Merge branch 'AlessandroSpallina-feature/llama-embedder' into develop
pieroit Aug 28, 2023
dc2e9be
fix test
pieroit Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import string
import json
from typing import List
from itertools import combinations
from sklearn.feature_extraction.text import CountVectorizer
from langchain.embeddings.base import Embeddings
import httpx


class DumbEmbedder(Embeddings):
Expand Down Expand Up @@ -41,3 +44,24 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_query(self, text: str) -> List[float]:
"""Embed a string of text and returns the embedding vector as a list of floats."""
return self.embedder.transform([text]).astype(float).todense().tolist()[0]


class CustomOpenAIEmbeddings(Embeddings):
"""Use LLAMA2 as embedder by calling a self-hosted lama-cpp-python instance.
"""

def __init__(self, url):
self.url = os.path.join(url, "v1/embeddings")

def embed_documents(self, texts: List[str]) -> List[List[float]]:
payload = json.dumps({"input": texts})
ret = httpx.post(self.url, data=payload, timeout=None)
ret.raise_for_status()
return [e['embedding'] for e in ret.json()['data']]

def embed_query(self, text: str) -> List[float]:
payload = json.dumps({"input": text})
ret = httpx.post(self.url, data=payload, timeout=None)
ret.raise_for_status()
return ret.json()['data'][0]['embedding']

26 changes: 26 additions & 0 deletions core/cat/factory/custom_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Optional, List, Any, Mapping, Dict
import requests
from langchain.llms.base import LLM
from langchain.llms.openai import OpenAI


class LLMDefault(LLM):
Expand Down Expand Up @@ -61,3 +63,27 @@ def _identifying_params(self) -> Mapping[str, Any]:
"auth_key": self.auth_key,
"options": self.options
}


class CustomOpenAI(OpenAI):
url: str

def __init__(self, **kwargs):
model_kwargs = {
'repeat_penalty': kwargs.pop('repeat_penalty'),
'top_k': kwargs.pop('top_k')
}

stop = kwargs.pop('stop', None)
if stop:
model_kwargs['stop'] = stop.split(',')

super().__init__(
openai_api_key=" ",
model_kwargs=model_kwargs,
**kwargs
)

self.url = kwargs['url']
self.openai_api_base = os.path.join(self.url, "v1")

16 changes: 14 additions & 2 deletions core/cat/factory/embedder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import langchain
from pydantic import PyObject, BaseSettings

from cat.factory.dumb_embedder import DumbEmbedder
from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings


# Base class to manage LLM configuration.
Expand Down Expand Up @@ -36,11 +36,22 @@ class EmbedderDumbConfig(EmbedderSettings):

class Config:
schema_extra = {
"name_human_readable": "Dumb Embedder",
"humanReadableName": "Dumb Embedder",
"description": "Configuration for default embedder. It encodes the pairs of characters",
}


class EmbedderLlamaCppConfig(EmbedderSettings):
url: str
_pyclass = PyObject = CustomOpenAIEmbeddings

class Config:
schema_extra = {
"humanReadableName": "Self-hosted llama-cpp-python embedder",
"description": "Self-hosted llama-cpp-python embedder",
}


class EmbedderOpenAIConfig(EmbedderSettings):
openai_api_key: str
model: str = "text-embedding-ada-002"
Expand Down Expand Up @@ -98,6 +109,7 @@ class Config:
SUPPORTED_EMDEDDING_MODELS = [
EmbedderDumbConfig,
EmbedderFakeConfig,
EmbedderLlamaCppConfig,
EmbedderOpenAIConfig,
EmbedderAzureOpenAIConfig,
EmbedderCohereConfig,
Expand Down
21 changes: 19 additions & 2 deletions core/cat/factory/llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import langchain
from typing import Dict
from typing import Dict, List
import json
from pydantic import PyObject, BaseSettings

from cat.factory.custom_llm import LLMDefault, LLMCustom
from cat.factory.custom_llm import LLMDefault, LLMCustom, CustomOpenAI


# Base class to manage LLM configuration.
Expand Down Expand Up @@ -65,6 +65,22 @@ class Config:
}


class LLMLlamaCppConfig(LLMSettings):
url: str
temperature: float = 0.01
max_tokens: int = 512
stop: str = "Human:,###"
top_k: int = 40
top_p: float = 0.95
repeat_penalty: float = 1.1
_pyclass: PyObject = CustomOpenAI

class Config:
schema_extra = {
"humanReadableName": "Self-hosted llama-cpp-python",
"description": "Self-hosted llama-cpp-python compatible LLM",
}

class LLMOpenAIChatConfig(LLMSettings):
openai_api_key: str
model_name: str = "gpt-3.5-turbo"
Expand Down Expand Up @@ -230,6 +246,7 @@ class Config:
SUPPORTED_LANGUAGE_MODELS = [
LLMDefaultConfig,
LLMCustomConfig,
LLMLlamaCppConfig,
LLMOpenAIChatConfig,
LLMOpenAIConfig,
LLMCohereConfig,
Expand Down
9 changes: 9 additions & 0 deletions core/cat/mad_hatter/core_plugin/hooks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from langchain import HuggingFaceHub
from langchain.chat_models import AzureChatOpenAI
from cat.mad_hatter.decorators import hook
from cat.factory.custom_llm import CustomOpenAI


@hook(priority=0)
Expand Down Expand Up @@ -142,6 +143,14 @@ def get_language_embedder(cat) -> embedders.EmbedderSettings:
}
)

# Llama-cpp-python
elif type(cat._llm) in [CustomOpenAI]:
embedder = embedders.EmbedderLlamaCppConfig.get_embedder_from_config(
{
"url": cat._llm.url
}
)

else:
# If no embedder matches vendor, and no external embedder is configured, we use the DumbEmbedder.
# `This embedder is not a model properly trained
Expand Down
9 changes: 5 additions & 4 deletions core/cat/mad_hatter/mad_hatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
import shutil
import os
from typing import Dict

from cat.log import log
from cat.db import crud
Expand Down Expand Up @@ -114,7 +113,9 @@ def load_plugin(self, plugin_path, active):
# if plugin is valid, keep a reference
self.plugins[plugin.id] = plugin
except Exception as e:
log(e, "WARNING")
# Something happened while loading the plugin.
# Print the error and go on with the others.
log(str(e), "ERROR")

# Load hooks and tools of the active plugins into MadHatter
def sync_hooks_and_tools(self):
Expand Down Expand Up @@ -212,19 +213,19 @@ def embed_tools(self):

# activate / deactivate plugin
def toggle_plugin(self, plugin_id):
log(f"toggle plugin {plugin_id}", "WARNING")

if self.plugin_exists(plugin_id):

plugin_is_active = plugin_id in self.active_plugins

# update list of active plugins
if plugin_is_active:
log(f"Toggle plugin {plugin_id}: Deactivate", "WARNING")
# Deactivate the plugin
self.plugins[plugin_id].deactivate()
# Remove the plugin from the list of active plugins
self.active_plugins.remove(plugin_id)
else:
log(f"Toggle plugin {plugin_id}: Activate", "WARNING")
# Activate the plugin
self.plugins[plugin_id].activate()
# Ass the plugin in the list of active plugins
Expand Down
29 changes: 23 additions & 6 deletions core/cat/mad_hatter/plugin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
import sys
import json
import glob
import importlib
import traceback
from typing import Dict
from inspect import getmembers
from pydantic import BaseModel

from cat.mad_hatter.decorators import CatTool, CatHook
from cat.utils import to_camel_case
from cat.log import log
from cat.log import log, get_log_level

# this class represents a plugin in memory
# the plugin itsefl is managed as much as possible unix style
Expand Down Expand Up @@ -51,12 +53,19 @@ def __init__(self, plugin_path: str, active: bool):
self.activate()

def activate(self):
self._active = True
# lists of hooks and tools
self._hooks, self._tools = self._load_hooks_and_tools()
self._active = True

def deactivate(self):
self._active = False

# Remove the imported modules
for py_file in self.py_files:
py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry
log(f"Remove module {py_filename}", "DEBUG")
sys.modules.pop(py_filename)

self._hooks = []
self._tools = []

Expand Down Expand Up @@ -163,11 +172,19 @@ def _load_hooks_and_tools(self):
for py_file in self.py_files:
py_filename = py_file.replace("/", ".").replace(".py", "") # this is UGLY I know. I'm sorry

# save a reference to decorated functions
plugin_module = importlib.import_module(py_filename)
hooks += getmembers(plugin_module, self._is_cat_hook)
tools += getmembers(plugin_module, self._is_cat_tool)
log(f"Import module {py_filename}", "DEBUG")

# save a reference to decorated functions
try:
plugin_module = importlib.import_module(py_filename)
hooks += getmembers(plugin_module, self._is_cat_hook)
tools += getmembers(plugin_module, self._is_cat_tool)
except Exception as e:
log(f"Error in {py_filename}: {str(e)}","ERROR")
if get_log_level() == "DEBUG":
traceback.print_exc()
raise Exception(f"Unable to load the plugin {self._id}")

# clean and enrich instances
hooks = list(map(self._clean_hook, hooks))
tools = list(map(self._clean_tool, tools))
Expand Down
10 changes: 8 additions & 2 deletions core/cat/routes/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import Request, APIRouter, Body, HTTPException

from cat.factory.embedder import EMBEDDER_SCHEMAS
from cat.factory.embedder import EMBEDDER_SCHEMAS, SUPPORTED_EMDEDDING_MODELS
from cat.db import crud, models
from cat.log import log

Expand All @@ -20,13 +20,19 @@

# get configured Embedders and configuration schemas
@router.get("/settings/")
def get_embedders_settings() -> Dict:
def get_embedders_settings(request: Request) -> Dict:
"""Get the list of the Embedders"""

# get selected Embedder, if any
selected = crud.get_setting_by_name(name=EMBEDDER_SELECTED_NAME)
if selected is not None:
selected = selected["value"]["name"]
# If DB does not contain a selected embedder, it means an embedder was automatically selected.
# Deduce selected embedder:
ccat = request.app.state.ccat
for embedder_config_class in reversed(SUPPORTED_EMDEDDING_MODELS):
if embedder_config_class._pyclass == type(ccat.embedder):
selected = embedder_config_class.__name__

saved_settings = crud.get_settings_by_category(category=EMBEDDER_CATEGORY)
saved_settings = { s["name"]: s for s in saved_settings }
Expand Down
17 changes: 11 additions & 6 deletions core/cat/routes/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,17 @@ async def toggle_plugin(plugin_id: str, request: Request) -> Dict:
detail = { "error": "Plugin not found" }
)

# toggle plugin
ccat.mad_hatter.toggle_plugin(plugin_id)

return {
"info": f"Plugin {plugin_id} toggled"
}
try:
# toggle plugin
ccat.mad_hatter.toggle_plugin(plugin_id)
return {
"info": f"Plugin {plugin_id} toggled"
}
except Exception as e:
raise HTTPException(
status_code = 500,
detail = { "error": str(e)}
)


@router.get("/{plugin_id}")
Expand Down
13 changes: 13 additions & 0 deletions core/cat/routes/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,16 @@ async def upload_memory(
"content_type": file.content_type,
"info": "Memory is being ingested asynchronously"
}


@router.get("/allowed-mimetypes/")
async def get_allowed_mimetypes(request: Request) -> Dict:
"""Retrieve the allowed mimetypes that can be ingested by the Rabbit Hole"""

ccat = request.app.state.ccat

admitted_types = list(ccat.rabbit_hole.file_handlers.keys())

return {
"allowed": admitted_types
}
3 changes: 2 additions & 1 deletion core/tests/routes/embedder/test_embedder_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_get_all_embedder_settings(client):
expected_schema = EMBEDDER_SCHEMAS[setting["name"]]
assert dumps(jsonable_encoder(expected_schema)) == dumps(setting["schema"])

assert json["selected_configuration"] == None # no embedder configured at stratup
# automatically selected embedder
assert json["selected_configuration"] == "EmbedderDumbConfig"


def test_get_embedder_settings_non_existent(client):
Expand Down
Loading