From a6af989a2dea85788aa65d2c43b72356740bc867 Mon Sep 17 00:00:00 2001 From: Sam Whittle Date: Fri, 20 Sep 2024 15:00:30 +0200 Subject: [PATCH] maintain OneOfType serialization compatability --- .../beam/sdk/schemas/SchemaTranslation.java | 59 +++++++++++++------ .../sdk/schemas/logicaltypes/OneOfType.java | 10 +++- .../sdk/schemas/SchemaTranslationTest.java | 47 ++++++++++++--- 3 files changed, 89 insertions(+), 27 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java index cec0226e0741..5253f82d15b9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java @@ -115,7 +115,12 @@ private static String getLogicalTypeUrn(String identifier) { .build(); public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLogicalType) { - String uuid = schema.getUUID() != null ? schema.getUUID().toString() : ""; + return schemaToProto(schema, serializeLogicalType, true); + } + + public static SchemaApi.Schema schemaToProto( + Schema schema, boolean serializeLogicalType, boolean serializeUUID) { + String uuid = schema.getUUID() != null && serializeUUID ? schema.getUUID().toString() : ""; SchemaApi.Schema.Builder builder = SchemaApi.Schema.newBuilder().setId(uuid); for (Field field : schema.getFields()) { SchemaApi.Field protoField = @@ -123,7 +128,8 @@ public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLog field, schema.indexOf(field.getName()), schema.getEncodingPositions().get(field.getName()), - serializeLogicalType); + serializeLogicalType, + serializeUUID); builder.addFields(protoField); } builder.addAllOptions(optionsToProto(schema.getOptions())); @@ -131,11 +137,11 @@ public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLog } private static SchemaApi.Field fieldToProto( - Field field, int fieldId, int position, boolean serializeLogicalType) { + Field field, int fieldId, int position, boolean serializeLogicalType, boolean serializeUUID) { return SchemaApi.Field.newBuilder() .setName(field.getName()) .setDescription(field.getDescription()) - .setType(fieldTypeToProto(field.getType(), serializeLogicalType)) + .setType(fieldTypeToProto(field.getType(), serializeLogicalType, serializeUUID)) .setId(fieldId) .setEncodingPosition(position) .addAllOptions(optionsToProto(field.getOptions())) @@ -143,34 +149,46 @@ private static SchemaApi.Field fieldToProto( } @VisibleForTesting - static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean serializeLogicalType) { + static SchemaApi.FieldType fieldTypeToProto( + FieldType fieldType, boolean serializeLogicalType, boolean serializeUUID) { SchemaApi.FieldType.Builder builder = SchemaApi.FieldType.newBuilder(); switch (fieldType.getTypeName()) { case ROW: builder.setRowType( SchemaApi.RowType.newBuilder() - .setSchema(schemaToProto(fieldType.getRowSchema(), serializeLogicalType))); + .setSchema( + schemaToProto(fieldType.getRowSchema(), serializeLogicalType, serializeUUID))); break; case ARRAY: builder.setArrayType( SchemaApi.ArrayType.newBuilder() .setElementType( - fieldTypeToProto(fieldType.getCollectionElementType(), serializeLogicalType))); + fieldTypeToProto( + fieldType.getCollectionElementType(), + serializeLogicalType, + serializeUUID))); break; case ITERABLE: builder.setIterableType( SchemaApi.IterableType.newBuilder() .setElementType( - fieldTypeToProto(fieldType.getCollectionElementType(), serializeLogicalType))); + fieldTypeToProto( + fieldType.getCollectionElementType(), + serializeLogicalType, + serializeUUID))); break; case MAP: builder.setMapType( SchemaApi.MapType.newBuilder() - .setKeyType(fieldTypeToProto(fieldType.getMapKeyType(), serializeLogicalType)) - .setValueType(fieldTypeToProto(fieldType.getMapValueType(), serializeLogicalType)) + .setKeyType( + fieldTypeToProto( + fieldType.getMapKeyType(), serializeLogicalType, serializeUUID)) + .setValueType( + fieldTypeToProto( + fieldType.getMapValueType(), serializeLogicalType, serializeUUID)) .build()); break; @@ -186,12 +204,14 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali .setUrn(logicalType.getIdentifier()) .setPayload(ByteString.copyFrom(((UnknownLogicalType) logicalType).getPayload())) .setRepresentation( - fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType)); + fieldTypeToProto( + logicalType.getBaseType(), serializeLogicalType, serializeUUID)); if (logicalType.getArgumentType() != null) { logicalTypeBuilder .setArgumentType( - fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType)) + fieldTypeToProto( + logicalType.getArgumentType(), serializeLogicalType, serializeUUID)) .setArgument( fieldValueToProto(logicalType.getArgumentType(), logicalType.getArgument())); } @@ -200,13 +220,15 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali logicalTypeBuilder = SchemaApi.LogicalType.newBuilder() .setRepresentation( - fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType)) + fieldTypeToProto( + logicalType.getBaseType(), serializeLogicalType, serializeUUID)) .setUrn(urn); if (logicalType.getArgumentType() != null) { logicalTypeBuilder = logicalTypeBuilder .setArgumentType( - fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType)) + fieldTypeToProto( + logicalType.getArgumentType(), serializeLogicalType, serializeUUID)) .setArgument( fieldValueToProto( logicalType.getArgumentType(), logicalType.getArgument())); @@ -226,7 +248,8 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali builder.setLogicalType( SchemaApi.LogicalType.newBuilder() .setUrn(URN_BEAM_LOGICAL_MILLIS_INSTANT) - .setRepresentation(fieldTypeToProto(FieldType.INT64, serializeLogicalType)) + .setRepresentation( + fieldTypeToProto(FieldType.INT64, serializeLogicalType, serializeUUID)) .build()); break; case DECIMAL: @@ -235,7 +258,8 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali builder.setLogicalType( SchemaApi.LogicalType.newBuilder() .setUrn(URN_BEAM_LOGICAL_DECIMAL) - .setRepresentation(fieldTypeToProto(FieldType.BYTES, serializeLogicalType)) + .setRepresentation( + fieldTypeToProto(FieldType.BYTES, serializeLogicalType, serializeUUID)) .build()); break; case BYTE: @@ -771,7 +795,8 @@ private static List optionsToProto(Schema.Options options) { protoOptions.add( SchemaApi.Option.newBuilder() .setName(name) - .setType(fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false)) + .setType( + fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false, false)) .setValue( fieldValueToProto( Objects.requireNonNull(options.getType(name)), options.getValue(name))) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java index b280f29bf701..31b6c8db2fed 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.values.Row; import org.checkerframework.checker.nullness.qual.Nullable; @@ -46,6 +47,7 @@ public class OneOfType implements LogicalType { private final Schema oneOfSchema; private final EnumerationType enumerationType; + private final byte[] schemaProtoRepresentation; private OneOfType(List fields) { this(fields, null); @@ -63,6 +65,8 @@ private OneOfType(List fields, @Nullable Map enumMap) { enumerationType = EnumerationType.create(enumValues); } oneOfSchema = Schema.builder().addFields(nullableFields).build(); + schemaProtoRepresentation = + SchemaTranslation.schemaToProto(oneOfSchema, false, false).toByteArray(); } /** Create an {@link OneOfType} logical type. */ @@ -100,12 +104,12 @@ public String getIdentifier() { @Override public FieldType getArgumentType() { - return FieldType.row(oneOfSchema); + return FieldType.BYTES; } @Override - public Row getArgument() { - return Row.nullRow(oneOfSchema); + public byte[] getArgument() { + return schemaProtoRepresentation; } @Override diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java index e330eb0144e1..b082e2bb68ee 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaTranslationTest.java @@ -214,6 +214,7 @@ public void toAndFromProto() throws Exception { public static class FromProtoToProtoTest { @Parameters(name = "{index}: {0}") public static Iterable data() { + ImmutableList.Builder listBuilder = ImmutableList.builder(); SchemaApi.Schema.Builder builder = SchemaApi.Schema.newBuilder(); // A go 'int' builder.addFields( @@ -232,6 +233,9 @@ public static Iterable data() { .setId(0) .setEncodingPosition(0) .build()); + SchemaApi.Schema singleFieldSchema = builder.build(); + listBuilder.add(singleFieldSchema); + // A pickled python object builder.addFields( SchemaApi.Field.newBuilder() @@ -294,22 +298,51 @@ public static Iterable data() { .setId(2) .setEncodingPosition(2) .build()); - SchemaApi.Schema unknownLogicalTypeSchema = builder.build(); + SchemaApi.Schema multipleFieldSchema = builder.build(); + listBuilder.add(multipleFieldSchema); - return ImmutableList.builder().add(unknownLogicalTypeSchema).build(); + builder.clear(); + builder.addFields( + SchemaApi.Field.newBuilder() + .setName("nested") + .setType( + SchemaApi.FieldType.newBuilder() + .setRowType( + SchemaApi.RowType.newBuilder().setSchema(singleFieldSchema).build()) + .build()) + .build()); + SchemaApi.Schema nestedSchema = builder.build(); + listBuilder.add(nestedSchema); + + return listBuilder.build(); } @Parameter(0) public SchemaApi.Schema schemaProto; + private void clearIds(SchemaApi.Schema.Builder builder) { + builder.clearId(); + for (SchemaApi.Field.Builder field : builder.getFieldsBuilderList()) { + if (field.hasType() + && field.getType().hasRowType() + && field.getType().getRowType().hasSchema()) { + clearIds(field.getTypeBuilder().getRowTypeBuilder().getSchemaBuilder()); + } + } + } + @Test public void fromProtoAndToProto() throws Exception { Schema decodedSchema = SchemaTranslation.schemaFromProto(schemaProto); SchemaApi.Schema reencodedSchemaProto = SchemaTranslation.schemaToProto(decodedSchema, true); - reencodedSchemaProto = reencodedSchemaProto.toBuilder().clearId().build(); + SchemaApi.Schema.Builder builder = reencodedSchemaProto.toBuilder(); + clearIds(builder); + assertThat(builder.build(), equalTo(schemaProto)); - assertThat(reencodedSchemaProto, equalTo(schemaProto)); + SchemaApi.Schema reencodedSchemaProtoWithoutUUID = + SchemaTranslation.schemaToProto(decodedSchema, true, false); + assertThat(reencodedSchemaProtoWithoutUUID, equalTo(schemaProto)); } } @@ -433,8 +466,8 @@ public static Iterable data() { public Schema.FieldType fieldType; @Test - public void testLogicalTypeSerializeDeserilizeCorrectly() { - SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, true); + public void testLogicalTypeSerializeDeserializeCorrectly() { + SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, true, false); Schema.FieldType translated = SchemaTranslation.fieldTypeFromProto(proto); assertThat( @@ -452,7 +485,7 @@ public void testLogicalTypeSerializeDeserilizeCorrectly() { @Test public void testLogicalTypeFromToProtoCorrectly() { - SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, false); + SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, false, false); Schema.FieldType translated = SchemaTranslation.fieldTypeFromProto(proto); if (STANDARD_LOGICAL_TYPES.containsKey(translated.getLogicalType().getIdentifier())) {