Skip to content

Commit

Permalink
Merge pull request #156 from vanna-ai/fix-doc-vs-documentation
Browse files Browse the repository at this point in the history
Make documentation argument name consistent
  • Loading branch information
zainhoda authored Jan 17, 2024
2 parents a0c9320 + 6ce8951 commit 0684b5a
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 118 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.0.32"
version = "0.0.33"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
27 changes: 14 additions & 13 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import os
import re
import sqlite3
import traceback

from abc import ABC, abstractmethod
from typing import List, Tuple, Union
from urllib.parse import urlparse
Expand All @@ -12,7 +12,6 @@
import plotly.express as px
import plotly.graph_objects as go
import requests
import re

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
Expand Down Expand Up @@ -50,8 +49,8 @@ def generate_followup_questions(self, question: str, **kwargs) -> str:
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)
numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE)

numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")

def generate_questions(self, **kwargs) -> list[str]:
Expand All @@ -65,7 +64,7 @@ def generate_questions(self, **kwargs) -> list[str]:
"""
question_sql = self.get_similar_question_sql(question="", **kwargs)

return [q['question'] for q in question_sql]
return [q["question"] for q in question_sql]

# ----------------- Use Any Embeddings API ----------------- #
@abstractmethod
Expand Down Expand Up @@ -94,7 +93,7 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
pass

@abstractmethod
def add_documentation(self, doc: str, **kwargs) -> str:
def add_documentation(self, documentation: str, **kwargs) -> str:
pass

@abstractmethod
Expand All @@ -120,12 +119,12 @@ def get_sql_prompt(

@abstractmethod
def get_followup_questions_prompt(
self,
question: str,
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
doc_list: list,
**kwargs,
):
pass

Expand Down Expand Up @@ -829,9 +828,11 @@ def get_plotly_figure(
fig = ldict.get("fig", None)
except Exception as e:
# Inspect data types
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()

numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
categorical_cols = df.select_dtypes(
include=["object", "category"]
).columns.tolist()

# Decision-making for plot type
if len(numeric_cols) >= 2:
# Use the first two numeric columns for a scatter plot
Expand Down
62 changes: 34 additions & 28 deletions src/vanna/chromadb/chromadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from abc import abstractmethod

import chromadb
import pandas as pd
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import pandas as pd

from ..base import VannaBase

Expand Down Expand Up @@ -47,7 +47,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
"sql": sql,
}
)
id = str(uuid.uuid4())+"-sql"
id = str(uuid.uuid4()) + "-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
Expand All @@ -57,19 +57,19 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
return id

def add_ddl(self, ddl: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-ddl"
id = str(uuid.uuid4()) + "-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=id,
)
return id

def add_documentation(self, doc: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-doc"
def add_documentation(self, documentation: str, **kwargs) -> str:
id = str(uuid.uuid4()) + "-doc"
self.documentation_collection.add(
documents=doc,
embeddings=self.generate_embedding(doc),
documents=documentation,
embeddings=self.generate_embedding(documentation),
ids=id,
)
return id
Expand All @@ -81,15 +81,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

if sql_data is not None:
# Extract the documents and ids
documents = [json.loads(doc) for doc in sql_data['documents']]
ids = sql_data['ids']
documents = [json.loads(doc) for doc in sql_data["documents"]]
ids = sql_data["ids"]

# Create a DataFrame
df_sql = pd.DataFrame({
'id': ids,
'question': [doc['question'] for doc in documents],
'content': [doc['sql'] for doc in documents]
})
df_sql = pd.DataFrame(
{
"id": ids,
"question": [doc["question"] for doc in documents],
"content": [doc["sql"] for doc in documents],
}
)

df_sql["training_data_type"] = "sql"

Expand All @@ -99,15 +101,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

if ddl_data is not None:
# Extract the documents and ids
documents = [doc for doc in ddl_data['documents']]
ids = ddl_data['ids']
documents = [doc for doc in ddl_data["documents"]]
ids = ddl_data["ids"]

# Create a DataFrame
df_ddl = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})
df_ddl = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
)

df_ddl["training_data_type"] = "ddl"

Expand All @@ -117,15 +121,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

if doc_data is not None:
# Extract the documents and ids
documents = [doc for doc in doc_data['documents']]
ids = doc_data['ids']
documents = [doc for doc in doc_data["documents"]]
ids = doc_data["ids"]

# Create a DataFrame
df_doc = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})
df_doc = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
)

df_doc["training_data_type"] = "documentation"

Expand Down
Loading

0 comments on commit 0684b5a

Please sign in to comment.