Skip to content

Commit

Permalink
Merge pull request #485 from cheshire-cat-ai/develop
Browse files Browse the repository at this point in the history
version 1.1.2
  • Loading branch information
pieroit authored Oct 17, 2023
2 parents c96bbb7 + f667a96 commit 1e72089
Show file tree
Hide file tree
Showing 14 changed files with 209 additions and 144 deletions.
2 changes: 1 addition & 1 deletion core/cat/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_settings_by_category(category: str) -> List[Dict]:
def create_setting(payload: models.Setting) -> Dict:

# Missing fields (setting_id, updated_at) are filled automatically by pydantic
get_db().insert(payload.dict())
get_db().insert(payload.model_dump())

# retrieve the record we just created
new_record = get_setting_by_id(payload.setting_id)
Expand Down
2 changes: 1 addition & 1 deletion core/cat/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def generate_timestamp():
class SettingBody(BaseModel):
name: str
value: Union[Dict, List]
category: Optional[str]
category: Optional[str] = None

# actual setting class, with additional auto generated id and update time
class Setting(SettingBody):
Expand Down
68 changes: 41 additions & 27 deletions core/cat/factory/embedder.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,78 @@
from typing import Type
import langchain
from pydantic import PyObject, BaseSettings
from pydantic import BaseModel, ConfigDict

from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings


# Base class to manage LLM configuration.
class EmbedderSettings(BaseSettings):
class EmbedderSettings(BaseModel):
# class instantiating the embedder
_pyclass: None
_pyclass: Type = None

# This is related to pydantic, because "model_*" attributes are protected.
# We deactivate the protection because langchain relies on several "model_*" named attributes
model_config = ConfigDict(
protected_namespaces=()
)

# instantiate an Embedder from configuration
@classmethod
def get_embedder_from_config(cls, config):
if cls._pyclass is None:
raise Exception(
"Embedder configuration class has self._pyclass = None. Should be a valid Embedder class"
"Embedder configuration class has self._pyclass==None. Should be a valid Embedder class"
)
return cls._pyclass(**config)
return cls._pyclass.default(**config)


class EmbedderFakeConfig(EmbedderSettings):
size: int = 128
_pyclass: PyObject = langchain.embeddings.FakeEmbeddings
_pyclass: Type = langchain.embeddings.FakeEmbeddings

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra = {
"humanReadableName": "Default Embedder",
"description": "Configuration for default embedder. It just outputs random numbers.",
}
)


class EmbedderDumbConfig(EmbedderSettings):

_pyclass = PyObject = DumbEmbedder
_pyclass: Type = DumbEmbedder

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra = {
"humanReadableName": "Dumb Embedder",
"description": "Configuration for default embedder. It encodes the pairs of characters",
}
)


class EmbedderLlamaCppConfig(EmbedderSettings):
url: str
_pyclass = PyObject = CustomOpenAIEmbeddings
_pyclass: Type = CustomOpenAIEmbeddings

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra = {
"humanReadableName": "Self-hosted llama-cpp-python embedder",
"description": "Self-hosted llama-cpp-python embedder",
}
)


class EmbedderOpenAIConfig(EmbedderSettings):
openai_api_key: str
model: str = "text-embedding-ada-002"
_pyclass: PyObject = langchain.embeddings.OpenAIEmbeddings
_pyclass: Type = langchain.embeddings.OpenAIEmbeddings

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra = {
"humanReadableName": "OpenAI Embedder",
"description": "Configuration for OpenAI embeddings",
}
)


# https://python.langchain.com/en/latest/_modules/langchain/embeddings/openai.html#OpenAIEmbeddings
Expand All @@ -73,37 +84,40 @@ class EmbedderAzureOpenAIConfig(EmbedderSettings):
openai_api_version: str
deployment: str

_pyclass: PyObject = langchain.embeddings.OpenAIEmbeddings
_pyclass: Type = langchain.embeddings.OpenAIEmbeddings

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra = {
"humanReadableName": "Azure OpenAI Embedder",
"description": "Configuration for Azure OpenAI embeddings",
}
)


class EmbedderCohereConfig(EmbedderSettings):
cohere_api_key: str
model: str = "embed-multilingual-v2.0"
_pyclass: PyObject = langchain.embeddings.CohereEmbeddings
_pyclass: Type = langchain.embeddings.CohereEmbeddings

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra = {
"humanReadableName": "Cohere Embedder",
"description": "Configuration for Cohere embeddings",
}
)


class EmbedderHuggingFaceHubConfig(EmbedderSettings):
repo_id: str = "sentence-transformers/all-MiniLM-L12-v2"
huggingfacehub_api_token: str
_pyclass: PyObject = langchain.embeddings.HuggingFaceHubEmbeddings
_pyclass: Type = langchain.embeddings.HuggingFaceHubEmbeddings

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra = {
"humanReadableName": "HuggingFace Hub Embedder",
"description": "Configuration for HuggingFace Hub embeddings",
}
)


SUPPORTED_EMDEDDING_MODELS = [
Expand All @@ -120,7 +134,7 @@ class Config:
# EMBEDDER_SCHEMAS contains metadata to let any client know which fields are required to create the language embedder.
EMBEDDER_SCHEMAS = {}
for config_class in SUPPORTED_EMDEDDING_MODELS:
schema = config_class.schema()
schema = config_class.model_json_schema()

# useful for clients in order to call the correct config endpoints
schema["languageEmbedderName"] = schema["title"]
Expand Down
Loading

0 comments on commit 1e72089

Please sign in to comment.