Skip to content

Commit

Permalink
Fixed the rate limit issue for openai embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
marioradix committed Sep 27, 2023
1 parent c26f020 commit bf2b8fe
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
21 changes: 20 additions & 1 deletion core/cat/factory/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import PyObject, BaseSettings

from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings

from cat.utils import check_openai_key_valid

# Base class to manage LLM configuration.
class EmbedderSettings(BaseSettings):
Expand Down Expand Up @@ -62,6 +62,16 @@ class Config:
"humanReadableName": "OpenAI Embedder",
"description": "Configuration for OpenAI embeddings",
}

# instantiate an open ai Embedder from configuration with checking for the validity of the key
@classmethod
def get_embedder_from_config(cls, config):
if cls._pyclass is None:
raise Exception(
"Embedder configuration class has self._pyclass = None. Should be a valid Embedder class"
)
check_openai_key_valid(config["openai_api_key"])
return cls._pyclass(**config)


# https://python.langchain.com/en/latest/_modules/langchain/embeddings/openai.html#OpenAIEmbeddings
Expand All @@ -81,6 +91,15 @@ class Config:
"description": "Configuration for Azure OpenAI embeddings",
}

# instantiate an open ai Embedder from configuration with checking for the validity of the key
@classmethod
def get_embedder_from_config(cls, config):
if cls._pyclass is None:
raise Exception(
"Embedder configuration class has self._pyclass = None. Should be a valid Embedder class"
)
check_openai_key_valid(config["openai_api_key"])
return cls._pyclass(**config)

class EmbedderCohereConfig(EmbedderSettings):
cohere_api_key: str
Expand Down
9 changes: 6 additions & 3 deletions core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from langchain import HuggingFaceHub
from langchain.chat_models import AzureChatOpenAI
from cat.factory.custom_llm import CustomOpenAI

from openai.error import RateLimitError

MSG_TYPES = Literal["notification", "chat", "error"]

Expand Down Expand Up @@ -159,8 +159,11 @@ def get_language_embedder(self) -> embedders.EmbedderSettings:

# obtain configuration and instantiate Embedder
selected_embedder_config = crud.get_setting_by_name(name=selected_embedder_class)
embedder = FactoryClass.get_embedder_from_config(selected_embedder_config["value"])

try:
embedder = FactoryClass.get_embedder_from_config(selected_embedder_config["value"])
except Exception as e:
embedder = embedders.EmbedderDumbConfig.get_embedder_from_config({})
log.log(f"A problem occured while loading the embedder, the default embadder will be used, {str(e)}", "WARNING")
return embedder

# OpenAI embedder
Expand Down
16 changes: 16 additions & 0 deletions core/cat/routes/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from cat.factory.embedder import EMBEDDER_SCHEMAS, SUPPORTED_EMDEDDING_MODELS
from cat.db import crud, models
from cat.log import log
from cat.utils import check_openai_key_valid

import openai

router = APIRouter()

Expand Down Expand Up @@ -105,6 +108,18 @@ def upsert_embedder_setting(
}
)

# check if the openai key is valid
if "openai_api_key" in payload:
try:
check_openai_key_valid(payload["openai_api_key"])
except Exception as e:
raise HTTPException(
status_code=400,
detail={
"error": str(e)
}
)

# create the setting and upsert it
final_setting = crud.upsert_setting_by_name(
models.Setting(name=languageEmbedderName, category=EMBEDDER_CATEGORY, value=payload)
Expand All @@ -129,3 +144,4 @@ def upsert_embedder_setting(
ccat.mad_hatter.embed_tools()

return status

25 changes: 24 additions & 1 deletion core/cat/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Various utiles used from the projects."""

from datetime import timedelta

import openai

def to_camel_case(text :str ) -> str:
"""Format string to camel case.
Expand Down Expand Up @@ -67,3 +67,26 @@ def verbal_timedelta(td: timedelta) -> str:
return "{} ago".format(abs_delta)
else:
return "{} ago".format(abs_delta)

def check_openai_key_valid(key):
"""send a request to the openai api to test the key.
The fucntion send a request to the openai api using the given key and check the response, if the key is not valid an exception will be raised.
Parameters
----------
key : str
openai key.
Returns
-------
str
response to the query.
"""
openai.api_key = key
response = openai.Completion.create(
engine="davinci",
prompt="test for the key",
max_tokens=5
)
return response

0 comments on commit bf2b8fe

Please sign in to comment.