diff --git a/pyproject.toml b/pyproject.toml index 4c0d3943..1e7d0eb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,12 +33,12 @@ bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] google = ["google-generativeai", "google-cloud-aiplatform"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] qianfan = ["qianfan"] -mistralai = ["mistralai"] +mistralai = ["mistralai>=1.0.0"] anthropic = ["anthropic"] gemini = ["google-generativeai"] marqo = ["marqo"] diff --git a/src/vanna/mistral/mistral.py b/src/vanna/mistral/mistral.py index f5c89ac8..11d0b813 100644 --- a/src/vanna/mistral/mistral.py +++ b/src/vanna/mistral/mistral.py @@ -1,5 +1,7 @@ -from mistralai.client import MistralClient -from mistralai.models.chat_completion import ChatMessage +import os + +from mistralai import Mistral as MistralClient +from mistralai import UserMessage from ..base import VannaBase @@ -23,13 +25,13 @@ def __init__(self, config=None): self.model = model def system_message(self, message: str) -> any: - return ChatMessage(role="system", content=message) + return {"role": "system", "content": message} def user_message(self, message: str) -> any: - return ChatMessage(role="user", content=message) + return {"role": "user", "content": message} def assistant_message(self, message: str) -> any: - return ChatMessage(role="assistant", content=message) + return {"role": "assistant", "content": message} def generate_sql(self, question: str, **kwargs) -> str: # Use the super generate_sql @@ -41,7 +43,7 @@ def generate_sql(self, question: str, **kwargs) -> str: return sql def submit_prompt(self, prompt, **kwargs) -> str: - chat_response = self.client.chat( + chat_response = self.client.chat.complete( model=self.model, messages=prompt, )