From 6ce895121fa4d05276e90a1e992234e36b8adb98 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Wed, 17 Jan 2024 11:06:54 -0500 Subject: [PATCH] Make documentation argument name consistent --- pyproject.toml | 2 +- src/vanna/base/base.py | 27 +++--- src/vanna/chromadb/chromadb_vector.py | 62 +++++++------ src/vanna/marqo/marqo.py | 126 +++++++++++++------------- src/vanna/vannadb/vannadb_vector.py | 33 ++++--- 5 files changed, 132 insertions(+), 118 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 418e8910..8da49c9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "vanna" -version = "0.0.32" +version = "0.0.33" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 5e44ef85..20adcfc2 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1,8 +1,8 @@ import json import os +import re import sqlite3 import traceback - from abc import ABC, abstractmethod from typing import List, Tuple, Union from urllib.parse import urlparse @@ -12,7 +12,6 @@ import plotly.express as px import plotly.graph_objects as go import requests -import re from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError from ..types import TrainingPlan, TrainingPlanItem @@ -50,8 +49,8 @@ def generate_followup_questions(self, question: str, **kwargs) -> str: **kwargs, ) llm_response = self.submit_prompt(prompt, **kwargs) - - numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE) + + numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE) return numbers_removed.split("\n") def generate_questions(self, **kwargs) -> list[str]: @@ -65,7 +64,7 @@ def generate_questions(self, **kwargs) -> list[str]: """ question_sql = self.get_similar_question_sql(question="", **kwargs) - return [q['question'] for q in question_sql] + return [q["question"] for q in question_sql] # ----------------- Use Any Embeddings API ----------------- # @abstractmethod @@ -94,7 +93,7 @@ def add_ddl(self, ddl: str, **kwargs) -> str: pass @abstractmethod - def add_documentation(self, doc: str, **kwargs) -> str: + def add_documentation(self, documentation: str, **kwargs) -> str: pass @abstractmethod @@ -120,12 +119,12 @@ def get_sql_prompt( @abstractmethod def get_followup_questions_prompt( - self, - question: str, + self, + question: str, question_sql_list: list, ddl_list: list, - doc_list: list, - **kwargs + doc_list: list, + **kwargs, ): pass @@ -829,9 +828,11 @@ def get_plotly_figure( fig = ldict.get("fig", None) except Exception as e: # Inspect data types - numeric_cols = df.select_dtypes(include=['number']).columns.tolist() - categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() - + numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() + categorical_cols = df.select_dtypes( + include=["object", "category"] + ).columns.tolist() + # Decision-making for plot type if len(numeric_cols) >= 2: # Use the first two numeric columns for a scatter plot diff --git a/src/vanna/chromadb/chromadb_vector.py b/src/vanna/chromadb/chromadb_vector.py index 7af60a80..796c08a4 100644 --- a/src/vanna/chromadb/chromadb_vector.py +++ b/src/vanna/chromadb/chromadb_vector.py @@ -3,9 +3,9 @@ from abc import abstractmethod import chromadb +import pandas as pd from chromadb.config import Settings from chromadb.utils import embedding_functions -import pandas as pd from ..base import VannaBase @@ -47,7 +47,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: "sql": sql, } ) - id = str(uuid.uuid4())+"-sql" + id = str(uuid.uuid4()) + "-sql" self.sql_collection.add( documents=question_sql_json, embeddings=self.generate_embedding(question_sql_json), @@ -57,7 +57,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: return id def add_ddl(self, ddl: str, **kwargs) -> str: - id = str(uuid.uuid4())+"-ddl" + id = str(uuid.uuid4()) + "-ddl" self.ddl_collection.add( documents=ddl, embeddings=self.generate_embedding(ddl), @@ -65,11 +65,11 @@ def add_ddl(self, ddl: str, **kwargs) -> str: ) return id - def add_documentation(self, doc: str, **kwargs) -> str: - id = str(uuid.uuid4())+"-doc" + def add_documentation(self, documentation: str, **kwargs) -> str: + id = str(uuid.uuid4()) + "-doc" self.documentation_collection.add( - documents=doc, - embeddings=self.generate_embedding(doc), + documents=documentation, + embeddings=self.generate_embedding(documentation), ids=id, ) return id @@ -81,15 +81,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: if sql_data is not None: # Extract the documents and ids - documents = [json.loads(doc) for doc in sql_data['documents']] - ids = sql_data['ids'] + documents = [json.loads(doc) for doc in sql_data["documents"]] + ids = sql_data["ids"] # Create a DataFrame - df_sql = pd.DataFrame({ - 'id': ids, - 'question': [doc['question'] for doc in documents], - 'content': [doc['sql'] for doc in documents] - }) + df_sql = pd.DataFrame( + { + "id": ids, + "question": [doc["question"] for doc in documents], + "content": [doc["sql"] for doc in documents], + } + ) df_sql["training_data_type"] = "sql" @@ -99,15 +101,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: if ddl_data is not None: # Extract the documents and ids - documents = [doc for doc in ddl_data['documents']] - ids = ddl_data['ids'] + documents = [doc for doc in ddl_data["documents"]] + ids = ddl_data["ids"] # Create a DataFrame - df_ddl = pd.DataFrame({ - 'id': ids, - 'question': [None for doc in documents], - 'content': [doc for doc in documents] - }) + df_ddl = pd.DataFrame( + { + "id": ids, + "question": [None for doc in documents], + "content": [doc for doc in documents], + } + ) df_ddl["training_data_type"] = "ddl" @@ -117,15 +121,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: if doc_data is not None: # Extract the documents and ids - documents = [doc for doc in doc_data['documents']] - ids = doc_data['ids'] + documents = [doc for doc in doc_data["documents"]] + ids = doc_data["ids"] # Create a DataFrame - df_doc = pd.DataFrame({ - 'id': ids, - 'question': [None for doc in documents], - 'content': [doc for doc in documents] - }) + df_doc = pd.DataFrame( + { + "id": ids, + "question": [None for doc in documents], + "content": [doc for doc in documents], + } + ) df_doc["training_data_type"] = "documentation" diff --git a/src/vanna/marqo/marqo.py b/src/vanna/marqo/marqo.py index 168745c7..2dd50c3e 100644 --- a/src/vanna/marqo/marqo.py +++ b/src/vanna/marqo/marqo.py @@ -3,7 +3,6 @@ from abc import abstractmethod import marqo - import pandas as pd from ..base import VannaBase @@ -12,7 +11,7 @@ class Marqo_VectorStore(VannaBase): def __init__(self, config=None): VannaBase.__init__(self, config=config) - + if config is not None and "marqo_url" in config: marqo_url = config["marqo_url"] else: @@ -22,7 +21,7 @@ def __init__(self, config=None): marqo_model = config["marqo_model"] else: marqo_model = "hf/all_datasets_v4_MiniLM-L6" - + self.mq = marqo.Client(url=marqo_url) for index in ["vanna-sql", "vanna-ddl", "vanna-doc"]: @@ -33,18 +32,17 @@ def __init__(self, config=None): print(f"Marqo index {index} already exists") pass - def generate_embedding(self, data: str, **kwargs) -> list[float]: # Marqo doesn't need to generate embeddings - pass + pass def add_question_sql(self, question: str, sql: str, **kwargs) -> str: - id = str(uuid.uuid4())+"-sql" - question_sql_dict ={ - "question": question, - "sql": sql, - "_id": id, - } + id = str(uuid.uuid4()) + "-sql" + question_sql_dict = { + "question": question, + "sql": sql, + "_id": id, + } self.mq.index("vanna-sql").add_documents( [question_sql_dict], @@ -54,11 +52,11 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: return id def add_ddl(self, ddl: str, **kwargs) -> str: - id = str(uuid.uuid4())+"-ddl" - ddl_dict ={ - "ddl": ddl, - "_id": id, - } + id = str(uuid.uuid4()) + "-ddl" + ddl_dict = { + "ddl": ddl, + "_id": id, + } self.mq.index("vanna-ddl").add_documents( [ddl_dict], @@ -66,13 +64,13 @@ def add_ddl(self, ddl: str, **kwargs) -> str: ) return id - def add_documentation(self, doc: str, **kwargs) -> str: - id = str(uuid.uuid4())+"-doc" - doc_dict ={ - "doc": doc, - "_id": id, - } - + def add_documentation(self, documentation: str, **kwargs) -> str: + id = str(uuid.uuid4()) + "-doc" + doc_dict = { + "doc": documentation, + "_id": id, + } + self.mq.index("vanna-doc").add_documents( [doc_dict], tensor_fields=["doc"], @@ -80,31 +78,37 @@ def add_documentation(self, doc: str, **kwargs) -> str: return id def get_training_data(self, **kwargs) -> pd.DataFrame: - data = [] - - for hit in self.mq.index('vanna-doc').search("", limit=1000)['hits']: - data.append({ - "id": hit["_id"], - "training_data_type": "documentation", - "question": "", - "content": hit["doc"], - }) - - for hit in self.mq.index('vanna-ddl').search("", limit=1000)['hits']: - data.append({ - "id": hit["_id"], - "training_data_type": "ddl", - "question": "", - "content": hit["ddl"], - }) - - for hit in self.mq.index('vanna-sql').search("", limit=1000)['hits']: - data.append({ - "id": hit["_id"], - "training_data_type": "sql", - "question": hit["question"], - "content": hit["sql"], - }) + data = [] + + for hit in self.mq.index("vanna-doc").search("", limit=1000)["hits"]: + data.append( + { + "id": hit["_id"], + "training_data_type": "documentation", + "question": "", + "content": hit["doc"], + } + ) + + for hit in self.mq.index("vanna-ddl").search("", limit=1000)["hits"]: + data.append( + { + "id": hit["_id"], + "training_data_type": "ddl", + "question": "", + "content": hit["ddl"], + } + ) + + for hit in self.mq.index("vanna-sql").search("", limit=1000)["hits"]: + data.append( + { + "id": hit["_id"], + "training_data_type": "sql", + "question": hit["question"], + "content": hit["sql"], + } + ) df = pd.DataFrame(data) @@ -127,24 +131,24 @@ def remove_training_data(self, id: str, **kwargs) -> bool: @staticmethod def _extract_documents(data) -> list: # Check if 'hits' key is in the dictionary and if it's a list - if 'hits' in data and isinstance(data['hits'], list): + if "hits" in data and isinstance(data["hits"], list): # Iterate over each item in 'hits' - if len(data['hits']) == 0: + if len(data["hits"]) == 0: return [] # If there is a "doc" key, return the value of that key - if "doc" in data['hits'][0]: - return [hit["doc"] for hit in data['hits']] - + if "doc" in data["hits"][0]: + return [hit["doc"] for hit in data["hits"]] + # If there is a "ddl" key, return the value of that key - if "ddl" in data['hits'][0]: - return [hit["ddl"] for hit in data['hits']] - + if "ddl" in data["hits"][0]: + return [hit["ddl"] for hit in data["hits"]] + # Otherwise return the entire hit return [ - {key: value for key, value in hit.items() if not key.startswith('_')} - for hit in data['hits'] + {key: value for key, value in hit.items() if not key.startswith("_")} + for hit in data["hits"] ] else: # Return an empty list if 'hits' is not found or not a list @@ -152,15 +156,15 @@ def _extract_documents(data) -> list: def get_similar_question_sql(self, question: str, **kwargs) -> list: return Marqo_VectorStore._extract_documents( - self.mq.index('vanna-sql').search(question) + self.mq.index("vanna-sql").search(question) ) def get_related_ddl(self, question: str, **kwargs) -> list: return Marqo_VectorStore._extract_documents( - self.mq.index('vanna-ddl').search(question) + self.mq.index("vanna-ddl").search(question) ) def get_related_documentation(self, question: str, **kwargs) -> list: return Marqo_VectorStore._extract_documents( - self.mq.index('vanna-doc').search(question) + self.mq.index("vanna-doc").search(question) ) diff --git a/src/vanna/vannadb/vannadb_vector.py b/src/vanna/vannadb/vannadb_vector.py index 6378458b..095220ef 100644 --- a/src/vanna/vannadb/vannadb_vector.py +++ b/src/vanna/vannadb/vannadb_vector.py @@ -1,18 +1,21 @@ +import dataclasses +import json +from io import StringIO + +import pandas as pd +import requests + from ..base import VannaBase from ..types import ( - QuestionSQLPair, - StatusWithId, - StringData, DataFrameJSON, + Question, + QuestionSQLPair, Status, + StatusWithId, + StringData, TrainingData, - Question, ) -from io import StringIO -import pandas as pd -import requests -import json -import dataclasses + class VannaDB_VectorStore(VannaBase): def __init__(self, vanna_model: str, vanna_api_key: str, config=None): @@ -105,8 +108,8 @@ def add_ddl(self, ddl: str, **kwargs) -> str: return status.id - def add_documentation(self, doc: str, **kwargs) -> str: - params = [StringData(data=doc)] + def add_documentation(self, documentation: str, **kwargs) -> str: + params = [StringData(data=documentation)] d = self._rpc_call(method="add_documentation", params=params) @@ -167,7 +170,7 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list: training_data = self.related_training_data[question] else: training_data = self.get_related_training_data_cached(question) - + return training_data.questions def get_related_ddl(self, question: str, **kwargs) -> list: @@ -175,7 +178,7 @@ def get_related_ddl(self, question: str, **kwargs) -> list: training_data = self.related_training_data[question] else: training_data = self.get_related_training_data_cached(question) - + return training_data.ddl def get_related_documentation(self, question: str, **kwargs) -> list: @@ -183,5 +186,5 @@ def get_related_documentation(self, question: str, **kwargs) -> list: training_data = self.related_training_data[question] else: training_data = self.get_related_training_data_cached(question) - - return training_data.documentation \ No newline at end of file + + return training_data.documentation