From f2a3b4a1e56e7e63396a9aa5b5eced7af5da0f0a Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 2 Oct 2024 12:00:53 +0100 Subject: [PATCH] feat: add temp dataset parameter to BQ and FS vectorstores --- .../bq_storage_vectorstores/_base.py | 24 +++++++------------ .../bq_storage_vectorstores/bigquery.py | 2 +- .../test_feature_store_bq_vectorstore.py | 12 ++++++---- .../test_feature_store_fs_vectorstore.py | 2 +- 4 files changed, 18 insertions(+), 22 deletions(-) diff --git a/libs/community/langchain_google_community/bq_storage_vectorstores/_base.py b/libs/community/langchain_google_community/bq_storage_vectorstores/_base.py index d72bad29..cdb5d4ef 100644 --- a/libs/community/langchain_google_community/bq_storage_vectorstores/_base.py +++ b/libs/community/langchain_google_community/bq_storage_vectorstores/_base.py @@ -47,6 +47,8 @@ class BaseBigQueryVectorStore(VectorStore, BaseModel, ABC): content_field: Name of the column storing document content (default: "content"). embedding_field: Name of the column storing text embeddings (default: "embedding"). + temp_dataset_name: Name of the BigQuery dataset to be used to upload temporary + BQ tables. If None, will default to "{dataset_name}_temp". doc_id_field: Name of the column storing document IDs (default: "doc_id"). credentials: Optional Google Cloud credentials object. embedding_dimension: Dimension of the embedding vectors (inferred if not @@ -68,6 +70,7 @@ class BaseBigQueryVectorStore(VectorStore, BaseModel, ABC): content_field: str = "content" embedding_field: str = "embedding" doc_id_field: str = "doc_id" + temp_dataset_name: Optional[str] = None credentials: Optional[Any] = None embedding_dimension: Optional[int] = None extra_fields: Union[Dict[str, str], None] = None @@ -138,17 +141,20 @@ def validate_vals(self) -> Self: ) if self.embedding_dimension is None: self.embedding_dimension = len(self.embedding.embed_query("test")) + if self.temp_dataset_name is None: + self.temp_dataset_name = f"{self.dataset_name}_temp" full_table_id = f"{self.project_id}.{self.dataset_name}.{self.table_name}" self._full_table_id = full_table_id - temp_dataset_id = f"{self.dataset_name}_temp" if not check_bq_dataset_exists( client=self._bq_client, dataset_id=self.dataset_name ): self._bq_client.create_dataset(dataset=self.dataset_name, exists_ok=True) if not check_bq_dataset_exists( - client=self._bq_client, dataset_id=temp_dataset_id + client=self._bq_client, dataset_id=self.temp_dataset_name ): - self._bq_client.create_dataset(dataset=temp_dataset_id, exists_ok=True) + self._bq_client.create_dataset( + dataset=self.temp_dataset_name, exists_ok=True + ) table_ref = bigquery.TableReference.from_string(full_table_id) self._bq_client.create_table(table_ref, exists_ok=True) self._logger.info( @@ -235,18 +241,6 @@ def _validate_bq_table(self) -> Any: self._logger.debug(f"Table {self.full_table_id} validated") return table_ref - def _initialize_bq_table(self) -> Any: - """Validates or creates the BigQuery table.""" - from google.cloud import bigquery # type: ignore[attr-defined] - - self._bq_client.create_dataset(dataset=self.dataset_name, exists_ok=True) - self._bq_client.create_dataset( - dataset=f"{self.dataset_name}_temp", exists_ok=True - ) - table_ref = bigquery.TableReference.from_string(self.full_table_id) - self._bq_client.create_table(table_ref, exists_ok=True) - return table_ref - def add_texts( # type: ignore[override] self, texts: List[str], diff --git a/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py b/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py index af2fa62f..15e19436 100644 --- a/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py +++ b/libs/community/langchain_google_community/bq_storage_vectorstores/bigquery.py @@ -393,7 +393,7 @@ def _create_temp_bq_table( df[self.embedding_field] = embeddings table_id = ( f"{self.project_id}." - f"{self.dataset_name}_temp." + f"{self.temp_dataset_name}." f"{self.table_name}_{uuid.uuid4().hex}" ) diff --git a/libs/community/tests/integration_tests/feature_store/test_feature_store_bq_vectorstore.py b/libs/community/tests/integration_tests/feature_store/test_feature_store_bq_vectorstore.py index cc4c0637..1b22de77 100644 --- a/libs/community/tests/integration_tests/feature_store/test_feature_store_bq_vectorstore.py +++ b/libs/community/tests/integration_tests/feature_store/test_feature_store_bq_vectorstore.py @@ -12,6 +12,7 @@ from tests.integration_tests.fake import FakeEmbeddings TEST_DATASET = "langchain_test_dataset" +TEST_TEMP_DATASET = "temp_langchain_test_dataset" TEST_TABLE_NAME = f"langchain_test_table{str(random.randint(1,100000))}" TEST_FOS_NAME = "langchain_test_fos" EMBEDDING_SIZE = 768 @@ -34,7 +35,8 @@ def store_bq_vectorstore(request: pytest.FixtureRequest) -> BigQueryVectorStore: project_id=os.environ.get("PROJECT_ID", None), # type: ignore[arg-type] embedding=embedding_model, location="us-central1", - dataset_name=TestBigQueryVectorStore_bq_vectorstore.dataset_name, + dataset_name=TEST_DATASET, + temp_dataset_name=TEST_TEMP_DATASET, table_name=TEST_TABLE_NAME, ) TestBigQueryVectorStore_bq_vectorstore.store_bq_vectorstore.add_texts( @@ -44,7 +46,7 @@ def store_bq_vectorstore(request: pytest.FixtureRequest) -> BigQueryVectorStore: def teardown() -> None: bigquery.Client(location="us-central1").delete_dataset( - TestBigQueryVectorStore_bq_vectorstore.dataset_name, + TEST_DATASET, delete_contents=True, not_found_ok=True, ) @@ -73,14 +75,15 @@ def existing_store_bq_vectorstore( project_id=os.environ.get("PROJECT_ID", None), # type: ignore[arg-type] embedding=embedding_model, location="us-central1", - dataset_name=TestBigQueryVectorStore_bq_vectorstore.dataset_name, + dataset_name=TEST_DATASET, + temp_dataset_name=TEST_TEMP_DATASET, table_name=TEST_TABLE_NAME, ) ) def teardown() -> None: bigquery.Client(location="us-central1").delete_dataset( - TestBigQueryVectorStore_bq_vectorstore.dataset_name, + TEST_DATASET, delete_contents=True, not_found_ok=True, ) @@ -92,7 +95,6 @@ def teardown() -> None: class TestBigQueryVectorStore_bq_vectorstore: """BigQueryVectorStore tests class.""" - dataset_name = TEST_DATASET store_bq_vectorstore: BigQueryVectorStore existing_store_bq_vectorstore: BigQueryVectorStore texts = ["apple", "ice cream", "Saturn", "candy", "banana"] diff --git a/libs/community/tests/integration_tests/feature_store/test_feature_store_fs_vectorstore.py b/libs/community/tests/integration_tests/feature_store/test_feature_store_fs_vectorstore.py index b85b6957..2d566c91 100644 --- a/libs/community/tests/integration_tests/feature_store/test_feature_store_fs_vectorstore.py +++ b/libs/community/tests/integration_tests/feature_store/test_feature_store_fs_vectorstore.py @@ -13,7 +13,7 @@ # Feature Online store is static to avoid cold start setup time during testing TEST_DATASET = "langchain_test_dataset" TEST_TABLE_NAME = f"langchain_test_table{str(random.randint(1,100000))}" -TEST_FOS_NAME = "langchain_test_fos" +TEST_FOS_NAME = "langchain_test_fos2" TEST_VIEW_NAME = f"test{str(random.randint(1,100000))}" EMBEDDING_SIZE = 768