Skip to content

Commit

Permalink
feat: add temp dataset parameter to BQ and FS vectorstores
Browse files Browse the repository at this point in the history
  • Loading branch information
eliasecchig committed Oct 2, 2024
1 parent 88ffe40 commit faeb4c1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class BaseBigQueryVectorStore(VectorStore, BaseModel, ABC):
content_field: str = "content"
embedding_field: str = "embedding"
doc_id_field: str = "doc_id"
temp_dataset_name: Optional[str] = None
credentials: Optional[Any] = None
embedding_dimension: Optional[int] = None
extra_fields: Union[Dict[str, str], None] = None
Expand Down Expand Up @@ -138,17 +139,20 @@ def validate_vals(self) -> Self:
)
if self.embedding_dimension is None:
self.embedding_dimension = len(self.embedding.embed_query("test"))
if self.temp_dataset_name is None:
self.temp_dataset_name = f"{self.dataset_name}_temp"
full_table_id = f"{self.project_id}.{self.dataset_name}.{self.table_name}"
self._full_table_id = full_table_id
temp_dataset_id = f"{self.dataset_name}_temp"
if not check_bq_dataset_exists(
client=self._bq_client, dataset_id=self.dataset_name
):
self._bq_client.create_dataset(dataset=self.dataset_name, exists_ok=True)
if not check_bq_dataset_exists(
client=self._bq_client, dataset_id=temp_dataset_id
client=self._bq_client, dataset_id=self.temp_dataset_name
):
self._bq_client.create_dataset(dataset=temp_dataset_id, exists_ok=True)
self._bq_client.create_dataset(
dataset=self.temp_dataset_name, exists_ok=True
)
table_ref = bigquery.TableReference.from_string(full_table_id)
self._bq_client.create_table(table_ref, exists_ok=True)
self._logger.info(
Expand Down Expand Up @@ -235,18 +239,6 @@ def _validate_bq_table(self) -> Any:
self._logger.debug(f"Table {self.full_table_id} validated")
return table_ref

def _initialize_bq_table(self) -> Any:
"""Validates or creates the BigQuery table."""
from google.cloud import bigquery # type: ignore[attr-defined]

self._bq_client.create_dataset(dataset=self.dataset_name, exists_ok=True)
self._bq_client.create_dataset(
dataset=f"{self.dataset_name}_temp", exists_ok=True
)
table_ref = bigquery.TableReference.from_string(self.full_table_id)
self._bq_client.create_table(table_ref, exists_ok=True)
return table_ref

def add_texts( # type: ignore[override]
self,
texts: List[str],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _create_temp_bq_table(
df[self.embedding_field] = embeddings
table_id = (
f"{self.project_id}."
f"{self.dataset_name}_temp."
f"{self.temp_dataset_name}."
f"{self.table_name}_{uuid.uuid4().hex}"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tests.integration_tests.fake import FakeEmbeddings

TEST_DATASET = "langchain_test_dataset"
TEST_TEMP_DATASET = "temp_langchain_test_dataset"
TEST_TABLE_NAME = f"langchain_test_table{str(random.randint(1,100000))}"
TEST_FOS_NAME = "langchain_test_fos"
EMBEDDING_SIZE = 768
Expand All @@ -34,7 +35,8 @@ def store_bq_vectorstore(request: pytest.FixtureRequest) -> BigQueryVectorStore:
project_id=os.environ.get("PROJECT_ID", None), # type: ignore[arg-type]
embedding=embedding_model,
location="us-central1",
dataset_name=TestBigQueryVectorStore_bq_vectorstore.dataset_name,
dataset_name=TEST_DATASET,
temp_dataset_name=TEST_TEMP_DATASET,
table_name=TEST_TABLE_NAME,
)
TestBigQueryVectorStore_bq_vectorstore.store_bq_vectorstore.add_texts(
Expand All @@ -44,7 +46,7 @@ def store_bq_vectorstore(request: pytest.FixtureRequest) -> BigQueryVectorStore:

def teardown() -> None:
bigquery.Client(location="us-central1").delete_dataset(
TestBigQueryVectorStore_bq_vectorstore.dataset_name,
TEST_DATASET,
delete_contents=True,
not_found_ok=True,
)
Expand Down Expand Up @@ -73,14 +75,15 @@ def existing_store_bq_vectorstore(
project_id=os.environ.get("PROJECT_ID", None), # type: ignore[arg-type]
embedding=embedding_model,
location="us-central1",
dataset_name=TestBigQueryVectorStore_bq_vectorstore.dataset_name,
dataset_name=TEST_DATASET,
temp_dataset_name=TEST_TEMP_DATASET,
table_name=TEST_TABLE_NAME,
)
)

def teardown() -> None:
bigquery.Client(location="us-central1").delete_dataset(
TestBigQueryVectorStore_bq_vectorstore.dataset_name,
TEST_DATASET,
delete_contents=True,
not_found_ok=True,
)
Expand All @@ -92,7 +95,6 @@ def teardown() -> None:
class TestBigQueryVectorStore_bq_vectorstore:
"""BigQueryVectorStore tests class."""

dataset_name = TEST_DATASET
store_bq_vectorstore: BigQueryVectorStore
existing_store_bq_vectorstore: BigQueryVectorStore
texts = ["apple", "ice cream", "Saturn", "candy", "banana"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Feature Online store is static to avoid cold start setup time during testing
TEST_DATASET = "langchain_test_dataset"
TEST_TABLE_NAME = f"langchain_test_table{str(random.randint(1,100000))}"
TEST_FOS_NAME = "langchain_test_fos"
TEST_FOS_NAME = "langchain_test_fos2"
TEST_VIEW_NAME = f"test{str(random.randint(1,100000))}"
EMBEDDING_SIZE = 768

Expand Down

0 comments on commit faeb4c1

Please sign in to comment.