Skip to content

Commit

Permalink
[YAML] Implement basic java mapping operations. (#28657)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Oct 31, 2023
1 parent 96d6f2d commit b513420
Show file tree
Hide file tree
Showing 9 changed files with 1,337 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.transforms.providers;

import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;

/**
* An implementation of {@link TypedSchemaTransformProvider} for MapToFields for the java language.
*
* <p><b>Internal only:</b> This class is actively being worked on, and it will likely change. We
* provide no backwards compatibility guarantees, and it should not be implemented outside the Beam
* repository.
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
@AutoService(SchemaTransformProvider.class)
public class JavaMapToFieldsTransformProvider
extends TypedSchemaTransformProvider<JavaMapToFieldsTransformProvider.Configuration> {
protected static final String INPUT_ROWS_TAG = "input";
protected static final String OUTPUT_ROWS_TAG = "output";

@Override
protected Class<Configuration> configurationClass() {
return Configuration.class;
}

@Override
protected SchemaTransform from(Configuration configuration) {
return new JavaMapToFieldsTransform(configuration);
}

@Override
public String identifier() {
return String.format("beam:schematransform:org.apache.beam:yaml:map_to_fields-java:v1");
}

@Override
public List<String> inputCollectionNames() {
return Collections.singletonList(INPUT_ROWS_TAG);
}

@Override
public List<String> outputCollectionNames() {
return Collections.singletonList(OUTPUT_ROWS_TAG);
}

@DefaultSchema(AutoValueSchema.class)
@AutoValue
public abstract static class Configuration {
@Nullable
public abstract String getLanguage();

@Nullable
public abstract Boolean getAppend();

@Nullable
public abstract List<String> getDrop();

public abstract Map<String, JavaRowUdf.Configuration> getFields();

@Nullable
public abstract ErrorHandling getErrorHandling();

public static Builder builder() {
return new AutoValue_JavaMapToFieldsTransformProvider_Configuration.Builder();
}

@AutoValue.Builder
public abstract static class Builder {

public abstract Builder setLanguage(String language);

public abstract Builder setAppend(Boolean append);

public abstract Builder setDrop(List<String> drop);

public abstract Builder setFields(Map<String, JavaRowUdf.Configuration> fields);

public abstract Builder setErrorHandling(ErrorHandling errorHandling);

public abstract Configuration build();
}

@AutoValue
public abstract static class ErrorHandling {
@SchemaFieldDescription("The name of the output PCollection containing failed writes.")
public abstract String getOutput();

public static Builder builder() {
return new AutoValue_JavaMapToFieldsTransformProvider_Configuration_ErrorHandling.Builder();
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder setOutput(String output);

public abstract ErrorHandling build();
}
}
}

/** A {@link SchemaTransform} for MapToFields-java. */
protected static class JavaMapToFieldsTransform extends SchemaTransform {

private final Configuration configuration;

JavaMapToFieldsTransform(Configuration configuration) {
this.configuration = configuration;
}

@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
Schema inputSchema = input.get(INPUT_ROWS_TAG).getSchema();
Schema.Builder outputSchemaBuilder = new Schema.Builder();
// TODO(yaml): Consider allowing the full java schema naming syntax
// (perhaps as a different dialect/language).
boolean append = configuration.getAppend() != null && configuration.getAppend();
List<String> toDrop =
configuration.getDrop() == null ? Collections.emptyList() : configuration.getDrop();
List<JavaRowUdf> udfs = new ArrayList<>();
if (append) {
for (Schema.Field field : inputSchema.getFields()) {
if (!toDrop.contains(field.getName())) {
try {
udfs.add(
new JavaRowUdf(
JavaRowUdf.Configuration.builder().setExpression(field.getName()).build(),
inputSchema));
} catch (MalformedURLException
| ReflectiveOperationException
| StringCompiler.CompileException exn) {
throw new RuntimeException(exn);
}
outputSchemaBuilder = outputSchemaBuilder.addField(field);
}
}
}
for (Map.Entry<String, JavaRowUdf.Configuration> entry :
configuration.getFields().entrySet()) {
if (!"java".equals(configuration.getLanguage())) {
String expr = entry.getValue().getExpression();
if (expr == null || !inputSchema.hasField(expr)) {
throw new IllegalArgumentException(
"Unknown field or missing language specification for '" + entry.getKey() + "'");
}
}
try {
JavaRowUdf udf = new JavaRowUdf(entry.getValue(), inputSchema);
udfs.add(udf);
outputSchemaBuilder = outputSchemaBuilder.addField(entry.getKey(), udf.getOutputType());
} catch (MalformedURLException
| ReflectiveOperationException
| StringCompiler.CompileException exn) {
throw new RuntimeException(exn);
}
}
Schema outputSchema = outputSchemaBuilder.build();
boolean handleErrors =
configuration.getErrorHandling() != null
&& configuration.getErrorHandling().getOutput() != null;
Schema errorSchema =
Schema.of(
Schema.Field.of("failed_row", Schema.FieldType.row(inputSchema)),
Schema.Field.of("error_message", Schema.FieldType.STRING));

PCollectionTuple pcolls =
input
.get(INPUT_ROWS_TAG)
.apply(
"MapToFields",
ParDo.of(createDoFn(udfs, outputSchema, errorSchema, handleErrors))
.withOutputTags(mappedValues, TupleTagList.of(errorValues)));
pcolls.get(mappedValues).setRowSchema(outputSchema);
pcolls.get(errorValues).setRowSchema(errorSchema);

PCollectionRowTuple result =
PCollectionRowTuple.of(OUTPUT_ROWS_TAG, pcolls.get(mappedValues));
if (handleErrors) {
result = result.and(configuration.getErrorHandling().getOutput(), pcolls.get(errorValues));
}
return result;
}

private static final TupleTag<Row> mappedValues = new TupleTag<Row>() {};
private static final TupleTag<Row> errorValues = new TupleTag<Row>() {};

private static DoFn<Row, Row> createDoFn(
List<JavaRowUdf> udfs, Schema outputSchema, Schema errorSchema, boolean handleErrors) {
return new DoFn<Row, Row>() {
@ProcessElement
public void processElement(@Element Row inputRow, MultiOutputReceiver out) {
try {
Row.Builder outputRow = Row.withSchema(outputSchema);
for (JavaRowUdf udf : udfs) {
outputRow.addValue(udf.getFunction().apply(inputRow));
}
out.get(mappedValues).output(outputRow.build());
} catch (Exception exn) {
if (handleErrors) {
out.get(errorValues)
.output(
Row.withSchema(errorSchema)
.withFieldValue("failed_row", inputRow)
.withFieldValue("error_message", exn.getMessage())
.build());
} else {
throw new RuntimeException(exn);
}
}
}
};
}
}
}
Loading

0 comments on commit b513420

Please sign in to comment.