From aa22c50e54a6d98e07b3ae6572e096f6ca34c35c Mon Sep 17 00:00:00 2001 From: Ishika Shah Date: Wed, 13 Mar 2024 11:22:44 +0530 Subject: [PATCH] improvement_002: make llm instances configurable. --- mindsql/_utils/constants.py | 1 + mindsql/llms/googlegenai.py | 4 ++-- mindsql/llms/llama.py | 8 ++++---- mindsql/llms/open_ai.py | 12 ++++++------ 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mindsql/_utils/constants.py b/mindsql/_utils/constants.py index 56b30f5..9f5dea1 100644 --- a/mindsql/_utils/constants.py +++ b/mindsql/_utils/constants.py @@ -27,6 +27,7 @@ GOOGLE_GEN_AI_VALUE_ERROR = "For GoogleGenAI, config must be provided with an api_key" GOOGLE_GEN_AI_APIKEY_ERROR = "config must contain a Google AI Studio api_key" LLAMA_VALUE_ERROR = "For LlamaAI, config must be provided with a model_path" +CONFIG_REQUIRED_ERROR = "Configuration is required." LLAMA_PROMPT_EXCEPTION = "Prompt cannot be empty." OPENAI_VALUE_ERROR = "OpenAI API key is required" OPENAI_PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty." diff --git a/mindsql/llms/googlegenai.py b/mindsql/llms/googlegenai.py index 765529d..9e80582 100644 --- a/mindsql/llms/googlegenai.py +++ b/mindsql/llms/googlegenai.py @@ -20,9 +20,9 @@ def __init__(self, config=None): if 'api_key' not in config: raise ValueError(GOOGLE_GEN_AI_APIKEY_ERROR) - api_key = config['api_key'] + api_key = config.pop('api_key') genai.configure(api_key=api_key) - self.model = genai.GenerativeModel('gemini-pro') + self.model = genai.GenerativeModel('gemini-pro', **config) def system_message(self, message: str) -> any: """ diff --git a/mindsql/llms/llama.py b/mindsql/llms/llama.py index 1ddeedf..7147ddd 100644 --- a/mindsql/llms/llama.py +++ b/mindsql/llms/llama.py @@ -1,6 +1,6 @@ from llama_cpp import Llama -from .._utils.constants import LLAMA_VALUE_ERROR, LLAMA_PROMPT_EXCEPTION +from .._utils.constants import LLAMA_VALUE_ERROR, LLAMA_PROMPT_EXCEPTION, CONFIG_REQUIRED_ERROR from .illm import ILlm @@ -16,13 +16,13 @@ def __init__(self, config=None): None """ if config is None: - raise ValueError("") + raise ValueError(CONFIG_REQUIRED_ERROR) if 'model_path' not in config: raise ValueError(LLAMA_VALUE_ERROR) - path = config['model_path'] + path = config.pop('model_path') - self.model = Llama(model_path=path) + self.model = Llama(model_path=path, **config) def system_message(self, message: str) -> any: """ diff --git a/mindsql/llms/open_ai.py b/mindsql/llms/open_ai.py index 097a1ec..5cf63a9 100644 --- a/mindsql/llms/open_ai.py +++ b/mindsql/llms/open_ai.py @@ -1,7 +1,7 @@ from openai import OpenAI -from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION from . import ILlm +from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION class OpenAi(ILlm): @@ -16,6 +16,7 @@ def __init__(self, config=None, client=None): Returns: None """ + self.config = config self.client = client if client is not None: @@ -24,9 +25,8 @@ def __init__(self, config=None, client=None): if 'api_key' not in config: raise ValueError(OPENAI_VALUE_ERROR) - - if 'api_key' in config: - self.client = OpenAI(api_key=config['api_key']) + api_key = config.pop('api_key') + self.client = OpenAI(api_key=api_key, **config) def system_message(self, message: str) -> any: """ @@ -82,6 +82,6 @@ def invoke(self, prompt, **kwargs) -> str: model = self.config.get("model", "gpt-3.5-turbo") temperature = kwargs.get("temperature", 0.1) max_tokens = kwargs.get("max_tokens", 500) - response = self.client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}], max_tokens=max_tokens, stop=None, - temperature=temperature) + response = self.client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, stop=None, temperature=temperature) return response.choices[0].message.content