Skip to content

Commit

Permalink
Bulk Load CDK: Mapper Pipeline (#48371)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Nov 7, 2024
1 parent 87f514d commit 2e3023a
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,47 @@ import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Change
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange.Reason

open class AirbyteValueIdentityMapper(
val meta: DestinationRecord.Meta,
) {
interface AirbyteValueMapper {
val collectedChanges: List<DestinationRecord.Change>

fun map(
value: AirbyteValue,
schema: AirbyteType,
path: List<String> = emptyList(),
): AirbyteValue
}

/** An optimized identity mapper that just passes through. */
class AirbyteValueNoopMapper : AirbyteValueMapper {
override val collectedChanges: List<DestinationRecord.Change> = emptyList()

override fun map(
value: AirbyteValue,
schema: AirbyteType,
path: List<String>,
): AirbyteValue = value
}

open class AirbyteValueIdentityMapper : AirbyteValueMapper {
override val collectedChanges: List<DestinationRecord.Change>
get() = changes.toList().also { changes.clear() }

private val changes: MutableList<DestinationRecord.Change> = mutableListOf()

private fun collectFailure(
path: List<String>,
reason: Reason = Reason.DESTINATION_SERIALIZATION_ERROR
) {
meta.changes.add(DestinationRecord.Change(path.joinToString("."), Change.NULLED, reason))
val joined = path.joinToString(".")
if (changes.none { it.field == joined }) {
changes.add(DestinationRecord.Change(path.joinToString("."), Change.NULLED, reason))
}
}

fun map(
override fun map(
value: AirbyteValue,
schema: AirbyteType,
path: List<String> = emptyList()
path: List<String>,
): AirbyteValue =
try {
when (schema) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.DestinationRecord.Change

class MapperPipeline(
inputSchema: AirbyteType,
schemaValueMapperPairs: List<Pair<AirbyteSchemaMapper, AirbyteValueMapper>>,
) {
private val schemasWithMappers: List<Pair<AirbyteType, AirbyteValueMapper>>

val finalSchema: AirbyteType

init {
val (schemaMappers, valueMappers) = schemaValueMapperPairs.unzip()
val schemas =
schemaMappers.runningFold(inputSchema) { schema, mapper -> mapper.map(schema) }
schemasWithMappers = schemas.zip(valueMappers)
finalSchema = schemas.last()
}

fun map(data: AirbyteValue): Pair<AirbyteValue, List<Change>> {
val results =
schemasWithMappers.runningFold(data) { value, (schema, mapper) ->
mapper.map(value, schema)
}
val changesFlattened =
schemasWithMappers.flatMap { it.second.collectedChanges }.toSet().toList()
return results.last() to changesFlattened
}
}

interface MapperPipelineFactory {
fun create(stream: DestinationStream): MapperPipeline
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.data.json.toJson
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.util.serializeToString

class SchemalessTypesToJson : AirbyteSchemaIdentityMapper {
Expand All @@ -15,7 +14,7 @@ class SchemalessTypesToJson : AirbyteSchemaIdentityMapper {
override fun mapArrayWithoutSchema(schema: ArrayTypeWithoutSchema): AirbyteType = StringType
}

class SchemalessValuesToJson(meta: DestinationRecord.Meta) : AirbyteValueIdentityMapper(meta) {
class SchemalessValuesToJson : AirbyteValueIdentityMapper() {
override fun mapObjectWithoutSchema(
value: ObjectValue,
schema: ObjectTypeWithoutSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.data.json.toJson
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.util.serializeToString

class SchemalessTypesToJsonString : AirbyteSchemaIdentityMapper {
Expand All @@ -15,8 +14,7 @@ class SchemalessTypesToJsonString : AirbyteSchemaIdentityMapper {
override fun mapArrayWithoutSchema(schema: ArrayTypeWithoutSchema): AirbyteType = StringType
}

class SchemalessValuesToJsonString(meta: DestinationRecord.Meta) :
AirbyteValueIdentityMapper(meta) {
class SchemalessValuesToJsonString : AirbyteValueIdentityMapper() {
override fun mapObjectWithoutSchema(
value: ObjectValue,
schema: ObjectTypeWithoutSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.message.DestinationRecord
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.LocalTime
Expand All @@ -30,7 +29,7 @@ class TimeStringTypeToIntegerType : AirbyteSchemaIdentityMapper {
* NOTE: To keep parity with the old avro/parquet code, we will always first try to parse the value
* as with timezone, then fall back to without. But in theory we should be more strict.
*/
class TimeStringToInteger(meta: DestinationRecord.Meta) : AirbyteValueIdentityMapper(meta) {
class TimeStringToInteger : AirbyteValueIdentityMapper() {
companion object {
private val DATE_TIME_FORMATTER: DateTimeFormatter =
DateTimeFormatter.ofPattern(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.message.DestinationRecord

class UnionTypeToDisjointRecord : AirbyteSchemaIdentityMapper {
override fun mapUnion(schema: UnionType): AirbyteType {
if (schema.options.size < 2) {
Expand Down Expand Up @@ -46,7 +44,7 @@ class UnionTypeToDisjointRecord : AirbyteSchemaIdentityMapper {
}
}

class UnionValueToDisjointRecord(meta: DestinationRecord.Meta) : AirbyteValueIdentityMapper(meta) {
class UnionValueToDisjointRecord : AirbyteValueIdentityMapper() {
override fun mapUnion(
value: AirbyteValue,
schema: UnionType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import io.airbyte.cdk.load.test.util.ValueTestBuilder
Expand Down Expand Up @@ -39,10 +38,10 @@ class AirbyteValueIdentityMapperTest {
.endRecord()
.build()

val meta = DestinationRecord.Meta()
val values = AirbyteValueIdentityMapper(meta).map(inputValues, inputSchema)
val mapper = AirbyteValueIdentityMapper()
val values = mapper.map(inputValues, inputSchema)
Assertions.assertEquals(expectedValues, values)
Assertions.assertTrue(meta.changes.isEmpty())
Assertions.assertTrue(mapper.collectedChanges.isEmpty())
}

@Test
Expand All @@ -56,16 +55,15 @@ class AirbyteValueIdentityMapperTest {
nameOverride = "bad"
)
.build()
val meta = DestinationRecord.Meta()
val values = AirbyteValueIdentityMapper(meta).map(inputValues, inputSchema) as ObjectValue
Assertions.assertTrue(meta.changes.isNotEmpty())
val mapper = AirbyteValueIdentityMapper()
val values = mapper.map(inputValues, inputSchema) as ObjectValue
val changes = mapper.collectedChanges
Assertions.assertTrue(changes.isNotEmpty())
Assertions.assertTrue(values.values["bad"] is NullValue)
Assertions.assertTrue(meta.changes[0].field == "bad")
Assertions.assertTrue(changes[0].field == "bad")
Assertions.assertTrue(changes[0].change == AirbyteRecordMessageMetaChange.Change.NULLED)
Assertions.assertTrue(
meta.changes[0].change == AirbyteRecordMessageMetaChange.Change.NULLED
)
Assertions.assertTrue(
meta.changes[0].reason ==
changes[0].reason ==
AirbyteRecordMessageMetaChange.Reason.DESTINATION_SERIALIZATION_ERROR
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import io.airbyte.cdk.load.test.util.ValueTestBuilder
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

class MapperPipelineTest {
class TurnSchemalessObjectTypesIntoIntegers : AirbyteSchemaIdentityMapper {
override fun mapObjectWithoutSchema(schema: ObjectTypeWithoutSchema): AirbyteType =
IntegerType
}

class TurnSchemalessObjectsIntoIntegers : AirbyteValueIdentityMapper() {
override fun mapObjectWithoutSchema(
value: ObjectValue,
schema: ObjectTypeWithoutSchema,
path: List<String>
): AirbyteValue {
if (value.values.size == 1) {
throw IllegalStateException("Arbitrarily reject 1")
}
return IntegerValue(value.values.size.toLong())
}
}

class TurnIntegerTypesIntoStrings : AirbyteSchemaIdentityMapper {
override fun mapInteger(schema: IntegerType): AirbyteType = StringType
}

class TurnIntegersIntoStrings : AirbyteValueIdentityMapper() {
override fun mapInteger(value: IntegerValue, path: List<String>): AirbyteValue {
if (value.value == 2L) {
throw IllegalStateException("Arbitrarily reject 2")
}
return StringValue(value.value.toString())
}
}

private fun makePipeline(schema: AirbyteType) =
MapperPipeline(
schema,
listOf(
TurnIntegerTypesIntoStrings() to TurnIntegersIntoStrings(),
TurnSchemalessObjectTypesIntoIntegers() to TurnSchemalessObjectsIntoIntegers(),
)
)

@Test
fun testSuccessfulPipeline() {
val (inputSchema, expectedSchema) =
SchemaRecordBuilder<Root>()
.with(ObjectTypeWithoutSchema, IntegerType)
.with(IntegerType, StringType)
.withRecord()
.with(IntegerType, StringType)
.with(BooleanType, BooleanType) // expect unchanged
.endRecord()
.build()

val pipeline = makePipeline(inputSchema)
Assertions.assertEquals(
expectedSchema,
pipeline.finalSchema,
"final schema matches expected transformed schema"
)
}

@Test
fun testRecordMapping() {
val (inputValue, inputSchema, expectedOutput) =
ValueTestBuilder<Root>()
.with(
ObjectValue(linkedMapOf("a" to IntegerValue(1), "b" to IntegerValue(2))),
ObjectTypeWithoutSchema,
IntegerValue(2)
)
.with(IntegerValue(1), IntegerType, StringValue("1"))
.withRecord()
.with(IntegerValue(3), IntegerType, StringValue("3"))
.with(BooleanValue(true), BooleanType, BooleanValue(true)) // expect unchanged
.endRecord()
.build()
val pipeline = makePipeline(inputSchema)
val (result, changes) = pipeline.map(inputValue)

Assertions.assertEquals(0, changes.size, "no changes were captured")
Assertions.assertEquals(expectedOutput, result, "data was transformed as expected")
}

@Test
fun testFailedMapping() {
val (inputValue, inputSchema, _) =
ValueTestBuilder<Root>()
.with(
ObjectValue(linkedMapOf("a" to IntegerValue(1))),
ObjectTypeWithoutSchema,
NullValue,
nullable = true
) // fail: reject size==1
.with(IntegerValue(1), IntegerType, StringValue("1"))
.withRecord()
.with(IntegerValue(2), IntegerType, NullValue, nullable = true) // fail: reject 2
.with(BooleanValue(true), BooleanType, BooleanValue(true)) // expect unchanged
.endRecord()
.build()
val pipeline = makePipeline(inputSchema)
val (_, changes) = pipeline.map(inputValue)

Assertions.assertEquals(2, changes.size, "two failures were captured")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import io.airbyte.cdk.load.test.util.ValueTestBuilder
Expand Down Expand Up @@ -85,7 +84,7 @@ class SchemalessTypesToJsonStringTest {
ArrayType(FieldType(StringType, nullable = false))
)
.build()
val mapper = SchemalessValuesToJsonString(DestinationRecord.Meta())
val mapper = SchemalessValuesToJsonString()
val output = mapper.map(inputValues, inputSchema)
Assertions.assertEquals(expectedOutput, output)
}
Expand Down Expand Up @@ -120,7 +119,7 @@ class SchemalessTypesToJsonStringTest {
ArrayType(FieldType(StringType, nullable = false))
)
.build()
val mapper = SchemalessValuesToJsonString(DestinationRecord.Meta())
val mapper = SchemalessValuesToJsonString()
val output = mapper.map(inputValues, inputSchema)
Assertions.assertEquals(expectedOutput, output)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.data.json.toAirbyteValue
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import io.airbyte.cdk.load.test.util.ValueTestBuilder
Expand Down Expand Up @@ -85,7 +84,7 @@ class SchemalessTypesToJsonTest {
ArrayType(FieldType(StringType, nullable = false))
)
.build()
val mapper = SchemalessValuesToJson(DestinationRecord.Meta())
val mapper = SchemalessValuesToJson()
val output = mapper.map(inputValues, inputSchema)
Assertions.assertEquals(expectedOutput, output)
}
Expand Down Expand Up @@ -120,7 +119,7 @@ class SchemalessTypesToJsonTest {
ArrayType(FieldType(StringType, nullable = false))
)
.build()
val mapper = SchemalessValuesToJson(DestinationRecord.Meta())
val mapper = SchemalessValuesToJson()
val output = mapper.map(inputValues, inputSchema)
Assertions.assertEquals(expectedOutput, output)
}
Expand Down
Loading

0 comments on commit 2e3023a

Please sign in to comment.