Skip to content

Commit

Permalink
maintain OneOfType serialization compatability
Browse files Browse the repository at this point in the history
  • Loading branch information
scwhittle committed Sep 20, 2024
1 parent 4c97fab commit a6af989
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,62 +115,80 @@ private static String getLogicalTypeUrn(String identifier) {
.build();

public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLogicalType) {
String uuid = schema.getUUID() != null ? schema.getUUID().toString() : "";
return schemaToProto(schema, serializeLogicalType, true);
}

public static SchemaApi.Schema schemaToProto(
Schema schema, boolean serializeLogicalType, boolean serializeUUID) {
String uuid = schema.getUUID() != null && serializeUUID ? schema.getUUID().toString() : "";
SchemaApi.Schema.Builder builder = SchemaApi.Schema.newBuilder().setId(uuid);
for (Field field : schema.getFields()) {
SchemaApi.Field protoField =
fieldToProto(
field,
schema.indexOf(field.getName()),
schema.getEncodingPositions().get(field.getName()),
serializeLogicalType);
serializeLogicalType,
serializeUUID);
builder.addFields(protoField);
}
builder.addAllOptions(optionsToProto(schema.getOptions()));
return builder.build();
}

private static SchemaApi.Field fieldToProto(
Field field, int fieldId, int position, boolean serializeLogicalType) {
Field field, int fieldId, int position, boolean serializeLogicalType, boolean serializeUUID) {
return SchemaApi.Field.newBuilder()
.setName(field.getName())
.setDescription(field.getDescription())
.setType(fieldTypeToProto(field.getType(), serializeLogicalType))
.setType(fieldTypeToProto(field.getType(), serializeLogicalType, serializeUUID))
.setId(fieldId)
.setEncodingPosition(position)
.addAllOptions(optionsToProto(field.getOptions()))
.build();
}

@VisibleForTesting
static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean serializeLogicalType) {
static SchemaApi.FieldType fieldTypeToProto(
FieldType fieldType, boolean serializeLogicalType, boolean serializeUUID) {
SchemaApi.FieldType.Builder builder = SchemaApi.FieldType.newBuilder();
switch (fieldType.getTypeName()) {
case ROW:
builder.setRowType(
SchemaApi.RowType.newBuilder()
.setSchema(schemaToProto(fieldType.getRowSchema(), serializeLogicalType)));
.setSchema(
schemaToProto(fieldType.getRowSchema(), serializeLogicalType, serializeUUID)));
break;

case ARRAY:
builder.setArrayType(
SchemaApi.ArrayType.newBuilder()
.setElementType(
fieldTypeToProto(fieldType.getCollectionElementType(), serializeLogicalType)));
fieldTypeToProto(
fieldType.getCollectionElementType(),
serializeLogicalType,
serializeUUID)));
break;

case ITERABLE:
builder.setIterableType(
SchemaApi.IterableType.newBuilder()
.setElementType(
fieldTypeToProto(fieldType.getCollectionElementType(), serializeLogicalType)));
fieldTypeToProto(
fieldType.getCollectionElementType(),
serializeLogicalType,
serializeUUID)));
break;

case MAP:
builder.setMapType(
SchemaApi.MapType.newBuilder()
.setKeyType(fieldTypeToProto(fieldType.getMapKeyType(), serializeLogicalType))
.setValueType(fieldTypeToProto(fieldType.getMapValueType(), serializeLogicalType))
.setKeyType(
fieldTypeToProto(
fieldType.getMapKeyType(), serializeLogicalType, serializeUUID))
.setValueType(
fieldTypeToProto(
fieldType.getMapValueType(), serializeLogicalType, serializeUUID))
.build());
break;

Expand All @@ -186,12 +204,14 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali
.setUrn(logicalType.getIdentifier())
.setPayload(ByteString.copyFrom(((UnknownLogicalType) logicalType).getPayload()))
.setRepresentation(
fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType));
fieldTypeToProto(
logicalType.getBaseType(), serializeLogicalType, serializeUUID));

