From bf2b8fe8598fe828acd9577579574692037e9585 Mon Sep 17 00:00:00 2001 From: marioradix Date: Wed, 27 Sep 2023 23:31:52 +0200 Subject: [PATCH] Fixed the rate limit issue for openai embedder --- core/cat/factory/embedder.py | 21 ++++++++++++++++++++- core/cat/looking_glass/cheshire_cat.py | 9 ++++++--- core/cat/routes/embedder.py | 16 ++++++++++++++++ core/cat/utils.py | 25 ++++++++++++++++++++++++- 4 files changed, 66 insertions(+), 5 deletions(-) diff --git a/core/cat/factory/embedder.py b/core/cat/factory/embedder.py index 464a1f68..788b9f09 100644 --- a/core/cat/factory/embedder.py +++ b/core/cat/factory/embedder.py @@ -2,7 +2,7 @@ from pydantic import PyObject, BaseSettings from cat.factory.custom_embedder import DumbEmbedder, CustomOpenAIEmbeddings - +from cat.utils import check_openai_key_valid # Base class to manage LLM configuration. class EmbedderSettings(BaseSettings): @@ -62,6 +62,16 @@ class Config: "humanReadableName": "OpenAI Embedder", "description": "Configuration for OpenAI embeddings", } + + # instantiate an open ai Embedder from configuration with checking for the validity of the key + @classmethod + def get_embedder_from_config(cls, config): + if cls._pyclass is None: + raise Exception( + "Embedder configuration class has self._pyclass = None. Should be a valid Embedder class" + ) + check_openai_key_valid(config["openai_api_key"]) + return cls._pyclass(**config) # https://python.langchain.com/en/latest/_modules/langchain/embeddings/openai.html#OpenAIEmbeddings @@ -81,6 +91,15 @@ class Config: "description": "Configuration for Azure OpenAI embeddings", } + # instantiate an open ai Embedder from configuration with checking for the validity of the key + @classmethod + def get_embedder_from_config(cls, config): + if cls._pyclass is None: + raise Exception( + "Embedder configuration class has self._pyclass = None. Should be a valid Embedder class" + ) + check_openai_key_valid(config["openai_api_key"]) + return cls._pyclass(**config) class EmbedderCohereConfig(EmbedderSettings): cohere_api_key: str diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index b05936d3..6a7f885d 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -22,7 +22,7 @@ from langchain import HuggingFaceHub from langchain.chat_models import AzureChatOpenAI from cat.factory.custom_llm import CustomOpenAI - +from openai.error import RateLimitError MSG_TYPES = Literal["notification", "chat", "error"] @@ -159,8 +159,11 @@ def get_language_embedder(self) -> embedders.EmbedderSettings: # obtain configuration and instantiate Embedder selected_embedder_config = crud.get_setting_by_name(name=selected_embedder_class) - embedder = FactoryClass.get_embedder_from_config(selected_embedder_config["value"]) - + try: + embedder = FactoryClass.get_embedder_from_config(selected_embedder_config["value"]) + except Exception as e: + embedder = embedders.EmbedderDumbConfig.get_embedder_from_config({}) + log.log(f"A problem occured while loading the embedder, the default embadder will be used, {str(e)}", "WARNING") return embedder # OpenAI embedder diff --git a/core/cat/routes/embedder.py b/core/cat/routes/embedder.py index 6293ebb7..c0079ec9 100644 --- a/core/cat/routes/embedder.py +++ b/core/cat/routes/embedder.py @@ -5,6 +5,9 @@ from cat.factory.embedder import EMBEDDER_SCHEMAS, SUPPORTED_EMDEDDING_MODELS from cat.db import crud, models from cat.log import log +from cat.utils import check_openai_key_valid + +import openai router = APIRouter() @@ -105,6 +108,18 @@ def upsert_embedder_setting( } ) + # check if the openai key is valid + if "openai_api_key" in payload: + try: + check_openai_key_valid(payload["openai_api_key"]) + except Exception as e: + raise HTTPException( + status_code=400, + detail={ + "error": str(e) + } + ) + # create the setting and upsert it final_setting = crud.upsert_setting_by_name( models.Setting(name=languageEmbedderName, category=EMBEDDER_CATEGORY, value=payload) @@ -129,3 +144,4 @@ def upsert_embedder_setting( ccat.mad_hatter.embed_tools() return status + diff --git a/core/cat/utils.py b/core/cat/utils.py index 8792a99e..9219e308 100644 --- a/core/cat/utils.py +++ b/core/cat/utils.py @@ -1,7 +1,7 @@ """Various utiles used from the projects.""" from datetime import timedelta - +import openai def to_camel_case(text :str ) -> str: """Format string to camel case. @@ -67,3 +67,26 @@ def verbal_timedelta(td: timedelta) -> str: return "{} ago".format(abs_delta) else: return "{} ago".format(abs_delta) + +def check_openai_key_valid(key): + """send a request to the openai api to test the key. + + The fucntion send a request to the openai api using the given key and check the response, if the key is not valid an exception will be raised. + + Parameters + ---------- + key : str + openai key. + + Returns + ------- + str + response to the query. + """ + openai.api_key = key + response = openai.Completion.create( + engine="davinci", + prompt="test for the key", + max_tokens=5 + ) + return response