Skip to content

Commit

Permalink
Patch datum factory for specific data in IOs
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Sep 1, 2023
1 parent dd3e9f3 commit 5cc77ba
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2023 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.spotify.scio.avro

import org.apache.avro.Schema
import org.apache.avro.io.{DatumReader, DatumWriter}
import org.apache.avro.reflect.{ReflectData, ReflectDatumReader, ReflectDatumWriter}
import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils

/**
* Custom AvroDatumFactory for avro AvroDatumFactory relying on avro reflect so that underlying
* CharSequence type is String
*/
private[scio] class SpecificRecordDatumFactory[T](recordType: Class[T])
extends AvroDatumFactory[T](recordType) {
override def apply(writer: Schema, reader: Schema): DatumReader[T] = {
val data = new ReflectData(recordType.getClassLoader)
AvroUtils.addLogicalTypeConversions(data)
new ReflectDatumReader[T](writer, reader, data)
}

override def apply(writer: Schema): DatumWriter[T] = {
val data = new ReflectData(recordType.getClassLoader)
AvroUtils.addLogicalTypeConversions(data)
new ReflectDatumWriter[T](writer, data)
}
}
5 changes: 4 additions & 1 deletion scio-avro/src/main/scala/com/spotify/scio/avro/AvroIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ final case class SpecificRecordIO[T <: SpecificRecord: ClassTag: Coder](path: St
val t = BAvroIO
.read(cls)
.from(filePattern)
.withDatumReaderFactory(new SpecificRecordDatumFactory[T](cls))
sc
.applyTransform(t)
.setCoder(coder)
Expand All @@ -194,7 +195,9 @@ final case class SpecificRecordIO[T <: SpecificRecord: ClassTag: Coder](path: St
*/
override protected def write(data: SCollection[T], params: WriteP): Tap[T] = {
val cls = ScioUtil.classOf[T]
val t = BAvroIO.write(cls)
val t = BAvroIO
.write(cls)
.withDatumWriterFactory(new SpecificRecordDatumFactory[T](cls))

data.applyInternal(
avroOut(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,18 @@ trait AvroCoders {
throw new RuntimeException(msg)
}

// same as SpecificRecordDatumFactory in scio-avro
val factory = new AvroDatumFactory(clazz) {
override def apply(writer: Schema, reader: Schema): DatumReader[T] = {
// create the datum writer using the schema api
// class API might be unsafe. See schemaForClass
val datumReader = new ReflectDatumReader[T](writer, reader, new ReflectData())
// for backward compat, add logical type support by default
AvroUtils.addLogicalTypeConversions(datumReader.getData)
datumReader
val data = new ReflectData(clazz.getClassLoader)
AvroUtils.addLogicalTypeConversions(data)
new ReflectDatumReader[T](writer, reader, data)
}

override def apply(writer: Schema): DatumWriter[T] = {
// create the datum writer using the schema api
// class API might be unsafe. See schemaForClass
val datumWriter = new ReflectDatumWriter[T](writer, new ReflectData())
// for backward compat, add logical type support by default
AvroUtils.addLogicalTypeConversions(datumWriter.getData)
datumWriter
val data = new ReflectData(clazz.getClassLoader)
AvroUtils.addLogicalTypeConversions(data)
new ReflectDatumWriter[T](writer, data)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.avro.io.DatumReader;
import org.apache.avro.reflect.ReflectData;
import org.apache.avro.reflect.ReflectDatumReader;
import org.apache.avro.reflect.ReflectDatumWriter;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.avro.coders.AvroCoder;
import org.apache.beam.sdk.extensions.avro.io.AvroIO;
Expand Down Expand Up @@ -121,7 +122,16 @@ public GenericRecord formatRecord(ValueT element, Schema schema) {
}
})
.withCodec(codec.getCodec())
: AvroIO.sink(recordClass).withCodec(codec.getCodec());
: AvroIO.sink(recordClass)
.withCodec(codec.getCodec())
.withDatumWriterFactory(
(writer) -> {
// same as SpecificRecordDatumFactory in scio-avro
ReflectData data = new ReflectData(recordClass.getClassLoader());
org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils
.addLogicalTypeConversions(data);
return new ReflectDatumWriter<>(writer, data);
});

if (metadata != null) {
return sink.withMetadata(metadata);
Expand Down Expand Up @@ -193,10 +203,15 @@ private static class AvroReader<ValueT> extends FileOperations.Reader<ValueT> {
public void prepareRead(ReadableByteChannel channel) throws IOException {
final Schema schema = schemaSupplier.get();

DatumReader<ValueT> datumReader =
recordClass == null
? new GenericDatumReader<>(schema)
: new ReflectDatumReader<>(recordClass);
DatumReader<ValueT> datumReader;
if (recordClass == null) {
datumReader = new GenericDatumReader<>(schema);
} else {
// same as SpecificRecordDatumFactory in scio-avro
ReflectData data = new ReflectData(recordClass.getClassLoader());
org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils.addLogicalTypeConversions(data);
datumReader = new ReflectDatumReader<>(data);
}

reader = new DataFileStream<>(Channels.newInputStream(channel), datumReader);
}
Expand Down

0 comments on commit 5cc77ba

Please sign in to comment.