if (logicalType.getArgumentType() != null) {
logicalTypeBuilder
.setArgumentType(
fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType))
fieldTypeToProto(
logicalType.getArgumentType(), serializeLogicalType, serializeUUID))
.setArgument(
fieldValueToProto(logicalType.getArgumentType(), logicalType.getArgument()));
}
Expand All @@ -200,13 +220,15 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali
logicalTypeBuilder =
SchemaApi.LogicalType.newBuilder()
.setRepresentation(
fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType))
fieldTypeToProto(
logicalType.getBaseType(), serializeLogicalType, serializeUUID))
.setUrn(urn);
if (logicalType.getArgumentType() != null) {
logicalTypeBuilder =
logicalTypeBuilder
.setArgumentType(
fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType))
fieldTypeToProto(
logicalType.getArgumentType(), serializeLogicalType, serializeUUID))
.setArgument(
fieldValueToProto(
logicalType.getArgumentType(), logicalType.getArgument()));
Expand All @@ -226,7 +248,8 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali
builder.setLogicalType(
SchemaApi.LogicalType.newBuilder()
.setUrn(URN_BEAM_LOGICAL_MILLIS_INSTANT)
.setRepresentation(fieldTypeToProto(FieldType.INT64, serializeLogicalType))
.setRepresentation(
fieldTypeToProto(FieldType.INT64, serializeLogicalType, serializeUUID))
.build());
break;
case DECIMAL:
Expand All @@ -235,7 +258,8 @@ static SchemaApi.FieldType fieldTypeToProto(FieldType fieldType, boolean seriali
builder.setLogicalType(
SchemaApi.LogicalType.newBuilder()
.setUrn(URN_BEAM_LOGICAL_DECIMAL)
.setRepresentation(fieldTypeToProto(FieldType.BYTES, serializeLogicalType))
.setRepresentation(
fieldTypeToProto(FieldType.BYTES, serializeLogicalType, serializeUUID))
.build());
break;
case BYTE:
Expand Down Expand Up @@ -771,7 +795,8 @@ private static List<SchemaApi.Option> optionsToProto(Schema.Options options) {
protoOptions.add(
SchemaApi.Option.newBuilder()
.setName(name)
.setType(fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false))
.setType(
fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false, false))
.setValue(
fieldValueToProto(
Objects.requireNonNull(options.getType(name)), options.getValue(name)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.values.Row;
import org.checkerframework.checker.nullness.qual.Nullable;

Expand All @@ -46,6 +47,7 @@ public class OneOfType implements LogicalType<OneOfType.Value, Row> {

private final Schema oneOfSchema;
private final EnumerationType enumerationType;
private final byte[] schemaProtoRepresentation;

private OneOfType(List<Field> fields) {
this(fields, null);
Expand All @@ -63,6 +65,8 @@ private OneOfType(List<Field> fields, @Nullable Map<String, Integer> enumMap) {
enumerationType = EnumerationType.create(enumValues);
}
oneOfSchema = Schema.builder().addFields(nullableFields).build();
schemaProtoRepresentation =
SchemaTranslation.schemaToProto(oneOfSchema, false, false).toByteArray();
}

/** Create an {@link OneOfType} logical type. */
Expand Down Expand Up @@ -100,12 +104,12 @@ public String getIdentifier() {

@Override
public FieldType getArgumentType() {
return FieldType.row(oneOfSchema);
return FieldType.BYTES;
}

@Override
public Row getArgument() {
return Row.nullRow(oneOfSchema);
public byte[] getArgument() {
return schemaProtoRepresentation;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ public void toAndFromProto() throws Exception {
public static class FromProtoToProtoTest {
@Parameters(name = "{index}: {0}")
public static Iterable<SchemaApi.Schema> data() {
ImmutableList.Builder<SchemaApi.Schema> listBuilder = ImmutableList.builder();
SchemaApi.Schema.Builder builder = SchemaApi.Schema.newBuilder();
// A go 'int'
builder.addFields(
Expand All @@ -232,6 +233,9 @@ public static Iterable<SchemaApi.Schema> data() {
.setId(0)
.setEncodingPosition(0)
.build());
SchemaApi.Schema singleFieldSchema = builder.build();
listBuilder.add(singleFieldSchema);

// A pickled python object
builder.addFields(
SchemaApi.Field.newBuilder()
Expand Down Expand Up @@ -294,22 +298,51 @@ public static Iterable<SchemaApi.Schema> data() {
.setId(2)
.setEncodingPosition(2)
.build());
SchemaApi.Schema unknownLogicalTypeSchema = builder.build();
SchemaApi.Schema multipleFieldSchema = builder.build();
listBuilder.add(multipleFieldSchema);

return ImmutableList.<SchemaApi.Schema>builder().add(unknownLogicalTypeSchema).build();
builder.clear();
builder.addFields(
SchemaApi.Field.newBuilder()
.setName("nested")
.setType(
SchemaApi.FieldType.newBuilder()
.setRowType(
SchemaApi.RowType.newBuilder().setSchema(singleFieldSchema).build())
.build())
.build());
SchemaApi.Schema nestedSchema = builder.build();
listBuilder.add(nestedSchema);

return listBuilder.build();
}

@Parameter(0)
public SchemaApi.Schema schemaProto;

private void clearIds(SchemaApi.Schema.Builder builder) {
builder.clearId();
for (SchemaApi.Field.Builder field : builder.getFieldsBuilderList()) {
if (field.hasType()
&& field.getType().hasRowType()
&& field.getType().getRowType().hasSchema()) {
clearIds(field.getTypeBuilder().getRowTypeBuilder().getSchemaBuilder());
}
}
}

@Test
public void fromProtoAndToProto() throws Exception {
Schema decodedSchema = SchemaTranslation.schemaFromProto(schemaProto);

SchemaApi.Schema reencodedSchemaProto = SchemaTranslation.schemaToProto(decodedSchema, true);
reencodedSchemaProto = reencodedSchemaProto.toBuilder().clearId().build();
SchemaApi.Schema.Builder builder = reencodedSchemaProto.toBuilder();
clearIds(builder);
assertThat(builder.build(), equalTo(schemaProto));

assertThat(reencodedSchemaProto, equalTo(schemaProto));
SchemaApi.Schema reencodedSchemaProtoWithoutUUID =
SchemaTranslation.schemaToProto(decodedSchema, true, false);
assertThat(reencodedSchemaProtoWithoutUUID, equalTo(schemaProto));
}
}

Expand Down Expand Up @@ -433,8 +466,8 @@ public static Iterable<Schema.FieldType> data() {
public Schema.FieldType fieldType;

@Test
public void testLogicalTypeSerializeDeserilizeCorrectly() {
SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, true);
public void testLogicalTypeSerializeDeserializeCorrectly() {
SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, true, false);
Schema.FieldType translated = SchemaTranslation.fieldTypeFromProto(proto);

assertThat(
Expand All @@ -452,7 +485,7 @@ public void testLogicalTypeSerializeDeserilizeCorrectly() {

@Test
public void testLogicalTypeFromToProtoCorrectly() {
SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, false);
SchemaApi.FieldType proto = SchemaTranslation.fieldTypeToProto(fieldType, false, false);
Schema.FieldType translated = SchemaTranslation.fieldTypeFromProto(proto);

if (STANDARD_LOGICAL_TYPES.containsKey(translated.getLogicalType().getIdentifier())) {
Expand Down

0 comments on commit a6af989

Please sign in to comment.