Skip to content

Commit

Permalink
A schema transform implementation for SpannerIO.Write (#24278)
Browse files Browse the repository at this point in the history
* A schema transform implementation for SpannerIO.Write

* fixup

* fixup

* fixup

* fixup

* fixup and comments

* fixup

* fixup
  • Loading branch information
pabloem authored Nov 24, 2022
1 parent 3b7e181 commit 09606de
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private static Key createKeyFromBeamRow(Row row) {
return builder.build();
}

private static Mutation createMutationFromBeamRows(
public static Mutation createMutationFromBeamRows(
Mutation.WriteBuilder mutationBuilder, Row row) {
Schema schema = row.getSchema();
List<String> columns = schema.getFieldNames();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.spanner;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import com.google.cloud.spanner.Mutation;
import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.transforms.FlatMapElements;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;

@AutoService(SchemaTransformProvider.class)
public class SpannerWriteSchemaTransformProvider
extends TypedSchemaTransformProvider<
SpannerWriteSchemaTransformProvider.SpannerWriteSchemaTransformConfiguration> {

@Override
protected @UnknownKeyFor @NonNull @Initialized Class<SpannerWriteSchemaTransformConfiguration>
configurationClass() {
return SpannerWriteSchemaTransformConfiguration.class;
}

@Override
protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from(
SpannerWriteSchemaTransformConfiguration configuration) {
return new SpannerSchemaTransformWrite(configuration);
}

static class SpannerSchemaTransformWrite implements SchemaTransform, Serializable {
private final SpannerWriteSchemaTransformConfiguration configuration;

SpannerSchemaTransformWrite(SpannerWriteSchemaTransformConfiguration configuration) {
this.configuration = configuration;
}

@Override
public @UnknownKeyFor @NonNull @Initialized PTransform<
@UnknownKeyFor @NonNull @Initialized PCollectionRowTuple,
@UnknownKeyFor @NonNull @Initialized PCollectionRowTuple>
buildTransform() {
// TODO: For now we are allowing ourselves to fail at runtime, but we could
// perform validations here at expansion time. This TODO is to add a few
// validations (e.g. table/database/instance existence, schema match, etc).
return new PTransform<@NonNull PCollectionRowTuple, @NonNull PCollectionRowTuple>() {
@Override
public PCollectionRowTuple expand(@NonNull PCollectionRowTuple input) {
SpannerWriteResult result =
input
.get("input")
.apply(
MapElements.via(
new SimpleFunction<Row, Mutation>(
row ->
MutationUtils.createMutationFromBeamRows(
Mutation.newInsertOrUpdateBuilder(configuration.getTableId()),
Objects.requireNonNull(row))) {}))
.apply(
SpannerIO.write()
.withDatabaseId(configuration.getDatabaseId())
.withInstanceId(configuration.getInstanceId())
.withFailureMode(SpannerIO.FailureMode.REPORT_FAILURES));
Schema failureSchema =
Schema.builder()
.addStringField("operation")
.addStringField("instanceId")
.addStringField("databaseId")
.addStringField("tableId")
.addStringField("mutationData")
.build();
PCollection<Row> failures =
result
.getFailedMutations()
.apply(
FlatMapElements.into(TypeDescriptors.rows())
.via(
mtg ->
Objects.requireNonNull(mtg).attached().stream()
.map(
mutation ->
Row.withSchema(failureSchema)
.addValue(mutation.getOperation().toString())
.addValue(configuration.getInstanceId())
.addValue(configuration.getDatabaseId())
.addValue(mutation.getTable())
// TODO(pabloem): Figure out how to represent
// mutation
// contents in DLQ
.addValue(
Iterators.toString(
mutation.getValues().iterator()))
.build())
.collect(Collectors.toList())))
.setRowSchema(failureSchema);
return PCollectionRowTuple.of("failures", failures);
}
};
}
}

@Override
public @UnknownKeyFor @NonNull @Initialized String identifier() {
return "beam:schematransform:org.apache.beam:spanner_write:v1";
}

@Override
public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String>
inputCollectionNames() {
return Collections.singletonList("input");
}

@Override
public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String>
outputCollectionNames() {
return Collections.singletonList("failures");
}

@AutoValue
@DefaultSchema(AutoValueSchema.class)
public abstract static class SpannerWriteSchemaTransformConfiguration implements Serializable {
public abstract String getInstanceId();

public abstract String getDatabaseId();

public abstract String getTableId();

public static Builder builder() {
return new AutoValue_SpannerWriteSchemaTransformProvider_SpannerWriteSchemaTransformConfiguration
.Builder();
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setInstanceId(String instanceId);

public abstract Builder setDatabaseId(String databaseId);

public abstract Builder setTableId(String tableId);

public abstract SpannerWriteSchemaTransformConfiguration build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,26 @@
import com.google.spanner.admin.database.v1.CreateDatabaseMetadata;
import java.io.Serializable;
import java.util.Collections;
import java.util.Objects;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestPipelineOptions;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.Wait;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicate;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
Expand Down Expand Up @@ -200,6 +206,40 @@ public void testWrite() throws Exception {
assertThat(countNumberOfRecords(pgDatabaseName), equalTo((long) numRecords));
}

@Test
public void testWriteViaSchemaTransform() throws Exception {
int numRecords = 100;
final Schema tableSchema =
Schema.builder().addInt64Field("Key").addStringField("Value").build();
PCollectionRowTuple.of(
"input",
p.apply("Init", GenerateSequence.from(0).to(numRecords))
.apply(
MapElements.into(TypeDescriptors.rows())
.via(
seed ->
Row.withSchema(tableSchema)
.addValue(seed)
.addValue(Objects.requireNonNull(seed).toString())
.build()))
.setRowSchema(tableSchema))
.apply(
new SpannerWriteSchemaTransformProvider()
.from(
SpannerWriteSchemaTransformProvider.SpannerWriteSchemaTransformConfiguration
.builder()
.setDatabaseId(databaseName)
.setInstanceId(options.getInstanceId())
.setTableId(options.getTable())
.build())
.buildTransform());

PipelineResult result = p.run();
result.waitUntilFinish();
assertThat(result.getState(), is(PipelineResult.State.DONE));
assertThat(countNumberOfRecords(databaseName), equalTo((long) numRecords));
}

@Test
public void testSequentialWrite() throws Exception {
int numRecords = 100;
Expand Down

0 comments on commit 09606de

Please sign in to comment.