From be90a84284e9c5019eae04ab23ddf049b70fe24d Mon Sep 17 00:00:00 2001 From: Cloud Teleport Date: Mon, 10 Jun 2024 18:00:28 +0000 Subject: [PATCH] COPYBARA_INTEGRATE_REVIEW=https://github.com/GoogleCloudPlatform/DataflowTemplates/pull/1365 from meagar:main 9217e2107d23bb4fd76541936b56018180d457b4 PiperOrigin-RevId: 641958196 --- ...igtable_Change_Streams_to_Vector_Search.md | 318 +++++++++++ v2/googlecloud-to-googlecloud/pom.xml | 4 + ...bleChangeStreamsToVectorSearchOptions.java | 216 +++++++ .../BigtableChangeStreamsToVectorSearch.java | 271 +++++++++ ...eStreamMutationToDatapointOperationFn.java | 228 ++++++++ .../DatapointOperationFn.java | 50 ++ .../RemoveDatapointsFn.java | 55 ++ .../UpsertDatapointsFn.java | 55 ++ .../Utils.java | 88 +++ .../package-info.java | 18 + ...ams-to-vector-embeddings-command-spec.json | 7 + ...BigtableChangeStreamsToVectorSearchIT.java | 534 ++++++++++++++++++ .../UtilsTest.java | 140 +++++ .../VectorSearchResourceManager.java | 362 ++++++++++++ .../teleport/v2/templates/DataTypesIt.java | 18 + .../v2/templates/SourceDbToSpannerITBase.java | 1 + .../test/resources/DataTypesIt/data-types.sql | 48 ++ .../resources/DataTypesIt/spanner-schema.sql | 30 + .../terraform/samples/README.md | 20 + .../terraform/samples/multiple-jobs/README.md | 72 +++ .../terraform/samples/multiple-jobs/main.tf | 63 +++ .../samples/multiple-jobs/outputs.tf | 10 + .../samples/multiple-jobs/terraform.tf | 13 + .../samples/multiple-jobs/terraform.tfvars | 63 +++ .../multiple-jobs/terraform_simple.tfvars | 34 ++ .../samples/multiple-jobs/variables.tf | 49 ++ .../SpannerChangeStreamToGcsMultiShardIT.java | 81 +++ .../avro/GenericRecordTypeConvertor.java | 4 + .../avro/GenericRecordTypeConvertorTest.java | 13 + 29 files changed, 2865 insertions(+) create mode 100644 v2/googlecloud-to-googlecloud/README_Bigtable_Change_Streams_to_Vector_Search.md create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/options/BigtableChangeStreamsToVectorSearchOptions.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearch.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/ChangeStreamMutationToDatapointOperationFn.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/DatapointOperationFn.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/RemoveDatapointsFn.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UpsertDatapointsFn.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/Utils.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/package-info.java create mode 100644 v2/googlecloud-to-googlecloud/src/main/resources/bigtable-changestreams-to-vector-embeddings-command-spec.json create mode 100644 v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearchIT.java create mode 100644 v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UtilsTest.java create mode 100644 v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/VectorSearchResourceManager.java create mode 100644 v2/sourcedb-to-spanner/terraform/samples/README.md create mode 100644 v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/README.md create mode 100644 v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/main.tf create mode 100644 v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/outputs.tf create mode 100644 v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tf create mode 100644 v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tfvars create mode 100644 v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform_simple.tfvars create mode 100644 v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/variables.tf diff --git a/v2/googlecloud-to-googlecloud/README_Bigtable_Change_Streams_to_Vector_Search.md b/v2/googlecloud-to-googlecloud/README_Bigtable_Change_Streams_to_Vector_Search.md new file mode 100644 index 0000000000..1e4a899042 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/README_Bigtable_Change_Streams_to_Vector_Search.md @@ -0,0 +1,318 @@ + +Bigtable Change Streams to Vector Search template +--- +Streaming pipeline. Streams Bigtable data change records and writes them into +Vertex AI Vector Search using Dataflow Runner V2. + + +:memo: This is a Google-provided template! Please +check [Provided templates documentation](https://cloud.google.com/dataflow/docs/guides/templates/provided/bigtable-change-streams-to-vector-search) +on how to use it without having to build from sources using [Create job from template](https://console.cloud.google.com/dataflow/createjob?template=Bigtable_Change_Streams_to_Vector_Search). + +:bulb: This is a generated documentation based +on [Metadata Annotations](https://github.com/GoogleCloudPlatform/DataflowTemplates#metadata-annotations) +. Do not change this file directly. + +## Parameters + +### Required parameters + +* **embeddingColumn** : The fully qualified column name where the embeddings are stored. In the format cf:col. +* **embeddingByteSize** : The byte size of each entry in the embeddings array. Use 4 for Float, and 8 for Double. Defaults to: 4. +* **vectorSearchIndex** : The Vector Search Index where changes will be streamed, in the format 'projects/{projectID}/locations/{region}/indexes/{indexID}' (no leading or trailing spaces) (Example: projects/123/locations/us-east1/indexes/456). +* **bigtableChangeStreamAppProfile** : The application profile is used to distinguish workload in Cloud Bigtable. +* **bigtableReadInstanceId** : The ID of the Cloud Bigtable instance that contains the table. +* **bigtableReadTableId** : The Cloud Bigtable table to read from. + +### Optional parameters + +* **bigtableMetadataTableTableId** : Table ID used for creating the metadata table. +* **crowdingTagColumn** : The fully qualified column name where the crowding tag is stored. In the format cf:col. +* **allowRestrictsMappings** : The comma separated fully qualified column names of the columns that should be used as the `allow` restricts, with their alias. In the format cf:col->alias. +* **denyRestrictsMappings** : The comma separated fully qualified column names of the columns that should be used as the `deny` restricts, with their alias. In the format cf:col->alias. +* **intNumericRestrictsMappings** : The comma separated fully qualified column names of the columns that should be used as integer `numeric_restricts`, with their alias. In the format cf:col->alias. +* **floatNumericRestrictsMappings** : The comma separated fully qualified column names of the columns that should be used as float (4 bytes) `numeric_restricts`, with their alias. In the format cf:col->alias. +* **doubleNumericRestrictsMappings** : The comma separated fully qualified column names of the columns that should be used as double (8 bytes) `numeric_restricts`, with their alias. In the format cf:col->alias. +* **upsertMaxBatchSize** : The maximum number of upserts to buffer before upserting the batch to the Vector Search Index. Batches will be sent when there are either upsertBatchSize records ready, or any record has been waiting upsertBatchDelay time has passed. (Example: 10). +* **upsertMaxBufferDuration** : The maximum delay before a batch of upserts is sent to Vector Search.Batches will be sent when there are either upsertBatchSize records ready, or any record has been waiting upsertBatchDelay time has passed. Allowed formats are: Ns (for seconds, example: 5s), Nm (for minutes, example: 12m), Nh (for hours, example: 2h). (Example: 10s). Defaults to: 10s. +* **deleteMaxBatchSize** : The maximum number of deletes to buffer before deleting the batch from the Vector Search Index. Batches will be sent when there are either deleteBatchSize records ready, or any record has been waiting deleteBatchDelay time has passed. (Example: 10). +* **deleteMaxBufferDuration** : The maximum delay before a batch of deletes is sent to Vector Search.Batches will be sent when there are either deleteBatchSize records ready, or any record has been waiting deleteBatchDelay time has passed. Allowed formats are: Ns (for seconds, example: 5s), Nm (for minutes, example: 12m), Nh (for hours, example: 2h). (Example: 10s). Defaults to: 10s. +* **dlqDirectory** : The path to store any unprocessed records with the reason they failed to be processed. Default is a directory under the Dataflow job's temp location. The default value is enough under most conditions. +* **bigtableChangeStreamMetadataInstanceId** : The Cloud Bigtable instance to use for the change streams connector metadata table. Defaults to empty. +* **bigtableChangeStreamMetadataTableTableId** : The Cloud Bigtable change streams connector metadata table ID to use. If not provided, a Cloud Bigtable change streams connector metadata table will automatically be created during the pipeline flow. Defaults to empty. +* **bigtableChangeStreamCharset** : Bigtable change streams charset name when reading values and column qualifiers. Default is UTF-8. +* **bigtableChangeStreamStartTimestamp** : The starting DateTime, inclusive, to use for reading change streams (https://tools.ietf.org/html/rfc3339). For example, 2022-05-05T07:59:59Z. Defaults to the timestamp when the pipeline starts. +* **bigtableChangeStreamIgnoreColumnFamilies** : A comma-separated list of column family names changes to which won't be captured. Defaults to empty. +* **bigtableChangeStreamIgnoreColumns** : A comma-separated list of column names changes to which won't be captured. Defaults to empty. +* **bigtableChangeStreamName** : Allows to resume processing from the point where a previously running pipeline stopped. +* **bigtableChangeStreamResume** : When set to true< a new pipeline will resume processing from the point at which a previously running pipeline with the same bigtableChangeStreamName stopped. If pipeline with the given bigtableChangeStreamName never ran in the past, a new pipeline will fail to start. When set to false a new pipeline will be started. If pipeline with the same bigtableChangeStreamName already ran in the past for the given source, a new pipeline will fail to start. Defaults to false. +* **bigtableReadProjectId** : Project to read Cloud Bigtable data from. The default for this parameter is the project where the Dataflow pipeline is running. + + + +## Getting Started + +### Requirements + +* Java 11 +* Maven +* [gcloud CLI](https://cloud.google.com/sdk/gcloud), and execution of the + following commands: + * `gcloud auth login` + * `gcloud auth application-default login` + +:star2: Those dependencies are pre-installed if you use Google Cloud Shell! + +[![Open in Cloud Shell](http://gstatic.com/cloudssh/images/open-btn.svg)](https://console.cloud.google.com/cloudshell/editor?cloudshell_git_repo=https%3A%2F%2Fgithub.com%2FGoogleCloudPlatform%2FDataflowTemplates.git&cloudshell_open_in_editor=v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearch.java) + +### Templates Plugin + +This README provides instructions using +the [Templates Plugin](https://github.com/GoogleCloudPlatform/DataflowTemplates#templates-plugin). + +### Building Template + +This template is a Flex Template, meaning that the pipeline code will be +containerized and the container will be executed on Dataflow. Please +check [Use Flex Templates](https://cloud.google.com/dataflow/docs/guides/templates/using-flex-templates) +and [Configure Flex Templates](https://cloud.google.com/dataflow/docs/guides/templates/configuring-flex-templates) +for more information. + +#### Staging the Template + +If the plan is to just stage the template (i.e., make it available to use) by +the `gcloud` command or Dataflow "Create job from template" UI, +the `-PtemplatesStage` profile should be used: + +```shell +export PROJECT= +export BUCKET_NAME= + +mvn clean package -PtemplatesStage \ +-DskipTests \ +-DprojectId="$PROJECT" \ +-DbucketName="$BUCKET_NAME" \ +-DstagePrefix="templates" \ +-DtemplateName="Bigtable_Change_Streams_to_Vector_Search" \ +-f v2/googlecloud-to-googlecloud +``` + + +The command should build and save the template to Google Cloud, and then print +the complete location on Cloud Storage: + +``` +Flex Template was staged! gs:///templates/flex/Bigtable_Change_Streams_to_Vector_Search +``` + +The specific path should be copied as it will be used in the following steps. + +#### Running the Template + +**Using the staged template**: + +You can use the path above run the template (or share with others for execution). + +To start a job with the template at any time using `gcloud`, you are going to +need valid resources for the required parameters. + +Provided that, the following command line can be used: + +```shell +export PROJECT= +export BUCKET_NAME= +export REGION=us-central1 +export TEMPLATE_SPEC_GCSPATH="gs://$BUCKET_NAME/templates/flex/Bigtable_Change_Streams_to_Vector_Search" + +### Required +export EMBEDDING_COLUMN= +export EMBEDDING_BYTE_SIZE=4 +export VECTOR_SEARCH_INDEX= +export BIGTABLE_CHANGE_STREAM_APP_PROFILE= +export BIGTABLE_READ_INSTANCE_ID= +export BIGTABLE_READ_TABLE_ID= + +### Optional +export BIGTABLE_METADATA_TABLE_TABLE_ID= +export CROWDING_TAG_COLUMN= +export ALLOW_RESTRICTS_MAPPINGS= +export DENY_RESTRICTS_MAPPINGS= +export INT_NUMERIC_RESTRICTS_MAPPINGS= +export FLOAT_NUMERIC_RESTRICTS_MAPPINGS= +export DOUBLE_NUMERIC_RESTRICTS_MAPPINGS= +export UPSERT_MAX_BATCH_SIZE= +export UPSERT_MAX_BUFFER_DURATION=10s +export DELETE_MAX_BATCH_SIZE= +export DELETE_MAX_BUFFER_DURATION=10s +export DLQ_DIRECTORY="" +export BIGTABLE_CHANGE_STREAM_METADATA_INSTANCE_ID="" +export BIGTABLE_CHANGE_STREAM_METADATA_TABLE_TABLE_ID="" +export BIGTABLE_CHANGE_STREAM_CHARSET=UTF-8 +export BIGTABLE_CHANGE_STREAM_START_TIMESTAMP="" +export BIGTABLE_CHANGE_STREAM_IGNORE_COLUMN_FAMILIES="" +export BIGTABLE_CHANGE_STREAM_IGNORE_COLUMNS="" +export BIGTABLE_CHANGE_STREAM_NAME= +export BIGTABLE_CHANGE_STREAM_RESUME=false +export BIGTABLE_READ_PROJECT_ID="" + +gcloud dataflow flex-template run "bigtable-change-streams-to-vector-search-job" \ + --project "$PROJECT" \ + --region "$REGION" \ + --template-file-gcs-location "$TEMPLATE_SPEC_GCSPATH" \ + --parameters "bigtableMetadataTableTableId=$BIGTABLE_METADATA_TABLE_TABLE_ID" \ + --parameters "embeddingColumn=$EMBEDDING_COLUMN" \ + --parameters "crowdingTagColumn=$CROWDING_TAG_COLUMN" \ + --parameters "embeddingByteSize=$EMBEDDING_BYTE_SIZE" \ + --parameters "allowRestrictsMappings=$ALLOW_RESTRICTS_MAPPINGS" \ + --parameters "denyRestrictsMappings=$DENY_RESTRICTS_MAPPINGS" \ + --parameters "intNumericRestrictsMappings=$INT_NUMERIC_RESTRICTS_MAPPINGS" \ + --parameters "floatNumericRestrictsMappings=$FLOAT_NUMERIC_RESTRICTS_MAPPINGS" \ + --parameters "doubleNumericRestrictsMappings=$DOUBLE_NUMERIC_RESTRICTS_MAPPINGS" \ + --parameters "upsertMaxBatchSize=$UPSERT_MAX_BATCH_SIZE" \ + --parameters "upsertMaxBufferDuration=$UPSERT_MAX_BUFFER_DURATION" \ + --parameters "deleteMaxBatchSize=$DELETE_MAX_BATCH_SIZE" \ + --parameters "deleteMaxBufferDuration=$DELETE_MAX_BUFFER_DURATION" \ + --parameters "vectorSearchIndex=$VECTOR_SEARCH_INDEX" \ + --parameters "dlqDirectory=$DLQ_DIRECTORY" \ + --parameters "bigtableChangeStreamMetadataInstanceId=$BIGTABLE_CHANGE_STREAM_METADATA_INSTANCE_ID" \ + --parameters "bigtableChangeStreamMetadataTableTableId=$BIGTABLE_CHANGE_STREAM_METADATA_TABLE_TABLE_ID" \ + --parameters "bigtableChangeStreamAppProfile=$BIGTABLE_CHANGE_STREAM_APP_PROFILE" \ + --parameters "bigtableChangeStreamCharset=$BIGTABLE_CHANGE_STREAM_CHARSET" \ + --parameters "bigtableChangeStreamStartTimestamp=$BIGTABLE_CHANGE_STREAM_START_TIMESTAMP" \ + --parameters "bigtableChangeStreamIgnoreColumnFamilies=$BIGTABLE_CHANGE_STREAM_IGNORE_COLUMN_FAMILIES" \ + --parameters "bigtableChangeStreamIgnoreColumns=$BIGTABLE_CHANGE_STREAM_IGNORE_COLUMNS" \ + --parameters "bigtableChangeStreamName=$BIGTABLE_CHANGE_STREAM_NAME" \ + --parameters "bigtableChangeStreamResume=$BIGTABLE_CHANGE_STREAM_RESUME" \ + --parameters "bigtableReadInstanceId=$BIGTABLE_READ_INSTANCE_ID" \ + --parameters "bigtableReadTableId=$BIGTABLE_READ_TABLE_ID" \ + --parameters "bigtableReadProjectId=$BIGTABLE_READ_PROJECT_ID" +``` + +For more information about the command, please check: +https://cloud.google.com/sdk/gcloud/reference/dataflow/flex-template/run + + +**Using the plugin**: + +Instead of just generating the template in the folder, it is possible to stage +and run the template in a single command. This may be useful for testing when +changing the templates. + +```shell +export PROJECT= +export BUCKET_NAME= +export REGION=us-central1 + +### Required +export EMBEDDING_COLUMN= +export EMBEDDING_BYTE_SIZE=4 +export VECTOR_SEARCH_INDEX= +export BIGTABLE_CHANGE_STREAM_APP_PROFILE= +export BIGTABLE_READ_INSTANCE_ID= +export BIGTABLE_READ_TABLE_ID= + +### Optional +export BIGTABLE_METADATA_TABLE_TABLE_ID= +export CROWDING_TAG_COLUMN= +export ALLOW_RESTRICTS_MAPPINGS= +export DENY_RESTRICTS_MAPPINGS= +export INT_NUMERIC_RESTRICTS_MAPPINGS= +export FLOAT_NUMERIC_RESTRICTS_MAPPINGS= +export DOUBLE_NUMERIC_RESTRICTS_MAPPINGS= +export UPSERT_MAX_BATCH_SIZE= +export UPSERT_MAX_BUFFER_DURATION=10s +export DELETE_MAX_BATCH_SIZE= +export DELETE_MAX_BUFFER_DURATION=10s +export DLQ_DIRECTORY="" +export BIGTABLE_CHANGE_STREAM_METADATA_INSTANCE_ID="" +export BIGTABLE_CHANGE_STREAM_METADATA_TABLE_TABLE_ID="" +export BIGTABLE_CHANGE_STREAM_CHARSET=UTF-8 +export BIGTABLE_CHANGE_STREAM_START_TIMESTAMP="" +export BIGTABLE_CHANGE_STREAM_IGNORE_COLUMN_FAMILIES="" +export BIGTABLE_CHANGE_STREAM_IGNORE_COLUMNS="" +export BIGTABLE_CHANGE_STREAM_NAME= +export BIGTABLE_CHANGE_STREAM_RESUME=false +export BIGTABLE_READ_PROJECT_ID="" + +mvn clean package -PtemplatesRun \ +-DskipTests \ +-DprojectId="$PROJECT" \ +-DbucketName="$BUCKET_NAME" \ +-Dregion="$REGION" \ +-DjobName="bigtable-change-streams-to-vector-search-job" \ +-DtemplateName="Bigtable_Change_Streams_to_Vector_Search" \ +-Dparameters="bigtableMetadataTableTableId=$BIGTABLE_METADATA_TABLE_TABLE_ID,embeddingColumn=$EMBEDDING_COLUMN,crowdingTagColumn=$CROWDING_TAG_COLUMN,embeddingByteSize=$EMBEDDING_BYTE_SIZE,allowRestrictsMappings=$ALLOW_RESTRICTS_MAPPINGS,denyRestrictsMappings=$DENY_RESTRICTS_MAPPINGS,intNumericRestrictsMappings=$INT_NUMERIC_RESTRICTS_MAPPINGS,floatNumericRestrictsMappings=$FLOAT_NUMERIC_RESTRICTS_MAPPINGS,doubleNumericRestrictsMappings=$DOUBLE_NUMERIC_RESTRICTS_MAPPINGS,upsertMaxBatchSize=$UPSERT_MAX_BATCH_SIZE,upsertMaxBufferDuration=$UPSERT_MAX_BUFFER_DURATION,deleteMaxBatchSize=$DELETE_MAX_BATCH_SIZE,deleteMaxBufferDuration=$DELETE_MAX_BUFFER_DURATION,vectorSearchIndex=$VECTOR_SEARCH_INDEX,dlqDirectory=$DLQ_DIRECTORY,bigtableChangeStreamMetadataInstanceId=$BIGTABLE_CHANGE_STREAM_METADATA_INSTANCE_ID,bigtableChangeStreamMetadataTableTableId=$BIGTABLE_CHANGE_STREAM_METADATA_TABLE_TABLE_ID,bigtableChangeStreamAppProfile=$BIGTABLE_CHANGE_STREAM_APP_PROFILE,bigtableChangeStreamCharset=$BIGTABLE_CHANGE_STREAM_CHARSET,bigtableChangeStreamStartTimestamp=$BIGTABLE_CHANGE_STREAM_START_TIMESTAMP,bigtableChangeStreamIgnoreColumnFamilies=$BIGTABLE_CHANGE_STREAM_IGNORE_COLUMN_FAMILIES,bigtableChangeStreamIgnoreColumns=$BIGTABLE_CHANGE_STREAM_IGNORE_COLUMNS,bigtableChangeStreamName=$BIGTABLE_CHANGE_STREAM_NAME,bigtableChangeStreamResume=$BIGTABLE_CHANGE_STREAM_RESUME,bigtableReadInstanceId=$BIGTABLE_READ_INSTANCE_ID,bigtableReadTableId=$BIGTABLE_READ_TABLE_ID,bigtableReadProjectId=$BIGTABLE_READ_PROJECT_ID" \ +-f v2/googlecloud-to-googlecloud +``` + +## Terraform + +Dataflow supports the utilization of Terraform to manage template jobs, +see [dataflow_flex_template_job](https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/dataflow_flex_template_job). + +Terraform modules have been generated for most templates in this repository. This includes the relevant parameters +specific to the template. If available, they may be used instead of +[dataflow_flex_template_job](https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/dataflow_flex_template_job) +directly. + +To use the autogenerated module, execute the standard +[terraform workflow](https://developer.hashicorp.com/terraform/intro/core-workflow): + +```shell +cd v2/googlecloud-to-googlecloud/terraform/Bigtable_Change_Streams_to_Vector_Search +terraform init +terraform apply +``` + +To use +[dataflow_flex_template_job](https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/dataflow_flex_template_job) +directly: + +```terraform +provider "google-beta" { + project = var.project +} +variable "project" { + default = "" +} +variable "region" { + default = "us-central1" +} + +resource "google_dataflow_flex_template_job" "bigtable_change_streams_to_vector_search" { + + provider = google-beta + container_spec_gcs_path = "gs://dataflow-templates-${var.region}/latest/flex/Bigtable_Change_Streams_to_Vector_Search" + name = "bigtable-change-streams-to-vector-search" + region = var.region + parameters = { + embeddingColumn = "" + embeddingByteSize = "4" + vectorSearchIndex = "projects/123/locations/us-east1/indexes/456" + bigtableChangeStreamAppProfile = "" + bigtableReadInstanceId = "" + bigtableReadTableId = "" + # bigtableMetadataTableTableId = "" + # crowdingTagColumn = "" + # allowRestrictsMappings = "" + # denyRestrictsMappings = "" + # intNumericRestrictsMappings = "" + # floatNumericRestrictsMappings = "" + # doubleNumericRestrictsMappings = "" + # upsertMaxBatchSize = "10" + # upsertMaxBufferDuration = "10s" + # deleteMaxBatchSize = "10" + # deleteMaxBufferDuration = "10s" + # dlqDirectory = "" + # bigtableChangeStreamMetadataInstanceId = "" + # bigtableChangeStreamMetadataTableTableId = "" + # bigtableChangeStreamCharset = "UTF-8" + # bigtableChangeStreamStartTimestamp = "" + # bigtableChangeStreamIgnoreColumnFamilies = "" + # bigtableChangeStreamIgnoreColumns = "" + # bigtableChangeStreamName = "" + # bigtableChangeStreamResume = "false" + # bigtableReadProjectId = "" + } +} +``` diff --git a/v2/googlecloud-to-googlecloud/pom.xml b/v2/googlecloud-to-googlecloud/pom.xml index 0b5f228b2c..31108deac6 100644 --- a/v2/googlecloud-to-googlecloud/pom.xml +++ b/v2/googlecloud-to-googlecloud/pom.xml @@ -257,5 +257,9 @@ org.apache.beam beam-sdks-java-extensions-python + + com.google.cloud + google-cloud-aiplatform + diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/options/BigtableChangeStreamsToVectorSearchOptions.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/options/BigtableChangeStreamsToVectorSearchOptions.java new file mode 100644 index 0000000000..8590d15373 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/options/BigtableChangeStreamsToVectorSearchOptions.java @@ -0,0 +1,216 @@ +/* + * Copyright (C) 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.options; + +import com.google.cloud.teleport.metadata.TemplateParameter; +import com.google.cloud.teleport.v2.bigtable.options.BigtableCommonOptions.ReadChangeStreamOptions; +import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.sdk.options.Default; + +/** + * The {@link BigtableChangeStreamsToVectorSearchOptions} class provides the custom execution + * options passed by the executor at the command-line. + */ +public interface BigtableChangeStreamsToVectorSearchOptions + extends DataflowPipelineOptions, ReadChangeStreamOptions { + + @TemplateParameter.Text( + order = 1, + optional = true, + description = "Bigtable Metadata Table Id", + helpText = "Table ID used for creating the metadata table.") + String getBigtableMetadataTableTableId(); + + void setBigtableMetadataTableTableId(String bigtableMetadataTableTableId); + + @TemplateParameter.Text( + order = 2, + description = "Embedding column", + helpText = + "The fully qualified column name where the embeddings are stored. In the format cf:col.") + String getEmbeddingColumn(); + + @SuppressWarnings("unused") + void setEmbeddingColumn(String value); + + @TemplateParameter.Text( + order = 3, + optional = true, + description = "Crowding tag column", + helpText = + "The fully qualified column name where the crowding tag is stored. In the format cf:col.") + String getCrowdingTagColumn(); + + @SuppressWarnings("unused") + void setCrowdingTagColumn(String value); + + @TemplateParameter.Integer( + order = 4, + description = "The byte size of the embeddings array. Can be 4 or 8.", + helpText = + "The byte size of each entry in the embeddings array. Use 4 for Float, and 8 for Double.") + @Default.Integer(4) + Integer getEmbeddingByteSize(); + + @SuppressWarnings("unused") + void setEmbeddingByteSize(Integer value); + + @TemplateParameter.Text( + order = 5, + optional = true, + description = "Allow restricts mappings", + helpText = + "The comma separated fully qualified column names of the columns that should be used as the `allow` restricts, with their alias. In the format cf:col->alias.") + String getAllowRestrictsMappings(); + + @SuppressWarnings("unused") + void setAllowRestrictsMappings(String value); + + @TemplateParameter.Text( + order = 6, + optional = true, + description = "Deny restricts mappings", + helpText = + "The comma separated fully qualified column names of the columns that should be used as the `deny` restricts, with their alias. In the format cf:col->alias.") + String getDenyRestrictsMappings(); + + @SuppressWarnings("unused") + void setDenyRestrictsMappings(String value); + + @TemplateParameter.Text( + order = 7, + optional = true, + description = "Integer numeric restricts mappings", + helpText = + "The comma separated fully qualified column names of the columns that should be used as integer `numeric_restricts`, with their alias. In the format cf:col->alias.") + String getIntNumericRestrictsMappings(); + + @SuppressWarnings("unused") + void setIntNumericRestrictsMappings(String value); + + @TemplateParameter.Text( + order = 8, + optional = true, + description = "Float numeric restricts mappings", + helpText = + "The comma separated fully qualified column names of the columns that should be used as float (4 bytes) `numeric_restricts`, with their alias. In the format cf:col->alias.") + String getFloatNumericRestrictsMappings(); + + @SuppressWarnings("unused") + void setFloatNumericRestrictsMappings(String value); + + @TemplateParameter.Text( + order = 9, + optional = true, + description = "Double numeric restricts mappings", + helpText = + "The comma separated fully qualified column names of the columns that should be used as double (8 bytes) `numeric_restricts`, with their alias. In the format cf:col->alias.") + String getDoubleNumericRestrictsMappings(); + + @SuppressWarnings("unused") + void setDoubleNumericRestrictsMappings(String value); + + @TemplateParameter.Integer( + order = 10, + optional = true, + description = "Maximum batch size for upserts for Vector Search", + helpText = + "The maximum number of upserts to buffer before upserting the batch to the Vector Search Index. " + + "Batches will be sent when there are either upsertBatchSize records ready, or any record has been " + + "waiting upsertBatchDelay time has passed.", + example = "10") + @Default.Integer(10) + int getUpsertMaxBatchSize(); + + @SuppressWarnings("unused") + void setUpsertMaxBatchSize(int batchSize); + + @TemplateParameter.Duration( + order = 11, + optional = true, + description = + "Maximum duration an upsert can wait in a buffer before its batch is submitted, regardless of batch size", + helpText = + "The maximum delay before a batch of upserts is sent to Vector Search." + + "Batches will be sent when there are either upsertBatchSize records ready, or any record has been " + + "waiting upsertBatchDelay time has passed. " + + "Allowed formats are: Ns (for seconds, example: 5s), Nm (for minutes, example: 12m), Nh (for hours, example: 2h).", + example = "10s") + @Default.String("10s") + String getUpsertMaxBufferDuration(); + + @SuppressWarnings("unused") + void setUpsertMaxBufferDuration(String maxBufferDuration); + + @TemplateParameter.Integer( + order = 12, + optional = true, + description = "Maximum batch size for deletes for Vector Search", + helpText = + "The maximum number of deletes to buffer before deleting the batch from the Vector Search Index. " + + "Batches will be sent when there are either deleteBatchSize records ready, or any record has been " + + "waiting deleteBatchDelay time has passed.", + example = "10") + @Default.Integer(10) + int getDeleteMaxBatchSize(); + + @SuppressWarnings("unused") + void setDeleteMaxBatchSize(int batchSize); + + @TemplateParameter.Duration( + order = 13, + optional = true, + description = + "Maximum duration a delete can wait in a buffer before its batch is submitted, regardless of batch size", + helpText = + "The maximum delay before a batch of deletes is sent to Vector Search." + + "Batches will be sent when there are either deleteBatchSize records ready, or any record has been " + + "waiting deleteBatchDelay time has passed. " + + "Allowed formats are: Ns (for seconds, example: 5s), Nm (for minutes, example: 12m), Nh (for hours, example: 2h).", + example = "10s") + @Default.String("10s") + String getDeleteMaxBufferDuration(); + + @SuppressWarnings("unused") + void setDeleteMaxBufferDuration(String maxBufferDuration); + + @TemplateParameter.Text( + order = 14, + optional = false, + description = "Vector Search Index Path", + helpText = + "The Vector Search Index where changes will be streamed, in the format 'projects/{projectID}/locations/{region}/indexes/{indexID}' (no leading or trailing spaces)", + example = "projects/123/locations/us-east1/indexes/456") + String getVectorSearchIndex(); + + @SuppressWarnings("unused") + void setVectorSearchIndex(String value); + + @TemplateParameter.GcsWriteFolder( + order = 15, + optional = true, + description = "Dead letter queue directory to store any unpublished change record.", + helpText = + "The path to store any unprocessed records with" + + " the reason they failed to be processed. " + + "Default is a directory under the Dataflow job's temp location. " + + "The default value is enough under most conditions.") + @Default.String("") + String getDlqDirectory(); + + @SuppressWarnings("unused") + void setDlqDirectory(String value); +} diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearch.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearch.java new file mode 100644 index 0000000000..5f726d3de1 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearch.java @@ -0,0 +1,271 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import com.google.cloud.Timestamp; +import com.google.cloud.aiplatform.v1.IndexDatapoint; +import com.google.cloud.teleport.metadata.Template; +import com.google.cloud.teleport.metadata.TemplateCategory; +import com.google.cloud.teleport.v2.bigtable.options.BigtableCommonOptions.ReadChangeStreamOptions; +import com.google.cloud.teleport.v2.bigtable.options.BigtableCommonOptions.ReadOptions; +import com.google.cloud.teleport.v2.cdc.dlq.DeadLetterQueueManager; +import com.google.cloud.teleport.v2.options.BigtableChangeStreamsToVectorSearchOptions; +import com.google.cloud.teleport.v2.transforms.DLQWriteTransform; +import com.google.cloud.teleport.v2.utils.DurationUtils; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.io.gcp.bigtable.BigtableIO; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.GroupIntoBatches; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Values; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.commons.lang3.StringUtils; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@Template( + name = "Bigtable_Change_Streams_to_Vector_Search", + category = TemplateCategory.STREAMING, + displayName = "Bigtable Change Streams to Vector Search", + description = + "Streaming pipeline. Streams Bigtable data change records and writes them into Vertex AI Vector Search using Dataflow Runner V2.", + optionsClass = BigtableChangeStreamsToVectorSearchOptions.class, + optionsOrder = { + BigtableChangeStreamsToVectorSearchOptions.class, + ReadChangeStreamOptions.class, + ReadOptions.class + }, + skipOptions = { + "bigtableReadAppProfile", + "bigtableAdditionalRetryCodes", + "bigtableRpcAttemptTimeoutMs", + "bigtableRpcTimeoutMs" + }, + documentation = + "https://cloud.google.com/dataflow/docs/guides/templates/provided/bigtable-change-streams-to-vector-search", + flexContainerName = "bigtable-changestreams-to-vector-search", + contactInformation = "https://cloud.google.com/support", + streaming = true) +public final class BigtableChangeStreamsToVectorSearch { + private static final Logger LOG = + LoggerFactory.getLogger(BigtableChangeStreamsToVectorSearch.class); + + private static final String USE_RUNNER_V2_EXPERIMENT = "use_runner_v2"; + + /** + * Main entry point for executing the pipeline. + * + * @param args The command-line arguments to the pipeline. + */ + public static void main(String[] args) throws Exception { + LOG.info("Starting replication from Cloud Bigtable Change Streams to Vector Search"); + + BigtableChangeStreamsToVectorSearchOptions options = + PipelineOptionsFactory.fromArgs(args) + .withValidation() + .as(BigtableChangeStreamsToVectorSearchOptions.class); + + run(options); + } + + public static PipelineResult run(BigtableChangeStreamsToVectorSearchOptions options) + throws IOException { + options.setStreaming(true); + options.setEnableStreamingEngine(true); + + List experiments = options.getExperiments(); + if (experiments == null) { + experiments = new ArrayList<>(); + } + boolean hasUseRunnerV2 = false; + for (String experiment : experiments) { + if (experiment.equalsIgnoreCase(USE_RUNNER_V2_EXPERIMENT)) { + hasUseRunnerV2 = true; + break; + } + } + if (!hasUseRunnerV2) { + experiments.add(USE_RUNNER_V2_EXPERIMENT); + } + options.setExperiments(experiments); + + Instant startTimestamp = + options.getBigtableChangeStreamStartTimestamp().isEmpty() + ? Instant.now() + : toInstant(Timestamp.parseTimestamp(options.getBigtableChangeStreamStartTimestamp())); + + String bigtableProjectId = getBigtableProjectId(options); + + LOG.info(" - startTimestamp {}", startTimestamp); + LOG.info(" - bigtableReadInstanceId {}", options.getBigtableReadInstanceId()); + LOG.info(" - bigtableReadTableId {}", options.getBigtableReadTableId()); + LOG.info(" - bigtableChangeStreamAppProfile {}", options.getBigtableChangeStreamAppProfile()); + LOG.info(" - embeddingColumn {}", options.getEmbeddingColumn()); + LOG.info(" - crowdingTagColumn {}", options.getCrowdingTagColumn()); + LOG.info(" - project {}", options.getProject()); + LOG.info(" - indexName {}", options.getVectorSearchIndex()); + + String indexName = options.getVectorSearchIndex(); + + String vertexRegion = Utils.extractRegionFromIndexName(indexName); + String vertexEndpoint = vertexRegion + "-aiplatform.googleapis.com:443"; + + final Pipeline pipeline = Pipeline.create(options); + + DeadLetterQueueManager dlqManager = buildDlqManager(options); + + BigtableIO.ReadChangeStream readChangeStream = + BigtableIO.readChangeStream() + .withChangeStreamName(options.getBigtableChangeStreamName()) + .withExistingPipelineOptions( + options.getBigtableChangeStreamResume() + ? BigtableIO.ExistingPipelineOptions.RESUME_OR_FAIL + : BigtableIO.ExistingPipelineOptions.FAIL_IF_EXISTS) + .withProjectId(bigtableProjectId) + .withAppProfileId(options.getBigtableChangeStreamAppProfile()) + .withInstanceId(options.getBigtableReadInstanceId()) + .withTableId(options.getBigtableReadTableId()) + .withMetadataTableInstanceId(options.getBigtableChangeStreamMetadataInstanceId()) + .withMetadataTableTableId(options.getBigtableMetadataTableTableId()) + .withStartTime(startTimestamp); + + PCollectionTuple results = + pipeline + .apply("Read from Cloud Bigtable Change Streams", readChangeStream) + .apply("Create Values", Values.create()) + .apply( + "Converting to Vector Search Datapoints", + ParDo.of( + new ChangeStreamMutationToDatapointOperationFn( + options.getEmbeddingColumn(), + options.getEmbeddingByteSize(), + options.getCrowdingTagColumn(), + Utils.parseColumnMapping(options.getAllowRestrictsMappings()), + Utils.parseColumnMapping(options.getDenyRestrictsMappings()), + Utils.parseColumnMapping(options.getIntNumericRestrictsMappings()), + Utils.parseColumnMapping(options.getFloatNumericRestrictsMappings()), + Utils.parseColumnMapping(options.getDoubleNumericRestrictsMappings()))) + .withOutputTags( + ChangeStreamMutationToDatapointOperationFn.UPSERT_DATAPOINT_TAG, + TupleTagList.of( + ChangeStreamMutationToDatapointOperationFn.REMOVE_DATAPOINT_TAG))); + results + .get(ChangeStreamMutationToDatapointOperationFn.UPSERT_DATAPOINT_TAG) + .apply("Add placeholer keys", WithKeys.of("placeholder")) + .apply( + "Batch Contents", + GroupIntoBatches.ofSize( + bufferSizeOption(options.getUpsertMaxBatchSize())) + .withMaxBufferingDuration( + bufferDurationOption(options.getUpsertMaxBufferDuration()))) + .apply("Map to Values", Values.create()) + .apply( + "Upsert Datapoints to VectorSearch", + ParDo.of(new UpsertDatapointsFn(vertexEndpoint, indexName))) + .apply( + "Write errors to DLQ", + DLQWriteTransform.WriteDLQ.newBuilder() + .withDlqDirectory(dlqManager.getSevereDlqDirectory() + "YYYY/MM/dd/HH/mm/") + .withTmpDirectory(dlqManager.getSevereDlqDirectory() + "tmp/") + .setIncludePaneInfo(true) + .build()); + + results + .get(ChangeStreamMutationToDatapointOperationFn.REMOVE_DATAPOINT_TAG) + .apply("Add placeholder keys", WithKeys.of("placeholer")) + .apply( + "Batch Contents", + GroupIntoBatches.ofSize( + bufferSizeOption(options.getDeleteMaxBatchSize())) + .withMaxBufferingDuration( + bufferDurationOption(options.getDeleteMaxBufferDuration()))) + .apply("Map to Values", Values.create()) + .apply( + "Remove Datapoints From VectorSearch", + ParDo.of(new RemoveDatapointsFn(vertexEndpoint, indexName))) + .apply( + "Write errors to DLQ", + DLQWriteTransform.WriteDLQ.newBuilder() + .withDlqDirectory(dlqManager.getSevereDlqDirectory() + "YYYY/MM/dd/HH/mm/") + .withTmpDirectory(dlqManager.getSevereDlqDirectory() + "tmp/") + .setIncludePaneInfo(true) + .build()); + + return pipeline.run(); + } + + private static String getBigtableProjectId(BigtableChangeStreamsToVectorSearchOptions options) { + return StringUtils.isEmpty(options.getBigtableReadProjectId()) + ? options.getProject() + : options.getBigtableReadProjectId(); + } + + private static Instant toInstant(Timestamp timestamp) { + if (timestamp == null) { + return null; + } else { + return Instant.ofEpochMilli(timestamp.getSeconds() * 1000 + timestamp.getNanos() / 1000000); + } + } + + private static int bufferSizeOption(int size) { + if (size < 1) { + size = 1; + } + + return size; + } + + private static Duration bufferDurationOption(String duration) { + if (duration.isEmpty()) { + return Duration.standardSeconds(1); + } + + return DurationUtils.parseDuration(duration); + } + + private static DeadLetterQueueManager buildDlqManager( + BigtableChangeStreamsToVectorSearchOptions options) { + String dlqDirectory = options.getDlqDirectory(); + if (dlqDirectory.isEmpty()) { + LOG.info("Falling back to temp dir for DLQ"); + + String tempLocation = options.as(DataflowPipelineOptions.class).getTempLocation(); + + LOG.info("Have temp location {}", tempLocation); + if (tempLocation == null || tempLocation.isEmpty()) { + tempLocation = "/"; + } else if (!tempLocation.endsWith("/")) { + tempLocation += "/"; + } + + dlqDirectory = tempLocation + "dlq"; + } + + LOG.info("Writing dead letter queue to: {}", dlqDirectory); + + return DeadLetterQueueManager.create(dlqDirectory, 1); + } +} diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/ChangeStreamMutationToDatapointOperationFn.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/ChangeStreamMutationToDatapointOperationFn.java new file mode 100644 index 0000000000..83734366b6 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/ChangeStreamMutationToDatapointOperationFn.java @@ -0,0 +1,228 @@ +/* + * Copyright (C) 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import com.google.cloud.aiplatform.v1.IndexDatapoint; +import com.google.cloud.bigtable.data.v2.models.ChangeStreamMutation; +import com.google.cloud.bigtable.data.v2.models.DeleteCells; +import com.google.cloud.bigtable.data.v2.models.DeleteFamily; +import com.google.cloud.bigtable.data.v2.models.Entry; +import com.google.cloud.bigtable.data.v2.models.SetCell; +import java.util.Map; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.hadoop.hbase.util.Bytes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The {@link ChangeStreamMutationToDatapointOperationFn} class is a {@link DoFn} that takes in a + * Bigtable ChangeStreamMutation and converts it to either an IndexDatapoint (to be added to the + * index) or a String representing a Datapoint ID to be removed from the index. + */ +public class ChangeStreamMutationToDatapointOperationFn + extends DoFn { + + public static final TupleTag UPSERT_DATAPOINT_TAG = + new TupleTag() {}; + public static final TupleTag REMOVE_DATAPOINT_TAG = new TupleTag() {}; + + private static final Logger LOG = + LoggerFactory.getLogger(ChangeStreamMutationToDatapointOperationFn.class); + + private String embeddingsColumn; // "family_name:qualifier" + private String embeddingsColumnFamilyName; // "family_name" extracted from embeddingsColumn + private int embeddingsByteSize; // 4 or 8 + private String crowdingTagColumn; + private Map allowRestrictsMappings; + private Map denyRestrictsMappings; + private Map intNumericRestrictsMappings; + private Map floatNumericRestrictsMappings; + private Map doubleNumericRestrictsMappings; + + public ChangeStreamMutationToDatapointOperationFn( + String embeddingsColumn, + int embeddingsByteSize, + String crowdingTagColumn, + Map allowRestrictsMappings, + Map denyRestrictsMappings, + Map intNumericRestrictsMappings, + Map floatNumericRestrictsMappings, + Map doubleNumericRestrictsMappings) { + + { + String[] parts = embeddingsColumn.split(":", 2); + if (parts.length != 2) { + throw new IllegalArgumentException( + "Invalid embeddingsColumn - should be in the form \"family:qualifier\""); + } + + this.embeddingsColumn = embeddingsColumn; + this.embeddingsColumnFamilyName = parts[0]; + } + + this.embeddingsByteSize = embeddingsByteSize; + + if (this.embeddingsByteSize != 4 && this.embeddingsByteSize != 8) { + throw new IllegalArgumentException("Embeddings byte size must be 4 or 8"); + } + + this.crowdingTagColumn = crowdingTagColumn; + this.allowRestrictsMappings = allowRestrictsMappings; + this.denyRestrictsMappings = denyRestrictsMappings; + this.intNumericRestrictsMappings = intNumericRestrictsMappings; + this.floatNumericRestrictsMappings = floatNumericRestrictsMappings; + this.doubleNumericRestrictsMappings = doubleNumericRestrictsMappings; + } + + @ProcessElement + public void processElement(@Element ChangeStreamMutation mutation, MultiOutputReceiver output) { + + // Mutations should contain one or more setCells, *or* a DeleteCells *or* a DeleteFamily, or + // other mods that we're not interested in. Depending on what we find, dispatch to the correct + // handler + for (Entry entry : mutation.getEntries()) { + if (entry instanceof SetCell) { + processInsert(mutation, output); + return; + } else if (entry instanceof DeleteCells || entry instanceof DeleteFamily) { + processDelete(mutation, output); + return; + } + } + } + + private void processInsert(ChangeStreamMutation mutation, MultiOutputReceiver output) { + IndexDatapoint.Builder datapointBuilder = IndexDatapoint.newBuilder(); + var datapointId = mutation.getRowKey().toStringUtf8(); + if (datapointId.isEmpty()) { + LOG.info("Have a mutation with no rowkey"); + return; + } + + datapointBuilder.setDatapointId(datapointId); + + for (Entry entry : mutation.getEntries()) { + LOG.info("Processing {}", entry); + + // We're only interested in SetCell mutations; everything else should be ignored + if (!(entry instanceof SetCell)) { + continue; + } + + SetCell m = (SetCell) entry; + LOG.info("Have value {}", m.getValue()); + + var family = m.getFamilyName(); + var qualifier = m.getQualifier().toStringUtf8(); + var col = family + ":" + qualifier; + + String mappedColumn; + + if (col.equals(embeddingsColumn)) { + var floats = Utils.bytesToFloats(m.getValue(), embeddingsByteSize == 8); + + datapointBuilder.addAllFeatureVector(floats); + } else if (col.equals(crowdingTagColumn)) { + LOG.info("Setting crowding tag {}", m.getValue().toStringUtf8()); + datapointBuilder + .getCrowdingTagBuilder() + .setCrowdingAttribute(m.getValue().toStringUtf8()) + .build(); + } else if ((mappedColumn = allowRestrictsMappings.get(col)) != null) { + datapointBuilder + .addRestrictsBuilder() + .setNamespace(mappedColumn) + .addAllowListBytes(m.getValue()) + .build(); + } else if ((mappedColumn = denyRestrictsMappings.get(col)) != null) { + datapointBuilder + .addRestrictsBuilder() + .setNamespace(mappedColumn) + .addDenyListBytes(m.getValue()) + .build(); + } else if ((mappedColumn = intNumericRestrictsMappings.get(col)) != null) { + int i = Bytes.toInt(m.getValue().toByteArray()); + datapointBuilder + .addNumericRestrictsBuilder() + .setNamespace(mappedColumn) + .setValueInt(i) + .build(); + } else if ((mappedColumn = floatNumericRestrictsMappings.get(col)) != null) { + float f = Bytes.toFloat(m.getValue().toByteArray()); + datapointBuilder + .addNumericRestrictsBuilder() + .setNamespace(mappedColumn) + .setValueFloat(f) + .build(); + } else if ((mappedColumn = doubleNumericRestrictsMappings.get(col)) != null) { + double d = Bytes.toDouble(m.getValue().toByteArray()); + datapointBuilder + .addNumericRestrictsBuilder() + .setNamespace(mappedColumn) + .setValueDouble(d) + .build(); + } + } + + LOG.info("Emitting an upsert datapoint"); + output.get(UPSERT_DATAPOINT_TAG).output(datapointBuilder.build()); + } + + private void processDelete(ChangeStreamMutation mutation, MultiOutputReceiver output) { + LOG.info("Handling mutation as a deletion"); + + Boolean isDelete = + mutation.getEntries().stream() + .anyMatch( + (entry) -> { + // Each deletion may come in as one or more DeleteCells mutations, or one more or + // DeleteFamily mutations + // As soon as we find a DeleteCells that covers the fully qualified embeddings + // column, _or_ a DeleteFamily that + // covers the embeddings column's family, we treat the mutation as a deletion of + // the Datapoint. + if (entry instanceof DeleteCells) { + LOG.info("Have a DeleteCells"); + DeleteCells m = (DeleteCells) entry; + LOG.info("Have embeddings col {}", this.embeddingsColumn); + LOG.info("Have computed {}", m.getFamilyName() + ":" + m.getQualifier()); + + Boolean match = + (m.getFamilyName() + ":" + m.getQualifier()).matches(this.embeddingsColumn); + LOG.info("Match: {}", match); + return match; + } else if (entry instanceof DeleteFamily) { + LOG.info("Have a DeleteFamily"); + DeleteFamily m = (DeleteFamily) entry; + LOG.info("Have family name {}", m.getFamilyName()); + LOG.info("have stored family name {}", this.embeddingsColumnFamilyName); + Boolean match = m.getFamilyName().matches(this.embeddingsColumnFamilyName); + LOG.info("Have match {}", match); + return match; + } + + return false; + }); + + LOG.info("Have isDeleted {}", isDelete); + if (isDelete) { + String rowkey = mutation.getRowKey().toStringUtf8(); + LOG.info("Emitting a remove datapoint: {}", rowkey); + output.get(REMOVE_DATAPOINT_TAG).output(rowkey); + } + } +} diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/DatapointOperationFn.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/DatapointOperationFn.java new file mode 100644 index 0000000000..277c2c2295 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/DatapointOperationFn.java @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import com.google.cloud.aiplatform.v1.IndexServiceClient; +import com.google.cloud.aiplatform.v1.IndexServiceSettings; +import java.io.IOException; +import org.apache.beam.sdk.transforms.DoFn; +import org.slf4j.Logger; + +public abstract class DatapointOperationFn extends DoFn { + private String endpoint; + protected String indexName; + + protected transient IndexServiceClient client; + + protected abstract Logger logger(); + + public DatapointOperationFn(String endpoint, String indexName) { + this.indexName = indexName; + this.endpoint = endpoint; + } + + @Setup + public void setup() { + logger().info("Connecting to vector search endpoint {}", endpoint); + logger().info("Using index {}", indexName); + + try { + client = + IndexServiceClient.create( + IndexServiceSettings.newBuilder().setEndpoint(endpoint).build()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/RemoveDatapointsFn.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/RemoveDatapointsFn.java new file mode 100644 index 0000000000..3eca3f5abd --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/RemoveDatapointsFn.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import com.google.cloud.aiplatform.v1.RemoveDatapointsRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RemoveDatapointsFn extends DatapointOperationFn> { + + private static final Logger LOG = LoggerFactory.getLogger(RemoveDatapointsFn.class); + + protected Logger logger() { + return LOG; + } + + public RemoveDatapointsFn(String endpoint, String indexName) { + super(endpoint, indexName); + } + + @ProcessElement + public void processElement(ProcessContext c) { + var datapointIds = c.element(); + LOG.info("Deleting datapoints: {}", datapointIds); + + // Appears to work, even when some datapoints don't exist + RemoveDatapointsRequest request = + RemoveDatapointsRequest.newBuilder() + .addAllDatapointIds(datapointIds) + .setIndex(indexName) + .build(); + + try { + client.removeDatapoints(request); + } catch (io.grpc.StatusRuntimeException e) { + LOG.info("Failed to remove datapoints: {}", e.getLocalizedMessage()); + c.output("Error deleting datapoint: " + e.getLocalizedMessage()); + } + + LOG.info("Done"); + } +} diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UpsertDatapointsFn.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UpsertDatapointsFn.java new file mode 100644 index 0000000000..bc471381c8 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UpsertDatapointsFn.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import com.google.cloud.aiplatform.v1.IndexDatapoint; +import com.google.cloud.aiplatform.v1.UpsertDatapointsRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class UpsertDatapointsFn extends DatapointOperationFn> { + + private static final Logger LOG = LoggerFactory.getLogger(UpsertDatapointsFn.class); + + protected Logger logger() { + return LOG; + } + + public UpsertDatapointsFn(String endpoint, String indexName) { + super(endpoint, indexName); + } + + @ProcessElement + public void processElement(ProcessContext c) { + var datapoints = c.element(); + LOG.info("Upserting datapoints: {}", datapoints); + LOG.info("Using index name {}", indexName); + UpsertDatapointsRequest request = + UpsertDatapointsRequest.newBuilder() + .addAllDatapoints(datapoints) + .setIndex(indexName) + .build(); + + try { + client.upsertDatapoints(request); + } catch (Exception e) { + LOG.info("Failed to upsert datapoints: {}", e.getLocalizedMessage()); + c.output("Error writing to vector search:" + e.getLocalizedMessage()); + } + + LOG.info("Done"); + } +} diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/Utils.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/Utils.java new file mode 100644 index 0000000000..d0be9ebedc --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/Utils.java @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import com.google.protobuf.ByteString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.hbase.util.Bytes; + +public class Utils { + + public static String extractRegionFromIndexName(String indexName) { + Pattern p = Pattern.compile("^?projects\\/\\d+/locations\\/([^/]+)\\/indexes/\\d+?$"); + Matcher matcher = p.matcher(indexName); + if (matcher.find()) { + String region = matcher.group(1); + return region; + } + + throw new IllegalArgumentException("Invalid IndexName"); + } + + // Split "cf1:foo1->bar1,cf1:foo2->bar2" into a map of { "cf1:foo1": "bar1", "cf1:foo2": "bar2" } + public static Map parseColumnMapping(String mapstr) { + Map columnsWithAliases = new HashMap<>(); + if (StringUtils.isBlank(mapstr)) { + return columnsWithAliases; + } + String[] columnsList = mapstr.split(","); + + for (String columnsWithAlias : columnsList) { + String[] columnWithAlias = columnsWithAlias.split("->"); + if (columnWithAlias.length == 2 + && columnWithAlias[0].length() >= 1 + && columnWithAlias[1].length() >= 1) { + columnsWithAliases.put(columnWithAlias[0], columnWithAlias[1]); + } else { + throw new IllegalArgumentException( + String.format("Malformed column mapping pair %s", columnsList)); + } + } + return columnsWithAliases; + } + + // Convert a ByteString into an array of 4 byte single-precision floats + // If parseDoubles is true, interpret the ByteString as containing 8 byte double-precision floats, + // which are narrowed to 4 byte floats + // If parseDoubles is false, interpret the ByteString as containing 4 byte floats + public static ArrayList bytesToFloats(ByteString value, Boolean parseDoubles) { + byte[] bytes = value.toByteArray(); + int bytesPerFloat = (parseDoubles ? 8 : 4); + + if (bytes.length % bytesPerFloat != 0) { + throw new RuntimeException( + String.format( + "Invalid ByteStream length %d (should be a multiple of %d)", + bytes.length, bytesPerFloat)); + } + + var embeddings = new ArrayList(); + for (int i = 0; i < bytes.length; i += bytesPerFloat) { + if (parseDoubles) { + embeddings.add((float) Bytes.toDouble(bytes, i)); + } else { + embeddings.add(Bytes.toFloat(bytes, i)); + } + } + + return embeddings; + } +} diff --git a/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/package-info.java b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/package-info.java new file mode 100644 index 0000000000..54b367aa01 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright (C) 2023 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +/** Utility classes for reading/writing/transforming Avro data. */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; diff --git a/v2/googlecloud-to-googlecloud/src/main/resources/bigtable-changestreams-to-vector-embeddings-command-spec.json b/v2/googlecloud-to-googlecloud/src/main/resources/bigtable-changestreams-to-vector-embeddings-command-spec.json new file mode 100644 index 0000000000..4ea0e1178b --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/main/resources/bigtable-changestreams-to-vector-embeddings-command-spec.json @@ -0,0 +1,7 @@ +{ + "mainClass": "com.google.cloud.teleport.v2.templates.BigtableChangeStreamsToVectorSearch", + "classPath": "/template/bigtable-changestreams-to-vector-search/*:/template/bigtable-changestreams-to-vector-search/libs/*:/template/bigtable-changestreams-to-vector-search/classes", + "defaultParameterValues": { + "labels": "{\"goog-dataflow-provided-template-type\":\"flex\", \"goog-dataflow-provided-template-name\":\"bigtable-changestreams-to-vector-search\"}" + } +} diff --git a/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearchIT.java b/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearchIT.java new file mode 100644 index 0000000000..676a56f7ae --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/BigtableChangeStreamsToVectorSearchIT.java @@ -0,0 +1,534 @@ +/* + * Copyright (C) 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatPipeline; +import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatResult; +import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric; +import static org.awaitility.Awaitility.await; +import static org.junit.Assert.assertEquals; + +import com.google.api.gax.paging.Page; +import com.google.cloud.bigtable.admin.v2.models.StorageType; +import com.google.cloud.bigtable.data.v2.models.RowMutation; +import com.google.cloud.storage.Blob; +import com.google.cloud.storage.Storage; +import com.google.cloud.storage.Storage.BlobListOption; +import com.google.cloud.storage.StorageOptions; +import com.google.cloud.teleport.metadata.TemplateIntegrationTest; +import com.google.common.collect.Lists; +import com.google.common.primitives.Floats; +import com.google.protobuf.ByteString; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.it.common.PipelineLauncher; +import org.apache.beam.it.common.PipelineOperator; +import org.apache.beam.it.common.utils.ResourceManagerUtils; +import org.apache.beam.it.gcp.TemplateTestBase; +import org.apache.beam.it.gcp.bigtable.BigtableResourceManager; +import org.apache.beam.it.gcp.bigtable.BigtableResourceManagerCluster; +import org.apache.beam.it.gcp.bigtable.BigtableTableSpec; +import org.apache.commons.lang3.RandomUtils; +import org.jetbrains.annotations.NotNull; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Integration test for {@link BigtableChangeStreamsToVectorSearch}. */ +@Category({TemplateIntegrationTest.class}) +@TemplateIntegrationTest(BigtableChangeStreamsToVectorSearch.class) +@RunWith(JUnit4.class) +@Ignore("Tests are flaky as indexes take forever to be created causing timeouts.") +public final class BigtableChangeStreamsToVectorSearchIT extends TemplateTestBase { + + private static final Logger LOG = + LoggerFactory.getLogger(BigtableChangeStreamsToVectorSearchIT.class); + + private static final String TEST_ZONE = "us-central1-b"; + private static final String PROJECT_NUMBER = "269744978479"; // cloud-teleport-testing + private BigtableResourceManager bigtableResourceManager; + private static VectorSearchResourceManager vectorSearchResourceManager; + + private static final String clusterName = "teleport-c1"; + private String appProfileId; + private String srcTable; + + // Fixture embeddings for tests, which match the dimensionality of the test index + private static final List GOOD_EMBEDDINGS = + Floats.asList(0.01f, 0.02f, 0.03f, 0.04f, 0.05f, 0.06f, 0.07f, 0.08f, 0.09f, 0.10f); + + // The GOOD_EMBEDDINGS list of floats converted to a byte array of single-precision (4 byte) + // big-endian floats, as they would be sent to CBT + private static final byte[] GOOD_EMBEDDING_BYTES = { + (byte) 0x3c, (byte) 0x23, (byte) 0xd7, (byte) 0x0a, // 0.01 + (byte) 0x3c, (byte) 0xa3, (byte) 0xd7, (byte) 0x0a, // 0.02 + (byte) 0x3c, (byte) 0xf5, (byte) 0xc2, (byte) 0x8f, // 0.03 + (byte) 0x3d, (byte) 0x23, (byte) 0xd7, (byte) 0x0a, // 0.04 + (byte) 0x3d, (byte) 0x4c, (byte) 0xcc, (byte) 0xcd, // 0.05 + (byte) 0x3d, (byte) 0x75, (byte) 0xc2, (byte) 0x8f, // 0.06 + (byte) 0x3d, (byte) 0x8f, (byte) 0x5c, (byte) 0x29, // 0.07 + (byte) 0x3d, (byte) 0xa3, (byte) 0xd7, (byte) 0x0a, // 0.08 + (byte) 0x3d, (byte) 0xb8, (byte) 0x51, (byte) 0xec, // 0.09 + (byte) 0x3d, (byte) 0xcc, (byte) 0xcc, (byte) 0xcd, // 0.10 + }; + + // A list of embeddings that is too short, and should produce an error on sync + private static final List BAD_EMBEDDINGS = + Floats.asList(0.01f, 0.02f, 0.03f, 0.04f, 0.05f); + + // The BAD_EMBEDDINGS list of floats converted to a byte array of single-precision (4 byte) + // big-endian floats, as they would be sent to CBT + private static final byte[] BAD_EMBEDDING_BYTES = { + (byte) 0x3c, (byte) 0x23, (byte) 0xd7, (byte) 0x0a, // 0.01 + (byte) 0x3c, (byte) 0xa3, (byte) 0xd7, (byte) 0x0a, // 0.02 + (byte) 0x3c, (byte) 0xf5, (byte) 0xc2, (byte) 0x8f, // 0.03 + (byte) 0x3d, (byte) 0x23, (byte) 0xd7, (byte) 0x0a, // 0.04 + (byte) 0x3d, (byte) 0x4c, (byte) 0xcc, (byte) 0xcd, // 0.05 + }; + + private static final String EMBEDDING_BYTE_SIZE = "4"; + + // Columns for writing to CBT and their ByteString equivalent, so we don't continually have to + // pass them through ByteString.copyFrom(...) in tests + private static final String SOURCE_COLUMN_FAMILY = "cf"; + private static final String EMBEDDING_COLUMN_NAME = "embeddings"; + private static final String CROWDING_TAG_COLUMN_NAME = "crowding_tag"; + private static final String ALLOW_RESTRICTS_COLUMN_NAME = "allow"; + private static final String DENY_RESTRICTS_COLUMN_NAME = "deny"; + private static final String INT_RESTRICTS_COLUMN_NAME = "int-restrict"; + private static final String FLOAT_RESTRICTS_COLUMN_NAME = "float-restrict"; + private static final String DOUBLE_RESTRICTS_COLUMN_NAME = "double-restrict"; + + private static final ByteString EMBEDDING_COLUMN = + ByteString.copyFrom(EMBEDDING_COLUMN_NAME, Charset.defaultCharset()); + private static final ByteString CROWDING_TAG_COLUMN = + ByteString.copyFrom(CROWDING_TAG_COLUMN_NAME, Charset.defaultCharset()); + private static final ByteString ALLOW_RESTRICTS_COLUMN = + ByteString.copyFrom(ALLOW_RESTRICTS_COLUMN_NAME, Charset.defaultCharset()); + private static final ByteString DENY_RESTRICTS_COLUMN = + ByteString.copyFrom(DENY_RESTRICTS_COLUMN_NAME, Charset.defaultCharset()); + private static final ByteString INT_RESTRICTS_COLUMN = + ByteString.copyFrom(INT_RESTRICTS_COLUMN_NAME, Charset.defaultCharset()); + private static final ByteString FLOAT_RESTRICTS_COLUMN = + ByteString.copyFrom(FLOAT_RESTRICTS_COLUMN_NAME, Charset.defaultCharset()); + private static final ByteString DOUBLE_RESTRICTS_COLUMN = + ByteString.copyFrom(DOUBLE_RESTRICTS_COLUMN_NAME, Charset.defaultCharset()); + + // Tags we'll read from in a datapoint + private static final String ALLOW_RESTRICTS_TAG = "allowtag"; + private static final String DENY_RESTRICTS_TAG = "denytag"; + private static final String INT_RESTRICTS_TAG = "inttag"; + private static final String FLOAT_RESTRICTS_TAG = "floattag"; + private static final String DOUBLE_RESTRICTS_TAG = "doubletag"; + + @BeforeClass + public static void setupClass() throws Exception { + vectorSearchResourceManager = + VectorSearchResourceManager.findOrCreateTestInfra(PROJECT_NUMBER, REGION); + } + + @Before + public void setup() throws Exception { + // REGION and PROJECT are available, but we need the project number, not its name + LOG.info("Have project number {}", PROJECT_NUMBER); + LOG.info("Have project {}", PROJECT); + LOG.info("Have region number {}", REGION); + + bigtableResourceManager = + BigtableResourceManager.builder( + removeUnsafeCharacters(testName), PROJECT, credentialsProvider) + .maybeUseStaticInstance() + .build(); + + appProfileId = generateAppProfileId(); + srcTable = generateTableName(); + + List clusters = new ArrayList<>(); + clusters.add(BigtableResourceManagerCluster.create(clusterName, TEST_ZONE, 1, StorageType.HDD)); + bigtableResourceManager.createInstance(clusters); + + bigtableResourceManager.createAppProfile( + appProfileId, true, bigtableResourceManager.getClusterNames()); + + BigtableTableSpec cdcTableSpec = new BigtableTableSpec(); + cdcTableSpec.setCdcEnabled(true); + cdcTableSpec.setColumnFamilies(Lists.asList(SOURCE_COLUMN_FAMILY, new String[] {})); + bigtableResourceManager.createTable(srcTable, cdcTableSpec); + + LOG.info("Cluster names: {}", bigtableResourceManager.getClusterNames()); + LOG.info("Have instance {}", bigtableResourceManager.getInstanceId()); + } + + @After + public void tearDown() { + ResourceManagerUtils.cleanResources(bigtableResourceManager, vectorSearchResourceManager); + } + + private PipelineLauncher.LaunchConfig.Builder defaultLaunchConfig() { + return PipelineLauncher.LaunchConfig.builder("test-job", specPath) + // Working configuration required by every test + .addParameter("bigtableReadTableId", srcTable) + .addParameter("bigtableReadInstanceId", bigtableResourceManager.getInstanceId()) + .addParameter("bigtableChangeStreamAppProfile", appProfileId) + .addParameter("bigtableChangeStreamCharset", "KOI8-R") + .addParameter("vectorSearchIndex", vectorSearchResourceManager.getTestIndex().getName()) + // Optional configuration that some tests may with to override + .addParameter("embeddingColumn", SOURCE_COLUMN_FAMILY + ":" + EMBEDDING_COLUMN_NAME) + .addParameter("embeddingByteSize", EMBEDDING_BYTE_SIZE) + .addParameter("crowdingTagColumn", SOURCE_COLUMN_FAMILY + ":" + CROWDING_TAG_COLUMN_NAME) + .addParameter( + "allowRestrictsMappings", + SOURCE_COLUMN_FAMILY + ":" + ALLOW_RESTRICTS_COLUMN_NAME + "->" + ALLOW_RESTRICTS_TAG) + .addParameter("upsertMaxBatchSize", "1") + .addParameter("upsertMaxBufferDuration", "1s") + .addParameter("deleteMaxBatchSize", "1") + .addParameter("deleteMaxBufferDuration", "1s") + .addParameter( + "denyRestrictsMappings", + SOURCE_COLUMN_FAMILY + ":" + DENY_RESTRICTS_COLUMN_NAME + "->" + DENY_RESTRICTS_TAG) + .addParameter( + "intNumericRestrictsMappings", + SOURCE_COLUMN_FAMILY + ":" + INT_RESTRICTS_COLUMN_NAME + "->" + INT_RESTRICTS_TAG) + .addParameter( + "floatNumericRestrictsMappings", + SOURCE_COLUMN_FAMILY + ":" + FLOAT_RESTRICTS_COLUMN_NAME + "->" + FLOAT_RESTRICTS_TAG) + .addParameter( + "doubleNumericRestrictsMappings", + SOURCE_COLUMN_FAMILY + + ":" + + DOUBLE_RESTRICTS_COLUMN_NAME + + "->" + + DOUBLE_RESTRICTS_TAG); + } + + @Test + public void testRowMutationsThatAddEmbeddingsAreSyncedAsUpserts() throws Exception { + LOG.info("Testname: {}", testName); + LOG.info("specPath: {}", specPath); + LOG.info("srcTable: {}", srcTable); + // LOG.info("bigtableResourceManagerInstanceId: {}", + // bigtableResourceManager.getInstanceId()); + // LOG.info("appProfileId: {}", appProfileId); + PipelineLauncher.LaunchInfo launchInfo = launchTemplate(defaultLaunchConfig()); + LOG.info("Pipeline launched: {}", launchInfo.pipelineName()); + + assertThatPipeline(launchInfo).isRunning(); + LOG.info("Pipeline is running"); + String rowkey = vectorSearchResourceManager.makeDatapointId(); + LOG.info("Writing rowkey {}", rowkey); + + long timestamp = 12000L; + RowMutation rowMutation = + RowMutation.create(srcTable, rowkey) + .setCell( + SOURCE_COLUMN_FAMILY, + EMBEDDING_COLUMN, + timestamp, + ByteString.copyFrom(GOOD_EMBEDDING_BYTES)); + + LOG.info("Writing row {}", rowkey); + bigtableResourceManager.write(rowMutation); + + PipelineOperator.Result result = + pipelineOperator() + .waitForConditionAndCancel( + createConfig(launchInfo, Duration.ofMinutes(30)), + () -> { + LOG.info("Looking for new datapoint"); + var datapoint = vectorSearchResourceManager.findDatapoint(rowkey); + if (datapoint == null) { + LOG.info("No result"); + + return false; + } else { + LOG.info("Found result: {}", datapoint.getFeatureVectorList()); + assertEqualVectors(GOOD_EMBEDDINGS, datapoint.getFeatureVectorList()); + return true; + } + }); + + assertThatResult(result).meetsConditions(); + } + + @Test + public void testBadEmbeddingsAreRejected() throws Exception { + LOG.info("Testname: {}", testName); + LOG.info("specPath: {}", specPath); + LOG.info("srcTable: {}", srcTable); + + PipelineLauncher.LaunchInfo launchInfo = + launchTemplate(defaultLaunchConfig().addParameter("dlqDirectory", getGcsPath("dlq"))); + + assertThatPipeline(launchInfo).isRunning(); + LOG.info("Pipeline launched: {}", launchInfo.pipelineName()); + + String rowkey = vectorSearchResourceManager.makeDatapointId(); + LOG.info("Writing rowkey {}", rowkey); + + long timestamp = 12000L; + + // This row mutation should fail, the Index API should reject it due to the incorrect + // dimensionality of the embeddings vector + RowMutation rowMutation = + RowMutation.create(srcTable, rowkey) + .setCell( + SOURCE_COLUMN_FAMILY, + EMBEDDING_COLUMN, + timestamp, + ByteString.copyFrom(BAD_EMBEDDING_BYTES)); + + LOG.info("Writing row {}", rowkey); + bigtableResourceManager.write(rowMutation); + + Storage storage = StorageOptions.newBuilder().setProjectId(PROJECT).build().getService(); + + String filterPrefix = + String.join("/", getClass().getSimpleName(), gcsClient.runId(), "dlq", "severe"); + LOG.info("Looking for files with a prefix: {}", filterPrefix); + + await("The failed message was not found in DLQ") + .atMost(Duration.ofMinutes(30)) + .pollInterval(Duration.ofSeconds(1)) + .until( + () -> { + Page blobs = + storage.list(artifactBucketName, BlobListOption.prefix(filterPrefix)); + + for (Blob blob : blobs.iterateAll()) { + // Ignore temp files + if (blob.getName().contains(".temp-beam/")) { + continue; + } + + byte[] content = storage.readAllBytes(blob.getBlobId()); + var errorMessage = new String(content); + + LOG.info("Have message '{}'", errorMessage); + String wantMessage = + String.format( + "Error writing to vector search:io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Incorrect dimensionality. Expected 10, got 5. Datapoint ID: %s.\n", + rowkey); + LOG.info("Want message '{}'", wantMessage); + assertEquals(errorMessage, wantMessage); + + return true; + } + + return false; + }); + } + + @Test + public void testRowMutationsThatIncludeOptionalFieldsAreSynced() throws Exception { + LOG.info("Testname: {}", testName); + LOG.info("specPath: {}", specPath); + LOG.info("srcTable: {}", srcTable); + + PipelineLauncher.LaunchInfo launchInfo = launchTemplate(defaultLaunchConfig()); + + LOG.info("Pipeline launched1"); + assertThatPipeline(launchInfo).isRunning(); + LOG.info("Pipeline launched2"); + LOG.info("Writing mutation"); + String rowkey = vectorSearchResourceManager.makeDatapointId(); + + LOG.info("Writing rowkey {}", rowkey); + + final String crowdingTag = randomAlphanumeric(10); + final String allowTag = randomAlphanumeric(10); + final String denyTag = randomAlphanumeric(10); + final int intRestrict = RandomUtils.nextInt(); + final byte[] intBytes = ByteBuffer.allocate(4).putInt(intRestrict).array(); + final float floatRestrict = RandomUtils.nextFloat(); + final byte[] floatBytes = ByteBuffer.allocate(4).putFloat(floatRestrict).array(); + final double doubleRestrict = RandomUtils.nextDouble(); + final byte[] doubleBytes = ByteBuffer.allocate(8).putDouble(doubleRestrict).array(); + + long timestamp = 12000L; + RowMutation rowMutation = + RowMutation.create(srcTable, rowkey) + .setCell( + SOURCE_COLUMN_FAMILY, + EMBEDDING_COLUMN, + timestamp, + ByteString.copyFrom(GOOD_EMBEDDING_BYTES)) + .setCell( + SOURCE_COLUMN_FAMILY, + CROWDING_TAG_COLUMN, + timestamp, + ByteString.copyFrom(crowdingTag, Charset.defaultCharset())) + .setCell( + SOURCE_COLUMN_FAMILY, + ALLOW_RESTRICTS_COLUMN, + timestamp, + ByteString.copyFrom(allowTag, Charset.defaultCharset())) + .setCell( + SOURCE_COLUMN_FAMILY, + DENY_RESTRICTS_COLUMN, + timestamp, + ByteString.copyFrom(denyTag, Charset.defaultCharset())) + .setCell( + SOURCE_COLUMN_FAMILY, + INT_RESTRICTS_COLUMN, + timestamp, + ByteString.copyFrom(intBytes)) + .setCell( + SOURCE_COLUMN_FAMILY, + FLOAT_RESTRICTS_COLUMN, + timestamp, + ByteString.copyFrom(floatBytes)) + .setCell( + SOURCE_COLUMN_FAMILY, + DOUBLE_RESTRICTS_COLUMN, + timestamp, + ByteString.copyFrom(doubleBytes)); + + bigtableResourceManager.write(rowMutation); + + PipelineOperator.Result result = + pipelineOperator() + .waitForConditionAndCancel( + createConfig(launchInfo, Duration.ofMinutes(30)), + () -> { + LOG.info("Looking for new datapoint"); + var datapoint = vectorSearchResourceManager.findDatapoint(rowkey); + if (datapoint == null) { + LOG.info("No result"); + return false; + } + + LOG.info("Found result: {}", datapoint); + + assertEquals(2, datapoint.getRestrictsCount()); + + assertEquals(GOOD_EMBEDDINGS.size(), datapoint.getFeatureVectorCount()); + assertEqualVectors(GOOD_EMBEDDINGS, datapoint.getFeatureVectorList()); + + for (var r : datapoint.getRestrictsList()) { + if (r.getNamespace().equals(ALLOW_RESTRICTS_TAG)) { + // It's our allow-restrict, verify tag matches + assertEquals(1, r.getAllowListCount()); + // LOG("Have allow list {}", r.getAllowListList()); + assertEquals(allowTag, r.getAllowList(0)); + } else { + // it's necessarily our deny-restrict, verify tag matches + assertEquals(DENY_RESTRICTS_TAG, r.getNamespace()); + assertEquals(1, r.getDenyListCount()); + assertEquals(denyTag, r.getDenyList(0)); + } + } + + assertEquals(3, datapoint.getNumericRestrictsCount()); + + for (var r : datapoint.getNumericRestrictsList()) { + if (r.getNamespace().equals(INT_RESTRICTS_TAG)) { + assertEquals(intRestrict, r.getValueInt()); + } else if (r.getNamespace().equals(FLOAT_RESTRICTS_TAG)) { + assertEquals(floatRestrict, r.getValueFloat(), 0.001); + } else if (r.getNamespace().equals(DOUBLE_RESTRICTS_TAG)) { + assertEquals(doubleRestrict, r.getValueDouble(), 0.001); + } else { + throw new RuntimeException("Unexpected numeric restrict"); + } + } + + return true; + }); + + assertThatResult(result).meetsConditions(); + } + + @Test + public void testDeleteFamilyMutationsThatDeletetEmbeddingsColumnAreSyncedAsDeletes() + throws Exception { + + var datapointId = vectorSearchResourceManager.makeDatapointId(); + vectorSearchResourceManager.addDatapoint(datapointId, GOOD_EMBEDDINGS); + + PipelineLauncher.LaunchInfo launchInfo = launchTemplate(defaultLaunchConfig()); + + assertThatPipeline(launchInfo).isRunning(); + + LOG.info("Waiting for datapoint to become queryable"); + PipelineOperator.Result result = + pipelineOperator() + .waitForCondition( + createConfig(launchInfo, Duration.ofMinutes(30)), + () -> { + // Make sure the datapoint exists and is findable + var dp = vectorSearchResourceManager.findDatapoint(datapointId); + LOG.info(dp == null ? "DP does not yet exist" : "DP exists"); + return dp != null; + }); + + assertThatResult(result).meetsConditions(); + + RowMutation rowMutation = + RowMutation.create(srcTable, datapointId).deleteFamily(SOURCE_COLUMN_FAMILY); + + bigtableResourceManager.write(rowMutation); + + LOG.info("Waiting for row to be deleted to become queryable"); + PipelineOperator.Result result2 = + pipelineOperator() + .waitForConditionAndCancel( + createConfig(launchInfo, Duration.ofMinutes(30)), + () -> { + // Make sure the datapoint exists and is findable + var dp = vectorSearchResourceManager.findDatapoint(datapointId); + LOG.info(dp == null ? "DP has been deleted" : "DP still exists"); + return dp == null; + }); + + assertThatResult(result).meetsConditions(); + } + + @NotNull + public static Boolean assertEqualVectors(Iterable expected, Iterable actual) { + var i = actual.iterator(); + var j = actual.iterator(); + + while (i.hasNext() && j.hasNext()) { + assertEquals(i.next(), j.next(), 0.0001); + } + + return !i.hasNext() && !j.hasNext(); + } + + @NotNull + private static String generateAppProfileId() { + return "cdc_app_profile_" + randomAlphanumeric(8).toLowerCase() + "_" + System.nanoTime(); + } + + @NotNull + private String generateTableName() { + return "table_" + randomAlphanumeric(8).toLowerCase() + "_" + System.nanoTime(); + } + + private String removeUnsafeCharacters(String testName) { + return testName.replaceAll("[\\[\\]]", "-"); + } +} diff --git a/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UtilsTest.java b/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UtilsTest.java new file mode 100644 index 0000000000..85c7cb34a8 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/UtilsTest.java @@ -0,0 +1,140 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import com.google.protobuf.ByteString; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test cases for the {@link Utils} class. */ +@RunWith(JUnit4.class) +public class UtilsTest { + + @Test + public void testExtractRegionFromIndexName() { + assertEquals( + "us-east1", + Utils.extractRegionFromIndexName("projects/123/locations/us-east1/indexes/456")); + + var badIndexNames = + List.of( + "foo", + "/projects/123/locations/us-east1/indexes", + "projects/123/locations/us-east1/indexes/", + "/projects/123/locations/us-east1/indexes/"); + + for (var indexName : badIndexNames) { + var ex = + assertThrows( + IllegalArgumentException.class, + () -> { + Utils.extractRegionFromIndexName(indexName); + }); + assertEquals("Invalid IndexName", ex.getMessage()); + } + } + + @Test + public void testParseColumnMappings() { + var input = "cf1:foo1->bar1,cf2:foo2->bar2"; + + var got = Utils.parseColumnMapping(input); + var want = new HashMap(); + want.put("cf1:foo1", "bar1"); + want.put("cf2:foo2", "bar2"); + + assertEquals(want, got); + } + + @Test + public void testParseColumnMappingWithBadInput() { + var badMapping = "cf1:foo->bar->baz"; + + var ex = + assertThrows( + IllegalArgumentException.class, + () -> { + Utils.parseColumnMapping(badMapping); + }); + + assertEquals("Malformed column mapping pair cf1:foo->bar->baz", ex.getMessage()); + } + + @Test + public void testFloatEmbeddingsAreDecoded() { + byte[] bytes = { + // 4 byte single precision big-endian IEEE 754 float 3.14 + (byte) 64, (byte) 72, (byte) 245, (byte) 195, + // 4 byte single precision big-endian IEEE 754 float 2.1782 + (byte) 64, (byte) 45, (byte) 246, (byte) 253 + }; + var bs = ByteString.copyFrom(bytes); + + var want = new ArrayList(Arrays.asList(3.14f, 2.7182f)); + var got = Utils.bytesToFloats(bs, false); + + assertEquals(want, got); + } + + @Test + public void testEmptyEmbeddingsDecodeCorrectly() { + byte[] bytes = {}; + var bs = ByteString.copyFrom(bytes); + + var want = new ArrayList(); + var got = Utils.bytesToFloats(bs, false); + assertEquals(want, got); + } + + @Test + public void testInvalidLengthEmbeddingsProducesException() { + byte[] bytes = {(byte) 1, (byte) 2, (byte) 3}; + var bs = ByteString.copyFrom(bytes); + + var ex = + assertThrows( + RuntimeException.class, + () -> { + Utils.bytesToFloats(bs, false); + }); + + assertEquals("Invalid ByteStream length 3 (should be a multiple of 4)", ex.getMessage()); + } + + @Test + public void testDoubleLengthFloatEncodingsAreDecoded() { + // Test that 8 byte doubles are correctly decoded into 4 byte floats + byte[] bytes = { + // 8 byte double precision big-endian IEEE 754 float 3.14 + (byte) 64, (byte) 9, (byte) 30, (byte) 184, (byte) 96, (byte) 0, (byte) 0, (byte) 0, + // 8 byte double precision big-endian IEEE 754 float 2.1782 + (byte) 64, (byte) 5, (byte) 190, (byte) 223, (byte) 160, (byte) 0, (byte) 0, (byte) 0, + }; + var bs = ByteString.copyFrom(bytes); + + var want = new ArrayList(Arrays.asList(3.14f, 2.7182f)); + var got = Utils.bytesToFloats(bs, true); + assertEquals(want, got); + } +} diff --git a/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/VectorSearchResourceManager.java b/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/VectorSearchResourceManager.java new file mode 100644 index 0000000000..8d39afc456 --- /dev/null +++ b/v2/googlecloud-to-googlecloud/src/test/java/com/google/cloud/teleport/v2/templates/bigtablechangestreamstovectorsearch/VectorSearchResourceManager.java @@ -0,0 +1,362 @@ +/* + * Copyright (C) 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.teleport.v2.templates.bigtablechangestreamstovectorsearch; + +import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric; + +import com.google.api.gax.longrunning.OperationTimedPollAlgorithm; +import com.google.api.gax.retrying.RetrySettings; +import com.google.cloud.aiplatform.v1.CreateIndexRequest; +import com.google.cloud.aiplatform.v1.DeployIndexRequest; +import com.google.cloud.aiplatform.v1.DeployedIndex; +import com.google.cloud.aiplatform.v1.Index; +import com.google.cloud.aiplatform.v1.IndexDatapoint; +import com.google.cloud.aiplatform.v1.IndexEndpoint; +import com.google.cloud.aiplatform.v1.IndexEndpointServiceClient; +import com.google.cloud.aiplatform.v1.IndexEndpointServiceSettings; +import com.google.cloud.aiplatform.v1.IndexServiceClient; +import com.google.cloud.aiplatform.v1.IndexServiceSettings; +import com.google.cloud.aiplatform.v1.MatchServiceClient; +import com.google.cloud.aiplatform.v1.MatchServiceSettings; +import com.google.cloud.aiplatform.v1.ReadIndexDatapointsRequest; +import com.google.cloud.aiplatform.v1.ReadIndexDatapointsResponse; +import com.google.cloud.aiplatform.v1.RemoveDatapointsRequest; +import com.google.cloud.aiplatform.v1.UpsertDatapointsRequest; +import com.google.protobuf.TextFormat; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.apache.beam.it.common.ResourceManager; +import org.eclipse.jetty.util.ConcurrentHashSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.threeten.bp.Duration; + +/** + * VectorSearchProvisioner looks up testing infra structure in the given project and region, and + * creates it if it doesn't already exist. Creation can take up to 30 minutes, so infra is left + * intact after tests run, so it can be reused in subsequent test runs. + * + *

