diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index e309b23b54cc..42529ca9eefb 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -3042,8 +3042,9 @@ public Write withPropagateSuccessfulStorageApiWrites( } /** - * If set to true, then all successful writes will be propagated to {@link WriteResult} and - * accessible via the {@link WriteResult#getSuccessfulStorageApiInserts} method. + * If called, then all successful writes will be propagated to {@link WriteResult} and + * accessible via the {@link WriteResult#getSuccessfulStorageApiInserts} method. The predicate + * allows filtering out columns from appearing in the resulting PCollection. */ public Write withPropagateSuccessfulStorageApiWrites(Predicate columnsToPropagate) { return toBuilder() diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java index a18b52dd834c..8c2c035f2b19 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/TableRowToStorageApiProto.java @@ -1095,16 +1095,27 @@ private static long toEpochMicros(Instant timestamp) { @VisibleForTesting public static TableRow tableRowFromMessage( Message message, boolean includeCdcColumns, Predicate includeField) { + return tableRowFromMessage(message, includeCdcColumns, includeField, ""); + } + + public static TableRow tableRowFromMessage( + Message message, + boolean includeCdcColumns, + Predicate includeField, + String namePrefix) { // TODO: Would be more correct to generate TableRows using setF. TableRow tableRow = new TableRow(); for (Map.Entry field : message.getAllFields().entrySet()) { + StringBuilder fullName = new StringBuilder(); FieldDescriptor fieldDescriptor = field.getKey(); + fullName = fullName.append(namePrefix).append(fieldDescriptor.getName()); Object fieldValue = field.getValue(); - if ((includeCdcColumns || !StorageApiCDC.COLUMNS.contains(fieldDescriptor.getName())) + if ((includeCdcColumns || !StorageApiCDC.COLUMNS.contains(fullName.toString())) && includeField.test(fieldDescriptor.getName())) { tableRow.put( fieldDescriptor.getName(), - jsonValueFromMessageValue(fieldDescriptor, fieldValue, true, includeField)); + jsonValueFromMessageValue( + fieldDescriptor, fieldValue, true, includeField, fullName.append(".").toString())); } } return tableRow; @@ -1114,18 +1125,19 @@ public static Object jsonValueFromMessageValue( FieldDescriptor fieldDescriptor, Object fieldValue, boolean expandRepeated, - Predicate includeField) { + Predicate includeField, + String prefix) { if (expandRepeated && fieldDescriptor.isRepeated()) { List valueList = (List) fieldValue; return valueList.stream() - .map(v -> jsonValueFromMessageValue(fieldDescriptor, v, false, includeField)) + .map(v -> jsonValueFromMessageValue(fieldDescriptor, v, false, includeField, prefix)) .collect(toList()); } switch (fieldDescriptor.getType()) { case GROUP: case MESSAGE: - return tableRowFromMessage((Message) fieldValue, false, includeField); + return tableRowFromMessage((Message) fieldValue, false, includeField, prefix); case BYTES: return BaseEncoding.base64().encode(((ByteString) fieldValue).toByteArray()); case ENUM: diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index 2736ed7beb88..58b6746ea5c4 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -35,6 +35,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; import static org.junit.Assume.assumeTrue; import com.google.api.core.ApiFuture; @@ -80,6 +81,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.function.Function; import java.util.function.LongFunction; +import java.util.function.Predicate; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -1593,6 +1595,91 @@ public void testStreamingStorageApiWriteWithAutoShardingWithErrorHandling() thro storageWriteWithErrorHandling(true); } + private void storageWriteWithSuccessHandling(boolean columnSubset) throws Exception { + assumeTrue(useStorageApi); + if (!useStreaming) { + assumeFalse(useStorageApiApproximate); + } + List elements = + IntStream.range(0, 30) + .mapToObj(Integer::toString) + .map(i -> new TableRow().set("number", i).set("string", i)) + .collect(Collectors.toList()); + + List expectedSuccessElements = elements; + if (columnSubset) { + expectedSuccessElements = + elements.stream() + .map(tr -> new TableRow().set("number", tr.get("number"))) + .collect(Collectors.toList()); + } + + TableSchema tableSchema = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("number").setType("INTEGER"), + new TableFieldSchema().setName("string").setType("STRING"))); + + TestStream testStream = + TestStream.create(TableRowJsonCoder.of()) + .addElements( + elements.get(0), Iterables.toArray(elements.subList(1, 10), TableRow.class)) + .advanceProcessingTime(Duration.standardMinutes(1)) + .addElements( + elements.get(10), Iterables.toArray(elements.subList(11, 20), TableRow.class)) + .advanceProcessingTime(Duration.standardMinutes(1)) + .addElements( + elements.get(20), Iterables.toArray(elements.subList(21, 30), TableRow.class)) + .advanceWatermarkToInfinity(); + + BigQueryIO.Write write = + BigQueryIO.writeTableRows() + .to("project-id:dataset-id.table-id") + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withSchema(tableSchema) + .withMethod(Method.STORAGE_WRITE_API) + .withTestServices(fakeBqServices) + .withPropagateSuccessfulStorageApiWrites(true) + .withoutValidation(); + if (columnSubset) { + write = + write.withPropagateSuccessfulStorageApiWrites( + (Serializable & Predicate) s -> s.equals("number")); + } + if (useStreaming) { + if (useStorageApiApproximate) { + write = write.withMethod(Method.STORAGE_API_AT_LEAST_ONCE); + } else { + write = write.withAutoSharding(); + } + } + + PTransform> source = + useStreaming ? testStream : Create.of(elements).withCoder(TableRowJsonCoder.of()); + PCollection success = + p.apply(source).apply("WriteToBQ", write).getSuccessfulStorageApiInserts(); + + PAssert.that(success) + .containsInAnyOrder(Iterables.toArray(expectedSuccessElements, TableRow.class)); + + p.run().waitUntilFinish(); + + assertThat( + fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), + containsInAnyOrder(Iterables.toArray(elements, TableRow.class))); + } + + @Test + public void testStorageApiWriteWithSuccessfulRows() throws Exception { + storageWriteWithSuccessHandling(false); + } + + @Test + public void testStorageApiWriteWithSuccessfulRowsColumnSubset() throws Exception { + storageWriteWithSuccessHandling(true); + } + @DefaultSchema(JavaFieldSchema.class) static class SchemaPojo { final String name;