diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 1eb60f6e4959..9e1d1e1b80dd 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 3 + "modification": 4 } diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 382ae123a81d..ea98fb6b0bbd 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -171,6 +171,14 @@ def _execute_query(self, query: str): except RuntimeError as e: raise RuntimeError(f"Could not complete the query request: {query}. {e}") + def create_row_key(self, row: beam.Row): + if self.condition_value_fn: + return tuple(self.condition_value_fn(row)) + if self.fields: + row_dict = row._asdict() + return (tuple(row_dict[field] for field in self.fields)) + raise ValueError("Either fields or condition_value_fn must be specified") + def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): if isinstance(request, List): values = [] @@ -180,7 +188,7 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): raw_query = self.query_template if batch_size > 1: batched_condition_template = ' or '.join( - [self.row_restriction_template] * batch_size) + [fr'({self.row_restriction_template})'] * batch_size) raw_query = self.query_template.replace( self.row_restriction_template, batched_condition_template) for req in request: @@ -194,14 +202,15 @@ def __call__(self, request: Union[beam.Row, List[beam.Row]], *args, **kwargs): "Make sure the values passed in `fields` are the " "keys in the input `beam.Row`." + str(e)) values.extend(current_values) - requests_map.update((val, req) for val in current_values) + requests_map[self.create_row_key(req)] = req query = raw_query.format(*values) responses_dict = self._execute_query(query) for response in responses_dict: - for value in response.values(): - if value in requests_map: - responses.append((requests_map[value], beam.Row(**response))) + response_row = beam.Row(**response) + response_key = self.create_row_key(response_row) + if response_key in requests_map: + responses.append((requests_map[response_key], response_row)) return responses else: request_dict = request._asdict() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py index 0b8a384b934d..dd99e386555e 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_it_test.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import functools import logging +import secrets +import time import unittest from unittest.mock import MagicMock @@ -22,7 +25,11 @@ import apache_beam as beam from apache_beam.coders import coders +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to # pylint: disable=ungrouped-imports try: @@ -31,8 +38,7 @@ from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigquery import \ BigQueryEnrichmentHandler - from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store_it_test import \ - ValidateResponse + from apitools.base.py.exceptions import HttpError except ImportError: raise unittest.SkipTest( 'Google Cloud BigQuery dependencies are not installed.') @@ -40,24 +46,101 @@ _LOGGER = logging.getLogger(__name__) -def query_fn(row: beam.Row): - query = ( - "SELECT * FROM " - "`apache-beam-testing.my_ecommerce.product_details`" - " WHERE id = '{}'".format(row.id)) # type: ignore[attr-defined] - return query - - def condition_value_fn(row: beam.Row): return [row.id] # type: ignore[attr-defined] +def query_fn(table, row: beam.Row): + return f"SELECT * FROM `{table}` WHERE id = {row.id}" # type: ignore[attr-defined] + + +@pytest.mark.uses_testcontainer +class BigQueryEnrichmentIT(unittest.TestCase): + bigquery_dataset_id = 'python_enrichment_transform_read_table_' + project = "apache-beam-testing" + + @classmethod + def setUpClass(cls): + cls.bigquery_client = BigQueryWrapper() + cls.dataset_id = '%s%d%s' % ( + cls.bigquery_dataset_id, int(time.time()), secrets.token_hex(3)) + cls.bigquery_client.get_or_create_dataset(cls.project, cls.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", cls.dataset_id, cls.project) + + @classmethod + def tearDownClass(cls): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=cls.project, datasetId=cls.dataset_id, deleteContents=True) + try: + _LOGGER.debug( + "Deleting dataset %s in project %s", cls.dataset_id, cls.project) + cls.bigquery_client.client.datasets.Delete(request) + except HttpError: + _LOGGER.warning( + 'Failed to clean up dataset %s in project %s', + cls.dataset_id, + cls.project) + + @pytest.mark.uses_testcontainer -class TestBigQueryEnrichmentIT(unittest.TestCase): +class TestBigQueryEnrichmentIT(BigQueryEnrichmentIT): + table_data = [ + { + "id": 1, "name": "A", 'quantity': 2, 'distribution_center_id': 3 + }, + { + "id": 2, "name": "B", 'quantity': 3, 'distribution_center_id': 1 + }, + { + "id": 3, "name": "C", 'quantity': 10, 'distribution_center_id': 4 + }, + { + "id": 4, "name": "D", 'quantity': 1, 'distribution_center_id': 3 + }, + { + "id": 5, "name": "C", 'quantity': 100, 'distribution_center_id': 4 + }, + { + "id": 6, "name": "D", 'quantity': 11, 'distribution_center_id': 3 + }, + { + "id": 7, "name": "C", 'quantity': 7, 'distribution_center_id': 1 + }, + { + "id": 8, "name": "D", 'quantity': 4, 'distribution_center_id': 1 + }, + ] + + @classmethod + def create_table(cls, table_name): + fields = [('id', 'INTEGER'), ('name', 'STRING'), ('quantity', 'INTEGER'), + ('distribution_center_id', 'INTEGER')] + table_schema = bigquery.TableSchema() + for name, field_type in fields: + table_field = bigquery.TableFieldSchema() + table_field.name = name + table_field.type = field_type + table_schema.fields.append(table_field) + table = bigquery.Table( + tableReference=bigquery.TableReference( + projectId=cls.project, datasetId=cls.dataset_id, + tableId=table_name), + schema=table_schema) + request = bigquery.BigqueryTablesInsertRequest( + projectId=cls.project, datasetId=cls.dataset_id, table=table) + cls.bigquery_client.client.tables.Insert(request) + cls.bigquery_client.insert_rows( + cls.project, cls.dataset_id, table_name, cls.table_data) + cls.table_name = f"{cls.project}.{cls.dataset_id}.{table_name}" + + @classmethod + def setUpClass(cls): + super(TestBigQueryEnrichmentIT, cls).setUpClass() + cls.create_table('product_details') + def setUp(self) -> None: - self.project = 'apache-beam-testing' - self.condition_template = "id = '{}'" - self.table_name = "`apache-beam-testing.my_ecommerce.product_details`" + self.condition_template = "id = {}" self.retries = 3 self._start_container() @@ -82,123 +165,119 @@ def tearDown(self) -> None: self.client = None def test_bigquery_enrichment(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] fields = ['id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, - row_restriction_template=self.condition_template, + row_restriction_template="id = {}", table_name=self.table_name, fields=fields, - min_batch_size=2, + min_batch_size=1, max_batch_size=100, ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_query_fn(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_batched(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] + fields = ['id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] - handler = BigQueryEnrichmentHandler(project=self.project, query_fn=query_fn) + handler = BigQueryEnrichmentHandler( + project=self.project, + row_restriction_template="id = {}", + table_name=self.table_name, + fields=fields, + min_batch_size=2, + max_batch_size=100, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_condition_value_fn(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_batched_multiple_fields(self): + expected_rows = [ + beam.Row(id=1, distribution_center_id=3, name="A", quantity=2), + beam.Row(id=2, distribution_center_id=1, name="B", quantity=3) ] + fields = ['id', 'distribution_center_id'] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, distribution_center_id=3), + beam.Row(id=2, distribution_center_id=1), ] handler = BigQueryEnrichmentHandler( project=self.project, - row_restriction_template=self.condition_template, + row_restriction_template="id = {} AND distribution_center_id = {}", table_name=self.table_name, - condition_value_fn=condition_value_fn, - min_batch_size=2, + fields=fields, + min_batch_size=8, max_batch_size=100, ) + with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_bigquery_enrichment_with_condition_without_batch(self): - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_with_query_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + fn = functools.partial(query_fn, self.table_name) + handler = BigQueryEnrichmentHandler(project=self.project, query_fn=fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) + + def test_bigquery_enrichment_with_condition_value_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) + ] + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, row_restriction_template=self.condition_template, table_name=self.table_name, condition_value_fn=condition_value_fn, + min_batch_size=2, + max_batch_size=100, ) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | beam.Create(requests) - | Enrichment(handler) - | beam.ParDo(ValidateResponse(expected_fields))) + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) def test_bigquery_enrichment_bad_request(self): requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] handler = BigQueryEnrichmentHandler( project=self.project, @@ -231,18 +310,13 @@ def test_bigquery_enrichment_with_redis(self): requests. Since all requests are cached, it will return from there without making calls to the BigQuery service. """ - expected_fields = [ - 'id', 'name', 'quantity', 'category', 'brand', 'cost', 'retail_price' - ] requests = [ - beam.Row( - id='13842', - name='low profile dyed cotton twill cap - navy w39s55d', - quantity=2), - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] handler = BigQueryEnrichmentHandler( project=self.project, @@ -253,11 +327,12 @@ def test_bigquery_enrichment_with_redis(self): max_batch_size=100, ) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_populate_cache = ( test_pipeline | beam.Create(requests) - | Enrichment(handler).with_redis_cache(self.host, self.port) - | beam.ParDo(ValidateResponse(expected_fields))) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_populate_cache, equal_to(expected_rows)) # manually check cache entry c = coders.StrUtf8Coder() @@ -268,20 +343,15 @@ def test_bigquery_enrichment_with_redis(self): raise ValueError("No cache entry found for %s" % key) actual = BigQueryEnrichmentHandler.__call__ - BigQueryEnrichmentHandler.__call__ = MagicMock( - return_value=( - beam.Row( - id='15816', - name='low profile dyed cotton twill cap - putty w39s55d', - quantity=1), - beam.Row())) + BigQueryEnrichmentHandler.__call__ = MagicMock(return_value=(beam.Row())) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_cached = ( test_pipeline | beam.Create(requests) - | Enrichment(handler).with_redis_cache(self.host, self.port) - | beam.ParDo(ValidateResponse(expected_fields))) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_cached, equal_to(expected_rows)) BigQueryEnrichmentHandler.__call__ = actual