Specifically, it creates: - A new index, named defined by TEST_INDEX_NAME, with a + * dimensionality of 10, accessible via `getTestIndex()` - A new endpoint, name defined by + * TEST_ENDPOINT_NAME, accessible vai `getTestEndpoint()` - A new index deployment deploying the + * test index to the test endpoint + */ +public class VectorSearchResourceManager implements ResourceManager { + private static final Logger LOG = LoggerFactory.getLogger(VectorSearchResourceManager.class); + + // Hard-coded; if this class winds up being used by other templates, we'll probably want to + // add a Builder class, and allow these to be configured. + // Note - in the cloud-teleport-testing project, a "nokill' suffix prevents the test + // infrastructure from being automatically cleaned up. + public static final int TEST_INDEX_DIMENSIONS = 10; + public static final String TEST_INDEX_NAME = "bt-vs-index1-nokill"; + public static final String TEST_ENDPOINT_NAME = "bt-vs-endpoint1-nokill"; + public static final String TEST_DEPLOYED_INDEX_NAME = "bt_deploy1_nokill"; + + /** + * Factory method for creating a new resource manager, and provisioning the test index/endpoint. + * The returned resource manager will have a ready-to-use index, endpoint and deployed endpoint. + * + * @param projectNumber The _numeric_ project ID (ex "269744978479") + * @param region The region in which to find-or-create the infra (ex "us-east1") + * @return A new VectorSearchResourceManagerr instance + * @throws Exception + */ + // Factory method + public static VectorSearchResourceManager findOrCreateTestInfra( + String projectNumber, String region) throws Exception { + var c = new VectorSearchResourceManager(projectNumber, region); + c.findOrCreateTestInfra(); + return c; + } + + public Index getTestIndex() { + return testIndex; + } + + public IndexEndpoint getTestEndpoint() { + return testEndpoint; + } + + private Index testIndex = null; + private IndexEndpoint testEndpoint = null; + private DeployedIndex testDeployedIndex = null; + + private String region; + private String projectNumber; + private String host; // ie "us-east1-aiplatform.googleapis.com:443"; (port required) + private String parent; // ie "projects/123/locations/us-east1" (no leading/trailing slashes) + + private IndexServiceClient indexClient; + private IndexEndpointServiceClient endpointClient; + private MatchServiceClient matchClient; + + // Each datapoint ID we generate goes into this set; when we tear down, we send a + // `removeDatapoints` request to the index, to prevent data living beyond the test. + // This isn't perfect - if a test crashes and cleanup doesn't run, or the cleanup request times + // the datapoints will be left in the index, but the datapoint IDs are sufficiently long and + // random that this shouldn't cause collisions between test runs. It could mean that the index + // will slowly grow over time, so we should figure out how to periodically purge the index, but + // that could cause running instances of this test to fail, so it's not perfect. + // Using ConcurrentHashSet instead of HashSet to support parallelized test cases + private ConcurrentHashSet pendingDatapoints = new ConcurrentHashSet<>(); + + private VectorSearchResourceManager(String projectNumber, String region) throws Exception { + this.projectNumber = projectNumber; + this.region = region; + + this.host = String.format("%s-aiplatform.googleapis.com:443", region); + this.parent = String.format("projects/%s/locations/%s", projectNumber, region); + } + + // Returns a random alphanumeric string, suitable for use as a Datapoint ID or a CBT row key, + // and records it for later deletion. + public String makeDatapointId() { + var id = randomAlphanumeric(20); + pendingDatapoints.add(id); + return id; + } + + /** + * Load an IndexDatapoint by its ID, returning null if the datapoint does not exist. + * + * @param datapointID The datapoint ID to load + * @return An IndexDatapoint, or null + */ + public IndexDatapoint findDatapoint(String datapointID) { + LOG.debug("Finding datapoint {}", datapointID); + + ReadIndexDatapointsRequest request = + ReadIndexDatapointsRequest.newBuilder() + .setIndexEndpoint(testEndpoint.getName()) + .setDeployedIndexId(testDeployedIndex.getId()) + .addAllIds(List.of(datapointID)) + .build(); + + ReadIndexDatapointsResponse response; + + try { + response = matchClient.readIndexDatapoints(request); + } catch (com.google.api.gax.rpc.NotFoundException e) { + LOG.debug("Datapoint {} does not exist (NotFoundException)", datapointID); + return null; + } + + for (var i : response.getDatapointsList()) { + if (i.getDatapointId().equals(datapointID)) { + LOG.debug("Datapoint {} Exists", datapointID); + return i; + } + } + + // If we reach this point, we received a response, but it doesn't have the datapoint we asked + // for, which is probably + // a bug? + LOG.error("Datapoint {} not found in response - probably a bug", datapointID); + + return null; + } + + @Override + public void cleanupAll() { + // Cleanup any datapoints that may have been left by failing tests + deleteDatapoints(this.pendingDatapoints); + + this.matchClient.close(); + this.endpointClient.close(); + this.indexClient.close(); + } + + /** + * Add a datapoint directly to the index. + * + * @param datapointId The datapoint ID + * @param vector Float embeddings. The length of the array must match the dimension of the index + */ + public void addDatapoint(String datapointId, Iterable vector) { + LOG.debug( + "Adding datapoint {} directly to index {} with floats", + datapointId, + testIndex.getName(), + vector); + var dp = + IndexDatapoint.newBuilder().setDatapointId(datapointId).addAllFeatureVector(vector).build(); + + LOG.debug("Doing thing"); + indexClient.upsertDatapoints( + UpsertDatapointsRequest.newBuilder() + .setIndex(testIndex.getName()) + .addAllDatapoints(List.of(dp)) + .build()); + + LOG.debug("Done thing"); + // LOG.debug("Update Mask", dpr.hasUpdateMask()); + + } + + // Cleanup a set of datapoints. Datapoint IDs that don't exist in the index are ignored, so it's + // safe to run this to remove a set of datapoints larger than the set of those that exist. + private void deleteDatapoints(Iterable datapointIds) { + var request = + RemoveDatapointsRequest.newBuilder() + .addAllDatapointIds(datapointIds) + .setIndex(testIndex.getName()) + .build(); + + this.indexClient.removeDatapoints(request); + } + + private void findOrCreateTestInfra() throws Exception { + // Used to poll long-running operations; it can take up to 30 minutes to create an index and + // endpoint and then + // deploy the index to the endpoint. Most of the time is taken on the deploy step. + var poll = + OperationTimedPollAlgorithm.create( + RetrySettings.newBuilder() + .setInitialRetryDelay(Duration.ofMinutes(1)) + .setRetryDelayMultiplier(1.0) + .setMaxRetryDelay(Duration.ofMinutes(1)) + .setInitialRpcTimeout(Duration.ZERO) + .setRpcTimeoutMultiplier(1.0) + .setMaxRpcTimeout(Duration.ZERO) + .setTotalTimeout(Duration.ofMinutes(120)) + .build()); + + var indexSettings = IndexServiceSettings.newBuilder().setEndpoint(host); + indexSettings.createIndexOperationSettings().setPollingAlgorithm(poll); + indexClient = IndexServiceClient.create(indexSettings.build()); + + var endpointSettings = IndexEndpointServiceSettings.newBuilder().setEndpoint(host); + endpointSettings.deployIndexOperationSettings().setPollingAlgorithm(poll); + endpointClient = IndexEndpointServiceClient.create(endpointSettings.build()); + + // All three (index, endpoint, deployment) seem to have to exist before the public domain on the + // endpoint is + // necessarily created? Sometimes it is empty if we try to retrieve it too quickly. + testIndex = findOrCreateIndex(); + testEndpoint = findOrCreateEndpoint(); + testDeployedIndex = findOrCreateDeployedIndex(); + + LOG.debug( + "Creating match client with endpoint {}", + testEndpoint.getPublicEndpointDomainName() + ":443"); + + var matchSettings = + MatchServiceSettings.newBuilder() + .setEndpoint(testEndpoint.getPublicEndpointDomainName() + ":443"); + + matchClient = MatchServiceClient.create(matchSettings.build()); + } + + private Index findOrCreateIndex() throws Exception { + LOG.debug("Doing find-or-create for test index {}", TEST_INDEX_NAME); + for (var i : indexClient.listIndexes(parent).iterateAll()) { + if (i.getDisplayName().equals(TEST_INDEX_NAME)) { + LOG.debug("Using existing index: {}:{}", i.getDisplayName(), i.getName()); + return i; + } else { + LOG.debug("Ignoring index: {}:{}", i.getDisplayName(), i.getName()); + } + } + + LOG.debug("Index {} does not exist, creating", TEST_INDEX_NAME); + return createIndex(TEST_INDEX_NAME); + } + + private IndexEndpoint findOrCreateEndpoint() throws Exception { + LOG.debug("Doing find-or-create for test endpoint {}", TEST_ENDPOINT_NAME); + + for (var e : endpointClient.listIndexEndpoints(parent).iterateAll()) { + if (e.getDisplayName().equals(TEST_ENDPOINT_NAME)) { + LOG.debug("Using existing endpoint: {}:{}", e.getDisplayName(), e.getName()); + return e; + } else { + LOG.debug("Ignoring endpoint {}:{}", e.getDisplayName(), e.getName()); + } + } + + LOG.debug("Endpoint {} does not exist, creating", TEST_ENDPOINT_NAME); + return createEndpoint(TEST_ENDPOINT_NAME); + } + + private DeployedIndex findOrCreateDeployedIndex() throws Exception { + LOG.debug("Doing find-or-create for test index deployment {}", TEST_DEPLOYED_INDEX_NAME); + + for (var d : testEndpoint.getDeployedIndexesList()) { + if (d.getId().equals(TEST_DEPLOYED_INDEX_NAME)) { + LOG.debug("Using existing deployment: {}:{}", d.getDisplayName(), d.getId()); + return d; + } else { + LOG.debug("Ignoring deployment {}:{}", d.getDisplayName(), d.getId()); + } + } + + LOG.debug("DeployedIndex {} does not exist, creating", TEST_DEPLOYED_INDEX_NAME); + return deployIndexToEndpoint(testIndex.getName(), testEndpoint.getName()); + } + + private Index createIndex(String indexName) throws Exception { + // This is somewhat of a black box, copied from an index in a good state. + // The resulting index will have a dimensionality of 10 + final CharSequence indexSchema = + "struct_value { fields { key: \"config\" value { struct_value { fields { key: \"algorithmConfig\" value { struct_value { fields { key: \"treeAhConfig\" value { struct_value { fields { key: \"fractionLeafNodesToSearch\" value { number_value: 0.05 } } fields { key: \"leafNodeEmbeddingCount\" value { string_value: \"1000\" } } } } } } } } fields { key: \"approximateNeighborsCount\" value { number_value: 1.0 } } fields { key: \"dimensions\" value { number_value: 10.0 } } fields { key: \"distanceMeasureType\" value { string_value: \"DOT_PRODUCT_DISTANCE\" } } fields { key: \"featureNormType\" value { string_value: \"NONE\" } } fields { key: \"shardSize\" value { string_value: \"SHARD_SIZE_SMALL\" } } } } } }"; + + var v = com.google.protobuf.Value.newBuilder(); + TextFormat.merge(indexSchema, v); + + var index = + Index.newBuilder() + .setIndexUpdateMethod(Index.IndexUpdateMethod.STREAM_UPDATE) + .setDisplayName(indexName) + .setDescription( + "Used in integration tests by the Bigtable Change Streams to Vector Search template") + .setMetadataSchemaUri( + "gs://google-cloud-aiplatform/schema/matchingengine/metadata/nearest_neighbor_search_1.0.0.yaml") + .setMetadata(v) + .build(); + + var request = CreateIndexRequest.newBuilder().setParent(parent).setIndex(index).build(); + + return indexClient.createIndexAsync(request).get(30, TimeUnit.MINUTES); + } + + private IndexEndpoint createEndpoint(String endpointName) throws Exception { + var endpoint = + IndexEndpoint.newBuilder() + .setDisplayName(endpointName) + .setDescription( + "Test endpoint for Bigtable Change Streams to Vector Search Dataflow Template") + .build(); + + return endpointClient.createIndexEndpointAsync(parent, endpoint).get(30, TimeUnit.MINUTES); + } + + private DeployedIndex deployIndexToEndpoint(String indexID, String endpointID) throws Exception { + DeployIndexRequest request = + DeployIndexRequest.newBuilder() + .setIndexEndpoint(endpointID) + .setDeployedIndex( + DeployedIndex.newBuilder() + .setIndex(indexID) + .setDisplayName("Integration tests") + .setId(TEST_DEPLOYED_INDEX_NAME) + // manually delete them? + .build()) + .build(); + + return endpointClient.deployIndexAsync(request).get().getDeployedIndex(); + } +} diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/DataTypesIt.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/DataTypesIt.java index 4c634b6334..bc753821d7 100644 --- a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/DataTypesIt.java +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/DataTypesIt.java @@ -15,6 +15,7 @@ */ package com.google.cloud.teleport.v2.templates; +import static com.google.common.truth.Truth.assertThat; import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatResult; import com.google.cloud.spanner.Struct; @@ -87,6 +88,8 @@ public void allTypesTest() throws Exception { null); PipelineOperator.Result result = pipelineOperator().waitUntilDone(createConfig(jobInfo)); assertThatResult(result).isLaunchFinished(); + + // Validate supported data types. Map>> expectedData = getExpectedData(); for (Map.Entry>> entry : expectedData.entrySet()) { String type = entry.getKey(); @@ -106,6 +109,21 @@ public void allTypesTest() throws Exception { SpannerAsserts.assertThatStructs(rows) .hasRecordsUnorderedCaseInsensitiveColumns(entry.getValue()); } + + // Validate unsupported types. + List unsupportedTypeTables = + List.of( + "spatial_linestring", + "spatial_multilinestring", + "spatial_multipoint", + "spatial_multipolygon", + "spatial_point", + "spatial_polygon"); + + for (String table : unsupportedTypeTables) { + // Unsupported rows should still be migrated. Each source table has 1 row. + assertThat(spannerResourceManager.getRowCount(table)).isEqualTo(1L); + } } private List> createRows(String colPrefix, Object... values) { diff --git a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/SourceDbToSpannerITBase.java b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/SourceDbToSpannerITBase.java index 3fcc30b072..6c94e897a9 100644 --- a/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/SourceDbToSpannerITBase.java +++ b/v2/sourcedb-to-spanner/src/test/java/com/google/cloud/teleport/v2/templates/SourceDbToSpannerITBase.java @@ -84,6 +84,7 @@ protected void loadSQLToJdbcResourceManager(JDBCResourceManager jdbcResourceMana if (!stmt.trim().isEmpty()) { // Skip SELECT statements if (!stmt.trim().toUpperCase().startsWith("SELECT")) { + LOG.info("Executing statement: {}", stmt); statement.executeUpdate(stmt); } } diff --git a/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/data-types.sql b/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/data-types.sql index 7f66575242..ce458adfda 100644 --- a/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/data-types.sql +++ b/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/data-types.sql @@ -278,5 +278,53 @@ INSERT INTO `year_table` (`year_col`) VALUES (NULL); INSERT INTO set_table (set_col) VALUES (NULL); +CREATE TABLE IF NOT EXISTS spatial_point ( + id INT AUTO_INCREMENT PRIMARY KEY, + location POINT +); + +INSERT INTO spatial_point (location) VALUES (POINT(77.5946, 12.9716)); + + +CREATE TABLE IF NOT EXISTS spatial_linestring ( + id INT AUTO_INCREMENT PRIMARY KEY, + path LINESTRING +); + +INSERT INTO spatial_linestring (path) +VALUES (LineString(Point(77.5946, 12.9716), Point(77.6100, 12.9600))); + +CREATE TABLE IF NOT EXISTS spatial_polygon ( + id INT AUTO_INCREMENT PRIMARY KEY, + area POLYGON +); + +INSERT INTO spatial_polygon (area) +VALUES (Polygon(LineString(Point(77.5946, 12.9716), Point(77.6100, 12.9600), Point(77.6000, 12.9500), Point(77.5946, 12.9716)))); + +CREATE TABLE IF NOT EXISTS spatial_multipoint ( + id INT AUTO_INCREMENT PRIMARY KEY, + points MULTIPOINT +); + +INSERT INTO spatial_multipoint (points) VALUES (MultiPoint(Point(77.5946, 12.9716), Point(77.6100, 12.9600))); + +CREATE TABLE IF NOT EXISTS spatial_multilinestring ( + id INT AUTO_INCREMENT PRIMARY KEY, + paths MULTILINESTRING +); + +INSERT INTO spatial_multilinestring (paths) +VALUES (MultiLineString(LineString(Point(77.5946, 12.9716), Point(77.6100, 12.9600)), LineString(Point(77.6000, 12.9500), Point(77.6200, 12.9400)))); + +CREATE TABLE IF NOT EXISTS spatial_multipolygon ( + id INT AUTO_INCREMENT PRIMARY KEY, + areas MULTIPOLYGON +); + +INSERT INTO spatial_multipolygon (areas) +VALUES (MultiPolygon(Polygon(LineString(Point(77.5946, 12.9716), Point(77.6100, 12.9600), Point(77.6000, 12.9500), Point(77.5946, 12.9716))), + Polygon(LineString(Point(77.6200, 12.9400), Point(77.6300, 12.9300), Point(77.6400, 12.9450), Point(77.6200, 12.9400))))); + diff --git a/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/spanner-schema.sql b/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/spanner-schema.sql index b75eb0f2c6..ceecbde5c1 100644 --- a/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/spanner-schema.sql +++ b/v2/sourcedb-to-spanner/src/test/resources/DataTypesIt/spanner-schema.sql @@ -146,4 +146,34 @@ CREATE TABLE varchar_table ( CREATE TABLE year_table ( id INT64 NOT NULL, year_col STRING(MAX), +) PRIMARY KEY(id); + +CREATE TABLE spatial_linestring ( + id INT64 NOT NULL, + path STRING(MAX), +) PRIMARY KEY(id); + +CREATE TABLE spatial_multilinestring ( + id INT64 NOT NULL, + paths STRING(MAX), +) PRIMARY KEY(id); + +CREATE TABLE spatial_multipoint ( + id INT64 NOT NULL, + points STRING(MAX), +) PRIMARY KEY(id); + +CREATE TABLE spatial_multipolygon ( + id INT64 NOT NULL, + areas STRING(MAX), +) PRIMARY KEY(id); + +CREATE TABLE spatial_point ( + id INT64 NOT NULL, + location STRING(MAX), +) PRIMARY KEY(id); + +CREATE TABLE spatial_polygon ( + id INT64 NOT NULL, + area STRING(MAX), ) PRIMARY KEY(id); \ No newline at end of file diff --git a/v2/sourcedb-to-spanner/terraform/samples/README.md b/v2/sourcedb-to-spanner/terraform/samples/README.md new file mode 100644 index 0000000000..8555597ec8 --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/README.md @@ -0,0 +1,20 @@ +## Terraform samples for bulk migration + +This repository provides samples for common scenarios users might have while trying to run a bulk migration to Spanner. + +Pick a sample that is closest to your use-case, and use it as a starting point, tailoring it to your own specific needs. + +## List of examples + +1. [Launching multiple bulk migration jobs](multiple-jobs/README.md) + +## How to add a new sample + +It is strongly recommended to copy an existing sample and modify it according to the scenario you are trying to cover. +This ensures uniformity in the style in which terraform samples are written. + +```shell +mkdir my-new-sample +cd my-new-sample +cp -r multiple-jobs/ +``` \ No newline at end of file diff --git a/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/README.md b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/README.md new file mode 100644 index 0000000000..b8f31eeed4 --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/README.md @@ -0,0 +1,72 @@ +## Scenario + +This Terraform example illustrates launching multiple bulk migration Datafllow jobs for a MySQL to Spanner migration with the following assumptions - + +1. MySQL source can establish network connectivity with Dataflow. +2. Appropriate permissions are added to the service account running Terraform to allow resource creation. +3. Appropriate permissions are provided to the service account running Dataflow to write to Spanner. +4. A GCS bucket has been provided to write the DLQ records to. + +Given these assumptions, it copies data from multiple source MySQL databases to the configured Spanner database(s). + +## Description + +This sample contains the following files - + +1. `main.tf` - This contains the Terraform resources which will be created. +2. `outputs.tf` - This declares the outputs that will be output as part of running this terraform example. +3. `variables.tf` - This declares the input variables that are required to configure the resources. +4. `terraform.tf` - This contains the required providers for this sample. +5. `terraform.tfvars` - This contains the dummy inputs that need to be populated to run this example. + +## How to run + +1. Clone this repository or the sample locally. +2. Edit the `terraform.tfvars` file and replace the dummy variables with real values. Extend the configuration to meet your needs. +3. Run the following commands - + +### Initialise Terraform + +```shell +# Initialise terraform - You only need to do this once for a directory. +terraform init +``` + +### Run `plan` and `apply` + +Validate the terraform files with - + +```shell +terraform plan +``` + +Run the terraform script with - + +```shell +terraform apply +``` + +This will launch the configured jobs and produce an output like below - + +```shell +Apply complete! Resources: 1 added, 0 changed, 0 destroyed. + +Outputs: + +dataflow_job_ids = [ + "2024-06-05_00_41_11-4759981257849547781", +] +dataflow_job_urls = [ + "https://console.cloud.google.com/dataflow/jobs/us-central1/2024-06-05_00_41_11-4759981257849547781", +] +``` + +**Note:** Each of the jobs will have a random suffix added to it to prevent name collisions. + +### Cleanup + +Once the jobs have finished running, you can cleanup by running - + +```shell +terraform destroy +``` diff --git a/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/main.tf b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/main.tf new file mode 100644 index 0000000000..4f3c5d10b2 --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/main.tf @@ -0,0 +1,63 @@ +provider "google-beta" { + project = var.common_params.project + region = var.common_params.region +} + +resource "google_project_service" "required" { + service = "dataflow.googleapis.com" + project = var.common_params.project + disable_on_destroy = false +} + +resource "random_pet" "job_name_suffix" { + count = length(var.jobs) +} + +resource "google_dataflow_flex_template_job" "generated" { + count = length(var.jobs) + depends_on = [google_project_service.required] + provider = google-beta + container_spec_gcs_path = "gs://dataflow-templates-${var.common_params.region}/latest/flex/Sourcedb_to_Spanner_Flex" + + parameters = { + jdbcDriverJars = var.common_params.jdbcDriverJars + jdbcDriverClassName = var.common_params.jdbcDriverClassName + sourceDbURL = var.jobs[count.index].sourceDbURL + username = var.jobs[count.index].username + password = var.jobs[count.index].password + tables = var.jobs[count.index].tables + numPartitions = tostring(var.jobs[count.index].numPartitions) + instanceId = var.jobs[count.index].instanceId + databaseId = var.jobs[count.index].databaseId + projectId = var.common_params.projectId + spannerHost = var.common_params.spannerHost + maxConnections = tostring(var.jobs[count.index].maxConnections) + sessionFilePath = var.common_params.sessionFilePath + DLQDirectory = var.jobs[count.index].DLQDirectory + disabledAlgorithms = var.common_params.disabledAlgorithms + extraFilesToStage = var.common_params.extraFilesToStage + defaultLogLevel = var.jobs[count.index].defaultLogLevel + } + + additional_experiments = var.common_params.additional_experiments + autoscaling_algorithm = var.common_params.autoscaling_algorithm + enable_streaming_engine = var.common_params.enable_streaming_engine + ip_configuration = var.jobs[count.index].ip_configuration + kms_key_name = var.jobs[count.index].kms_key_name + labels = var.jobs[count.index].labels + launcher_machine_type = var.jobs[count.index].launcher_machine_type + machine_type = var.jobs[count.index].machine_type + max_workers = var.jobs[count.index].max_workers + name = "${var.jobs[count.index].name}-${random_pet.job_name_suffix[count.index].id}" + network = var.common_params.network + subnetwork = var.common_params.subnetwork + num_workers = var.jobs[count.index].num_workers + on_delete = var.common_params.on_delete + project = var.common_params.project + region = var.common_params.region + sdk_container_image = var.common_params.sdk_container_image + service_account_email = var.common_params.service_account_email + skip_wait_on_job_termination = var.common_params.skip_wait_on_job_termination + staging_location = var.common_params.staging_location + temp_location = var.common_params.temp_location +} \ No newline at end of file diff --git a/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/outputs.tf b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/outputs.tf new file mode 100644 index 0000000000..ea62f8a37d --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/outputs.tf @@ -0,0 +1,10 @@ +output "dataflow_job_ids" { + value = [for job in google_dataflow_flex_template_job.generated : job.job_id] + description = "List of job IDs for the created Dataflow Flex Template jobs." +} + +output "dataflow_job_urls" { + value = [for job in google_dataflow_flex_template_job.generated : "https://console.cloud.google.com/dataflow/jobs/${var.common_params.region}/${job.job_id}"] + description = "List of URLs for the created Dataflow Flex Template jobs." +} + diff --git a/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tf b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tf new file mode 100644 index 0000000000..a05e43ce70 --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tf @@ -0,0 +1,13 @@ +terraform { + required_providers { + google = { + source = "hashicorp/google" + version = "~> 4.0" + } + random = { + source = "hashicorp/random" + version = "~> 3.0" # Or the latest compatible version + } + } + required_version = "~>1.2" +} \ No newline at end of file diff --git a/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tfvars b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tfvars new file mode 100644 index 0000000000..f2d57ffa3f --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform.tfvars @@ -0,0 +1,63 @@ +common_params = { + on_delete = "drain" # Or "cancel" if you prefer + project = "your-google-cloud-project-id" + region = "us-central1" # Or your desired region + jdbcDriverJars = "gs://your-bucket/driver_jar1.jar,gs://your-bucket/driver_jar2.jar" + jdbcDriverClassName = "com.mysql.jdbc.Driver" + projectId = "your-cloud-spanner-project-id" + spannerHost = "https://batch-spanner.googleapis.com" + sessionFilePath = "gs://your-bucket/session-file.json" + extraFilesToStage = "gs://your-bucket/extra-file.txt" + additional_experiments = ["enable_stackdriver_agent_metrics"] + autoscaling_algorithm = "THROUGHPUT_BASED" + enable_streaming_engine = true + network = "default" + subnetwork = "regions/us-central1/subnetworks/your-subnetwork" + sdk_container_image = "gcr.io/dataflow-templates/latest/flex/java11" + service_account_email = "your-service-account-email@your-project-id.iam.gserviceaccount.com" + skip_wait_on_job_termination = false + staging_location = "gs://your-staging-bucket" + temp_location = "gs://your-temp-bucket" +} + +jobs = [ + { + instanceId = "your-spanner-instance-id" + databaseId = "your-spanner-database-id" + sourceDbURL = "jdbc:mysql://127.0.0.1/my-db?autoReconnect=true&maxReconnects=10&unicode=true&characterEncoding=UTF-8" + username = "your-db-username" + password = "your-db-password" + tables = "table1,table2" + numPartitions = 200 + maxConnections = 50 + DLQDirectory = "gs://your-dlq-bucket/job1-dlq" + defaultLogLevel = "INFO" + ip_configuration = "WORKER_IP_PRIVATE" + kms_key_name = "projects/your-project-id/locations/global/keyRings/your-key-ring/cryptoKeys/your-key" + launcher_machine_type = "n1-standard-2" + machine_type = "n1-standard-2" + max_workers = 10 + name = "bulk-migration-job" + num_workers = 5 + }, + { + instanceId = "your-spanner-instance-id" + databaseId = "your-spanner-database-id" + sourceDbURL = "jdbc:mysql://another-db-host:3306/different-db" + username = "another-username" + password = "another-password" + tables = "table1,table2" + numPartitions = 200 + maxConnections = 25 + DLQDirectory = "gs://your-dlq-bucket/job2-dlq" + defaultLogLevel = "DEBUG" + ip_configuration = "WORKER_IP_PRIVATE" + kms_key_name = "projects/your-project-id/locations/global/keyRings/your-key-ring/cryptoKeys/your-key" + launcher_machine_type = "n1-standard-4" + machine_type = "n1-standard-4" + max_workers = 20 + name = "job2-orders-migration" + num_workers = 10 + } + # ... Add more job configurations as needed +] diff --git a/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform_simple.tfvars b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform_simple.tfvars new file mode 100644 index 0000000000..58e4efb5c1 --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/terraform_simple.tfvars @@ -0,0 +1,34 @@ +# Below is a simplified version of terraform.tfvars which only configures the +# most commonly set properties. +# It creates two bulk migration jobs with public IPs in the default network. + +common_params = { + on_delete = "cancel" # Or "cancel" if you prefer + project = "your-google-cloud-project-id" + region = "us-central1" # Or your desired region + projectId = "your-google-cloud-project-id" + service_account_email = "your-project-id-compute@developer.gserviceaccount.com" +} + +jobs = [ + { + instanceId = "your-spanner-instance-id" + databaseId = "your-spanner-database-id" + sourceDbURL = "jdbc:mysql://127.0.0.1/my-db?autoReconnect=true&maxReconnects=10&unicode=true&characterEncoding=UTF-8" + username = "your-db-username" + password = "your-db-password" + DLQDirectory = "gs://your-dlq-bucket/dlq1" + max_workers = 2 + num_workers = 2 + }, + { + instanceId = "your-spanner-instance-id" + databaseId = "your-spanner-database-id" + sourceDbURL = "jdbc:mysql://127.0.0.1/my-db?autoReconnect=true&maxReconnects=10&unicode=true&characterEncoding=UTF-8" + username = "your-db-username" + password = "your-db-password" + DLQDirectory = "gs://your-dlq-bucket/dlq2" + max_workers = 2 + num_workers = 2 + }, +] diff --git a/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/variables.tf b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/variables.tf new file mode 100644 index 0000000000..bb3121bfac --- /dev/null +++ b/v2/sourcedb-to-spanner/terraform/samples/multiple-jobs/variables.tf @@ -0,0 +1,49 @@ +variable "common_params" { + description = "Parameters which are common across jobs. Please refer to https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v2/sourcedb-to-spanner/README_Sourcedb_to_Spanner_Flex.md for the description of the parameters below." + type = object({ + on_delete = optional(string, "drain") + project = string + region = string + jdbcDriverJars = optional(string) + jdbcDriverClassName = optional(string) + projectId = string + spannerHost = optional(string, "https://batch-spanner.googleapis.com") + sessionFilePath = optional(string) + disabledAlgorithms = optional(string) + extraFilesToStage = optional(string) + additional_experiments = optional(set(string)) + autoscaling_algorithm = optional(string) + enable_streaming_engine = optional(bool) + network = optional(string) + subnetwork = optional(string) + sdk_container_image = optional(string) + service_account_email = optional(string) + skip_wait_on_job_termination = optional(bool, true) + staging_location = optional(string) + temp_location = optional(string) + }) +} + +variable "jobs" { + description = "List of job configurations. Please refer to https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v2/sourcedb-to-spanner/README_Sourcedb_to_Spanner_Flex.md for the description of the parameters below." + type = list(object({ + instanceId = string + databaseId = string + sourceDbURL = string + username = string + password = string + tables = optional(string) + numPartitions = optional(string) + maxConnections = optional(number, 0) + DLQDirectory = string + defaultLogLevel = optional(string, "INFO") + ip_configuration = optional(string) + kms_key_name = optional(string) + labels = optional(map(string)) + launcher_machine_type = optional(string) + machine_type = optional(string) + max_workers = optional(number) + name = optional(string, "bulk-migration-job") + num_workers = optional(number) + })) +} diff --git a/v2/spanner-change-streams-to-sharded-file-sink/src/test/java/com/google/cloud/teleport/v2/templates/SpannerChangeStreamToGcsMultiShardIT.java b/v2/spanner-change-streams-to-sharded-file-sink/src/test/java/com/google/cloud/teleport/v2/templates/SpannerChangeStreamToGcsMultiShardIT.java index 882d7a48be..2e9d52d4f4 100644 --- a/v2/spanner-change-streams-to-sharded-file-sink/src/test/java/com/google/cloud/teleport/v2/templates/SpannerChangeStreamToGcsMultiShardIT.java +++ b/v2/spanner-change-streams-to-sharded-file-sink/src/test/java/com/google/cloud/teleport/v2/templates/SpannerChangeStreamToGcsMultiShardIT.java @@ -19,6 +19,8 @@ import static org.apache.beam.it.truthmatchers.PipelineAsserts.assertThatResult; import com.google.cloud.spanner.Mutation; +import com.google.cloud.spanner.Options; +import com.google.cloud.spanner.TransactionRunner.TransactionCallable; import com.google.cloud.teleport.metadata.SkipDirectRunnerTest; import com.google.cloud.teleport.metadata.TemplateIntegrationTest; import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard; @@ -40,6 +42,8 @@ import org.apache.beam.it.gcp.spanner.conditions.SpannerRowsCheck; import org.apache.beam.it.gcp.storage.GcsResourceManager; import org.apache.beam.it.gcp.storage.conditions.GCSArtifactsCheck; +import org.apache.beam.sdk.io.gcp.spanner.SpannerAccessor; +import org.apache.beam.sdk.io.gcp.spanner.SpannerConfig; import org.junit.AfterClass; import org.junit.Before; import org.junit.Test; @@ -219,4 +223,81 @@ private void assertFileContentsInGCSForMultipleShards() { assertThatArtifacts(artifactsShardC).hasContent("SingerId\\\":\\\"4"); assertThatArtifacts(artifactsShardC).hasContent("SingerId\\\":\\\"5"); } + + @Test + public void testForwardMigrationFiltered() throws IOException, java.lang.InterruptedException { + // Construct a ChainedConditionCheck with below stages. + // 1. Wait for the metadata table to have the start time of reader job + // 2. Write 1 records per shard to Spanner with the transaction tag as txBy= + // 3. Wait and check there are no files in GCS for that shard + ChainedConditionCheck conditionCheck = + ChainedConditionCheck.builder( + List.of( + SpannerRowsCheck.builder( + spannerMetadataResourceManager, "spanner_to_gcs_metadata") + .setMinRows(1) + .setMaxRows(1) + .build())) + .build(); + // Wait for conditions + PipelineOperator.Result result = + pipelineOperator() + .waitForCondition(createConfig(jobInfo, Duration.ofMinutes(10)), conditionCheck); + // Assert Conditions + assertThatResult(result).meetsConditions(); + // Perform writes to Spanner + writeSpannerDataForForwardMigration(7, "seven", "testShardD"); + // Assert file present in GCS with the needed data + assertFileContentsInGCSForFilteredRecords(); + } + + private void assertFileContentsInGCSForFilteredRecords() { + ChainedConditionCheck conditionCheck = + ChainedConditionCheck.builder( + List.of( + GCSArtifactsCheck.builder( + gcsResourceManager, "output/testShardD/", Pattern.compile(".*\\.txt$")) + .setMinSize(1) + .setMaxSize(1) + .build())) + .build(); + + PipelineOperator.Result result = + pipelineOperator() + .waitForCondition(createConfig(jobInfo, Duration.ofMinutes(6)), conditionCheck); + + // Assert Conditions + assertThatResult(result).hasTimedOut(); + } + + private void writeSpannerDataForForwardMigration(int singerId, String firstName, String shardId) { + // Write a single record to Spanner for the given logical shard + // Add the record with the transaction tag as txBy= + SpannerConfig spannerConfig = + SpannerConfig.create() + .withProjectId(PROJECT) + .withInstanceId(spannerResourceManager.getInstanceId()) + .withDatabaseId(spannerResourceManager.getDatabaseId()); + SpannerAccessor spannerAccessor = SpannerAccessor.getOrCreate(spannerConfig); + spannerAccessor + .getDatabaseClient() + .readWriteTransaction( + Options.tag("txBy=forwardMigration"), + Options.priority(spannerConfig.getRpcPriority().get())) + .run( + (TransactionCallable) + transaction -> { + Mutation m = + Mutation.newInsertOrUpdateBuilder("Singers") + .set("SingerId") + .to(singerId) + .set("FirstName") + .to(firstName) + .set("migration_shard_id") + .to(shardId) + .build(); + transaction.buffer(m); + return null; + }); + } } diff --git a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertor.java b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertor.java index 266ebe1dfd..127fd2c2fb 100644 --- a/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertor.java +++ b/v2/spanner-common/src/main/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertor.java @@ -153,6 +153,7 @@ static class CustomAvroTypes { public static final String VARCHAR = "varchar"; public static final String NUMBER = "number"; public static final String JSON = "json"; + public static final String UNSUPPORTED = "unsupported"; } /** Avro logical types are converted to an equivalent string type. */ @@ -199,6 +200,9 @@ static String handleLogicalFieldType(String fieldName, Object recordValue, Schem } else if (fieldSchema.getLogicalType() != null && fieldSchema.getLogicalType().getName().equals(CustomAvroTypes.VARCHAR)) { return recordValue.toString(); + } else if (fieldSchema.getLogicalType() != null + && fieldSchema.getLogicalType().getName().equals(CustomAvroTypes.UNSUPPORTED)) { + return null; } else { LOG.error("Unknown field type {} for field {} in {}.", fieldSchema, fieldName, recordValue); throw new UnsupportedOperationException( diff --git a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertorTest.java b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertorTest.java index 87fc63ef94..802875683c 100644 --- a/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertorTest.java +++ b/v2/spanner-common/src/test/java/com/google/cloud/teleport/v2/spanner/migrations/avro/GenericRecordTypeConvertorTest.java @@ -65,6 +65,9 @@ public Schema getLogicalTypesSchema() { Schema varcharType = new LogicalType(GenericRecordTypeConvertor.CustomAvroTypes.VARCHAR) .addToSchema(SchemaBuilder.builder().stringType()); + Schema unsupportedType = + new LogicalType(GenericRecordTypeConvertor.CustomAvroTypes.UNSUPPORTED) + .addToSchema(SchemaBuilder.builder().nullType()); // Build the schema using the created types return SchemaBuilder.record("logicalTypes") @@ -97,6 +100,9 @@ public Schema getLogicalTypesSchema() { .name("varchar_col") .type(varcharType) .noDefault() + .name("unsupported_col") + .type(unsupportedType) + .noDefault() .endRecord(); } @@ -156,6 +162,7 @@ public void testHandleLogicalFieldType() { genericRecord.put("json_col", "{\"k1\":\"v1\"}"); genericRecord.put("number_col", "289452"); genericRecord.put("varchar_col", "Hellogcds"); + genericRecord.put("unsupported_col", null); String col = "date_col"; String result = @@ -210,6 +217,12 @@ public void testHandleLogicalFieldType() { GenericRecordTypeConvertor.handleLogicalFieldType( col, genericRecord.get(col), genericRecord.getSchema().getField(col).schema()); assertEquals("Test varchar_col conversion: ", "Hellogcds", result); + + col = "unsupported_col"; + result = + GenericRecordTypeConvertor.handleLogicalFieldType( + col, genericRecord.get(col), genericRecord.getSchema().getField(col).schema()); + assertEquals("Test unsupported_col conversion: ", null, result); } @Test