diff --git a/src/vanna/pgvector/pgvector.py b/src/vanna/pgvector/pgvector.py index cf0c2a23..3cddeb46 100644 --- a/src/vanna/pgvector/pgvector.py +++ b/src/vanna/pgvector/pgvector.py @@ -28,20 +28,20 @@ def __init__(self, config=None): if config and "embedding_function" in config: self.embedding_function = config.get("embedding_function") else: - from sentence_transformers import SentenceTransformer - self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2") + from langchain_huggingface import HuggingFaceEmbeddings + self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") - self.sql_vectorstore = PGVector( + self.sql_collection = PGVector( embeddings=self.embedding_function, collection_name="sql", connection=self.connection_string, ) - self.ddl_vectorstore = PGVector( + self.ddl_collection = PGVector( embeddings=self.embedding_function, collection_name="ddl", connection=self.connection_string, ) - self.documentation_vectorstore = PGVector( + self.documentation_collection = PGVector( embeddings=self.embedding_function, collection_name="documentation", connection=self.connection_string, @@ -94,16 +94,16 @@ def get_collection(self, collection_name): case _: raise ValueError("Specified collection does not exist.") - async def get_similar_question_sql(self, question: str) -> list: + def get_similar_question_sql(self, question: str) -> list: documents = self.sql_collection.similarity_search(query=question, k=self.n_results) return [ast.literal_eval(document.page_content) for document in documents] - async def get_related_ddl(self, question: str, **kwargs) -> list: - documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results) + def get_related_ddl(self, question: str, **kwargs) -> list: + documents = self.ddl_collection.similarity_search(query=question, k=self.n_results) return [document.page_content for document in documents] - async def get_related_documentation(self, question: str, **kwargs) -> list: - documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results) + def get_related_documentation(self, question: str, **kwargs) -> list: + documents = self.documentation_collection.similarity_search(query=question, k=self.n_results) return [document.page_content for document in documents] def train( @@ -251,15 +251,3 @@ def remove_collection(self, collection_name: str) -> bool: def generate_embedding(self, *args, **kwargs): pass - - def submit_prompt(self, *args, **kwargs): - pass - - def system_message(self, message: str) -> any: - return {"role": "system", "content": message} - - def user_message(self, message: str) -> any: - return {"role": "user", "content": message} - - def assistant_message(self, message: str) -> any: - return {"role": "assistant", "content": message} diff --git a/tests/test_pgvector.py b/tests/test_pgvector.py index 8c9344a5..2d82f273 100644 --- a/tests/test_pgvector.py +++ b/tests/test_pgvector.py @@ -3,22 +3,47 @@ from dotenv import load_dotenv # from vanna.pgvector import PG_VectorStore +# from vanna.openai import OpenAI_Chat +# assume .env file placed next to file with provided env vars load_dotenv() -# Removing thiese tests for now until the dependencies are sorted out # def get_vanna_connection_string(): # server = os.environ.get("PG_SERVER") # driver = "psycopg" -# port = 5434 +# port = os.environ.get("PG_PORT", 5432) # database = os.environ.get("PG_DATABASE") # username = os.environ.get("PG_USERNAME") # password = os.environ.get("PG_PASSWORD") -# return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}" +# def test_pgvector_e2e(): +# # configure Vanna to use OpenAI and PGVector +# class VannaCustom(PG_VectorStore, OpenAI_Chat): +# def __init__(self, config=None): +# PG_VectorStore.__init__(self, config=config) +# OpenAI_Chat.__init__(self, config=config) + +# vn = VannaCustom(config={ +# 'api_key': os.environ['OPENAI_API_KEY'], +# 'model': 'gpt-3.5-turbo', +# "connection_string": get_vanna_connection_string(), +# }) +# # connect to SQLite database +# vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +# # train Vanna on DDLs +# df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") +# for ddl in df_ddl['sql'].to_list(): +# vn.train(ddl=ddl) +# assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default + +# question = "What are the top 7 customers by sales?" +# sql = vn.generate_sql(question) +# df = vn.run_sql(sql) +# assert len(df) == 7 + +# # test if Vanna can generate an answer +# answer = vn.ask(question) +# assert answer is not None -# def test_pgvector(): -# connection_string = get_vanna_connection_string() -# pgclient = PG_VectorStore(config={"connection_string": connection_string}) -# assert pgclient is not None