Skip to content

Commit

Permalink
Merge pull request #6 from Mindinventory/improvement_002
Browse files Browse the repository at this point in the history
Make llm instances configurable.
  • Loading branch information
siddhant-mi authored Mar 13, 2024
2 parents 66137f7 + aa22c50 commit 20d677c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
1 change: 1 addition & 0 deletions mindsql/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
GOOGLE_GEN_AI_VALUE_ERROR = "For GoogleGenAI, config must be provided with an api_key"
GOOGLE_GEN_AI_APIKEY_ERROR = "config must contain a Google AI Studio api_key"
LLAMA_VALUE_ERROR = "For LlamaAI, config must be provided with a model_path"
CONFIG_REQUIRED_ERROR = "Configuration is required."
LLAMA_PROMPT_EXCEPTION = "Prompt cannot be empty."
OPENAI_VALUE_ERROR = "OpenAI API key is required"
OPENAI_PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
Expand Down
4 changes: 2 additions & 2 deletions mindsql/llms/googlegenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __init__(self, config=None):

if 'api_key' not in config:
raise ValueError(GOOGLE_GEN_AI_APIKEY_ERROR)
api_key = config['api_key']
api_key = config.pop('api_key')
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel('gemini-pro')
self.model = genai.GenerativeModel('gemini-pro', **config)

def system_message(self, message: str) -> any:
"""
Expand Down
8 changes: 4 additions & 4 deletions mindsql/llms/llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from llama_cpp import Llama

from .._utils.constants import LLAMA_VALUE_ERROR, LLAMA_PROMPT_EXCEPTION
from .._utils.constants import LLAMA_VALUE_ERROR, LLAMA_PROMPT_EXCEPTION, CONFIG_REQUIRED_ERROR
from .illm import ILlm


Expand All @@ -16,13 +16,13 @@ def __init__(self, config=None):
None
"""
if config is None:
raise ValueError("")
raise ValueError(CONFIG_REQUIRED_ERROR)

if 'model_path' not in config:
raise ValueError(LLAMA_VALUE_ERROR)
path = config['model_path']
path = config.pop('model_path')

self.model = Llama(model_path=path)
self.model = Llama(model_path=path, **config)

def system_message(self, message: str) -> any:
"""
Expand Down
12 changes: 6 additions & 6 deletions mindsql/llms/open_ai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from openai import OpenAI

from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION
from . import ILlm
from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION


class OpenAi(ILlm):
Expand All @@ -16,6 +16,7 @@ def __init__(self, config=None, client=None):
Returns:
None
"""
self.config = config
self.client = client

if client is not None:
Expand All @@ -24,9 +25,8 @@ def __init__(self, config=None, client=None):

if 'api_key' not in config:
raise ValueError(OPENAI_VALUE_ERROR)

if 'api_key' in config:
self.client = OpenAI(api_key=config['api_key'])
api_key = config.pop('api_key')
self.client = OpenAI(api_key=api_key, **config)

def system_message(self, message: str) -> any:
"""
Expand Down Expand Up @@ -82,6 +82,6 @@ def invoke(self, prompt, **kwargs) -> str:
model = self.config.get("model", "gpt-3.5-turbo")
temperature = kwargs.get("temperature", 0.1)
max_tokens = kwargs.get("max_tokens", 500)
response = self.client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}], max_tokens=max_tokens, stop=None,
temperature=temperature)
response = self.client.chat.completions.create(model=model, messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens, stop=None, temperature=temperature)
return response.choices[0].message.content

0 comments on commit 20d677c

Please sign in to comment.