Skip to content

Commit

Permalink
Merge pull request #593 from vanna-ai/mistral-1-fix
Browse files Browse the repository at this point in the history
Mistral 1.0.0
  • Loading branch information
zainhoda authored Aug 9, 2024
2 parents 75776e3 + 7208d63 commit d39dae3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
qianfan = ["qianfan"]
mistralai = ["mistralai"]
mistralai = ["mistralai>=1.0.0"]
anthropic = ["anthropic"]
gemini = ["google-generativeai"]
marqo = ["marqo"]
Expand Down
14 changes: 8 additions & 6 deletions src/vanna/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
import os

from mistralai import Mistral as MistralClient
from mistralai import UserMessage

from ..base import VannaBase

Expand All @@ -23,13 +25,13 @@ def __init__(self, config=None):
self.model = model

def system_message(self, message: str) -> any:
return ChatMessage(role="system", content=message)
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return ChatMessage(role="user", content=message)
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return ChatMessage(role="assistant", content=message)
return {"role": "assistant", "content": message}

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
Expand All @@ -41,7 +43,7 @@ def generate_sql(self, question: str, **kwargs) -> str:
return sql

def submit_prompt(self, prompt, **kwargs) -> str:
chat_response = self.client.chat(
chat_response = self.client.chat.complete(
model=self.model,
messages=prompt,
)
Expand Down

0 comments on commit d39dae3

Please sign in to comment.