Skip to content

Commit

Permalink
Improve BQ avro handling
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Sep 18, 2024
1 parent 6514136 commit 2312dbe
Show file tree
Hide file tree
Showing 11 changed files with 647 additions and 516 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,7 @@ public class BigQueryIO {
GENERIC_DATUM_WRITER_FACTORY = schema -> new GenericDatumWriter<>();

private static final SerializableFunction<TableSchema, org.apache.avro.Schema>
DEFAULT_AVRO_SCHEMA_FACTORY =
(SerializableFunction<TableSchema, org.apache.avro.Schema>)
input -> BigQueryAvroUtils.toGenericAvroSchema("root", input.getFields());
DEFAULT_AVRO_SCHEMA_FACTORY = BigQueryAvroUtils::toGenericAvroSchema;

/**
* @deprecated Use {@link #read(SerializableFunction)} or {@link #readTableRows} instead. {@link
Expand All @@ -649,12 +647,12 @@ public static Read read() {
* domain-specific type, due to the overhead of converting the rows to {@link TableRow}.
*/
public static TypedRead<TableRow> readTableRows() {
return read(new TableRowParser()).withCoder(TableRowJsonCoder.of());
return read(TableRowParser.INSTANCE).withCoder(TableRowJsonCoder.of());
}

/** Like {@link #readTableRows()} but with {@link Schema} support. */
public static TypedRead<TableRow> readTableRowsWithSchema() {
return read(new TableRowParser())
return read(TableRowParser.INSTANCE)
.withCoder(TableRowJsonCoder.of())
.withBeamRowConverters(
TypeDescriptor.of(TableRow.class),
Expand Down Expand Up @@ -793,8 +791,7 @@ static class TableRowParser implements SerializableFunction<SchemaAndRecord, Tab

@Override
public TableRow apply(SchemaAndRecord schemaAndRecord) {
return BigQueryAvroUtils.convertGenericRecordToTableRow(
schemaAndRecord.getRecord(), schemaAndRecord.getTableSchema());
return BigQueryAvroUtils.convertGenericRecordToTableRow(schemaAndRecord.getRecord());
}
}

Expand Down Expand Up @@ -1275,8 +1272,12 @@ public PCollection<T> expand(PBegin input) {

Schema beamSchema = null;
if (getTypeDescriptor() != null && getToBeamRowFn() != null && getFromBeamRowFn() != null) {
beamSchema = sourceDef.getBeamSchema(bqOptions);
beamSchema = getFinalSchema(beamSchema, getSelectedFields());
TableSchema tableSchema = sourceDef.getTableSchema(bqOptions);
ValueProvider<List<String>> selectedFields = getSelectedFields();
if (selectedFields != null) {
tableSchema = BigQueryUtils.trimSchema(tableSchema, selectedFields.get());
}
beamSchema = BigQueryUtils.fromTableSchema(tableSchema);
}

final Coder<T> coder = inferCoder(p.getCoderRegistry());
Expand Down Expand Up @@ -1441,24 +1442,6 @@ void cleanup(PassThroughThenCleanup.ContextContainer c) throws Exception {
return rows;
}

private static Schema getFinalSchema(
Schema beamSchema, ValueProvider<List<String>> selectedFields) {
List<Schema.Field> flds =
beamSchema.getFields().stream()
.filter(
field -> {
if (selectedFields != null
&& selectedFields.isAccessible()
&& selectedFields.get() != null) {
return selectedFields.get().contains(field.getName());
} else {
return true;
}
})
.collect(Collectors.toList());
return Schema.builder().addFields(flds).build();
}

private PCollection<T> expandForDirectRead(
PBegin input, Coder<T> outputCoder, Schema beamSchema, BigQueryOptions bqOptions) {
ValueProvider<TableReference> tableProvider = getTableProvider();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public <T> BigQuerySourceBase<T> toSource(

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
public TableSchema getTableSchema(BigQueryOptions bqOptions) {
try {
JobStatistics stats =
BigQueryQueryHelper.dryRunQueryIfNeeded(
Expand All @@ -189,14 +189,20 @@ public Schema getBeamSchema(BigQueryOptions bqOptions) {
flattenResults,
useLegacySql,
location);
TableSchema tableSchema = stats.getQuery().getSchema();
return BigQueryUtils.fromTableSchema(tableSchema);
return stats.getQuery().getSchema();
} catch (IOException | InterruptedException | NullPointerException e) {
throw new BigQuerySchemaRetrievalException(
"Exception while trying to retrieve schema of query", e);
}
}

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
TableSchema tableSchema = getTableSchema(bqOptions);
return BigQueryUtils.fromTableSchema(tableSchema);
}

ValueProvider<String> getQuery() {
return query;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ private List<ResourceId> executeExtract(
List<BoundedSource<T>> createSources(
List<ResourceId> files, TableSchema schema, @Nullable List<MatchResult.Metadata> metadata)
throws IOException, InterruptedException {
String avroSchema =
BigQueryAvroUtils.toGenericAvroSchema("root", schema.getFields()).toString();
String avroSchema = BigQueryAvroUtils.toGenericAvroSchema(schema).toString();

AvroSource.DatumReaderFactory<T> factory = readerFactory.apply(schema);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ <T> BigQuerySourceBase<T> toSource(
SerializableFunction<TableSchema, AvroSource.DatumReaderFactory<T>> readerFactory,
boolean useAvroLogicalTypes);

/**
* Extract the {@link TableSchema} corresponding to this source.
*
* @param bqOptions BigQueryOptions
* @return table schema of the source
* @throws BigQuerySchemaRetrievalException if schema retrieval fails
*/
TableSchema getTableSchema(BigQueryOptions bqOptions);

/**
* Extract the Beam {@link Schema} corresponding to this source.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@
import com.google.cloud.bigquery.storage.v1.ReadStream;
import java.io.IOException;
import java.util.List;
import org.apache.avro.Schema;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.arrow.ArrowConversion;
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient;
import org.apache.beam.sdk.metrics.Lineage;
Expand Down Expand Up @@ -182,30 +179,17 @@ public List<BigQueryStorageStreamSource<T>> split(
LOG.info("Read session returned {} streams", readSession.getStreamsList().size());
}

Schema sessionSchema;
if (readSession.getDataFormat() == DataFormat.ARROW) {
org.apache.arrow.vector.types.pojo.Schema schema =
ArrowConversion.arrowSchemaFromInput(
readSession.getArrowSchema().getSerializedSchema().newInput());
org.apache.beam.sdk.schemas.Schema beamSchema =
ArrowConversion.ArrowSchemaTranslator.toBeamSchema(schema);
sessionSchema = AvroUtils.toAvroSchema(beamSchema);
} else if (readSession.getDataFormat() == DataFormat.AVRO) {
sessionSchema = new Schema.Parser().parse(readSession.getAvroSchema().getSchema());
} else {
throw new IllegalArgumentException(
"data is not in a supported dataFormat: " + readSession.getDataFormat());
// TODO: this is inconsistent with method above, where it can be null
Preconditions.checkStateNotNull(targetTable);
TableSchema tableSchema = targetTable.getSchema();
if (selectedFieldsProvider != null) {
tableSchema = BigQueryUtils.trimSchema(tableSchema, selectedFieldsProvider.get());
}

Preconditions.checkStateNotNull(
targetTable); // TODO: this is inconsistent with method above, where it can be null
TableSchema trimmedSchema =
BigQueryAvroUtils.trimBigQueryTableSchema(targetTable.getSchema(), sessionSchema);
List<BigQueryStorageStreamSource<T>> sources = Lists.newArrayList();
for (ReadStream readStream : readSession.getStreamsList()) {
sources.add(
BigQueryStorageStreamSource.create(
readSession, readStream, trimmedSchema, parseFn, outputCoder, bqServices));
readSession, readStream, tableSchema, parseFn, outputCoder, bqServices));
}

return ImmutableList.copyOf(sources);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,22 @@ public <T> BigQuerySourceBase<T> toSource(

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
public TableSchema getTableSchema(BigQueryOptions bqOptions) {
try {
try (DatasetService datasetService = bqServices.getDatasetService(bqOptions)) {
TableReference tableRef = getTableReference(bqOptions);
Table table = datasetService.getTable(tableRef);
TableSchema tableSchema = Preconditions.checkStateNotNull(table).getSchema();
return BigQueryUtils.fromTableSchema(tableSchema);
return Preconditions.checkStateNotNull(table).getSchema();
}
} catch (Exception e) {
throw new BigQuerySchemaRetrievalException("Exception while trying to retrieve schema", e);
}
}

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
TableSchema tableSchema = getTableSchema(bqOptions);
return BigQueryUtils.fromTableSchema(tableSchema);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,21 @@ private static FieldType fromTableFieldSchemaType(
case "BYTES":
return FieldType.BYTES;
case "INT64":
case "INT":
case "SMALLINT":
case "INTEGER":
case "BIGINT":
case "TINYINT":
case "BYTEINT":
return FieldType.INT64;
case "FLOAT64":
case "FLOAT":
case "FLOAT": // even if not a valid BQ type, it is used in the schema
return FieldType.DOUBLE;
case "BOOL":
case "BOOLEAN":
return FieldType.BOOLEAN;
case "NUMERIC":
case "BIGNUMERIC":
return FieldType.DECIMAL;
case "TIMESTAMP":
return FieldType.DATETIME;
Expand All @@ -355,6 +361,10 @@ private static FieldType fromTableFieldSchemaType(

Schema rowSchema = fromTableFieldSchema(nestedFields, options);
return FieldType.row(rowSchema);
case "GEOGRAPHY":
case "JSON":
// TODO Add metadata for custom sql types
return FieldType.STRING;
default:
throw new UnsupportedOperationException(
"Converting BigQuery type " + typeName + " to Beam type is unsupported");
Expand Down Expand Up @@ -446,10 +456,27 @@ public static Schema fromTableSchema(TableSchema tableSchema, SchemaConversionOp
return fromTableFieldSchema(tableSchema.getFields(), options);
}

/** Convert a list of BigQuery {@link TableSchema} to Avro {@link org.apache.avro.Schema}. */
public static org.apache.avro.Schema toGenericAvroSchema(TableSchema tableSchema) {
return toGenericAvroSchema(tableSchema, false);
}

/** Convert a list of BigQuery {@link TableSchema} to Avro {@link org.apache.avro.Schema}. */
public static org.apache.avro.Schema toGenericAvroSchema(
TableSchema tableSchema, Boolean stringLogicalTypes) {
return toGenericAvroSchema("root", tableSchema.getFields(), stringLogicalTypes);
}

/** Convert a list of BigQuery {@link TableFieldSchema} to Avro {@link org.apache.avro.Schema}. */
public static org.apache.avro.Schema toGenericAvroSchema(
String schemaName, List<TableFieldSchema> fieldSchemas) {
return BigQueryAvroUtils.toGenericAvroSchema(schemaName, fieldSchemas);
return toGenericAvroSchema(schemaName, fieldSchemas, false);
}

/** Convert a list of BigQuery {@link TableFieldSchema} to Avro {@link org.apache.avro.Schema}. */
public static org.apache.avro.Schema toGenericAvroSchema(
String schemaName, List<TableFieldSchema> fieldSchemas, Boolean stringLogicalTypes) {
return BigQueryAvroUtils.toGenericAvroSchema(schemaName, fieldSchemas, stringLogicalTypes);
}

private static final BigQueryIO.TypedRead.ToBeamRowFunction<TableRow>
Expand Down Expand Up @@ -514,9 +541,20 @@ public static Row toBeamRow(GenericRecord record, Schema schema, ConversionOptio
return Row.withSchema(schema).addValues(valuesInOrder).build();
}

/**
* Convert generic record to Bq TableRow.
*
* @deprecated use {@link #convertGenericRecordToTableRow(GenericRecord)}
*/
@Deprecated
public static TableRow convertGenericRecordToTableRow(
GenericRecord record, TableSchema tableSchema) {
return BigQueryAvroUtils.convertGenericRecordToTableRow(record, tableSchema);
return convertGenericRecordToTableRow(record);
}

/** Convert generic record to Bq TableRow. */
public static TableRow convertGenericRecordToTableRow(GenericRecord record) {
return BigQueryAvroUtils.convertGenericRecordToTableRow(record);
}

/** Convert a Beam Row to a BigQuery TableRow. */
Expand Down Expand Up @@ -1039,6 +1077,21 @@ private static Object convertAvroNumeric(Object value) {
return tableSpec;
}

static TableSchema trimSchema(TableSchema schema, @Nullable List<String> selectedFields) {
if (selectedFields == null || selectedFields.isEmpty()) {
return schema;
}

List<TableFieldSchema> fields = schema.getFields();
List<TableFieldSchema> trimmedFields = new ArrayList<>();
for (TableFieldSchema field : fields) {
if (selectedFields.contains(field.getName())) {
trimmedFields.add(field);
}
}
return new TableSchema().setFields(trimmedFields);
}

private static @Nullable ServiceCallMetric callMetricForMethod(
@Nullable TableReference tableReference, String method) {
if (tableReference != null) {
Expand Down
Loading

0 comments on commit 2312dbe

Please sign in to comment.