Skip to content

Commit

Permalink
[SPARK-47682][SQL] Support cast from variant
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR allows casting from variant to another type. It has the same semantics as `variant_get` with an empty path.

### Why are the changes needed?

It can bring much convenience if the user can directly cast variant into another type rather than having to use `variant_get` with an empty path.

### Does this PR introduce _any_ user-facing change?

Yes. Casting from variant was previously not allowed but now allowd.

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45807 from chenhao-db/cast_from_variant.

Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
chenhao-db authored and cloud-fan committed Apr 9, 2024
1 parent 08c4963 commit 319edfd
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ private[sql] object UpCastRule {
*/
def canUpCast(from: DataType, to: DataType): Boolean = (from, to) match {
case _ if from == to => true
case (VariantType, _) => false
case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true
case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true
case (f, t) if legalNumericPrecedence(f, t) => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte,
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -127,6 +127,8 @@ object Cast extends QueryErrorsBase {
case (BooleanType, _: NumericType) => true
case (TimestampType, _: NumericType) => true

case (VariantType, _) => variant.VariantGet.checkDataType(to)

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canAnsiCast(fromType, toType) && resolvableNullability(fn, tn)

Expand Down Expand Up @@ -233,6 +235,8 @@ object Cast extends QueryErrorsBase {
case (TimestampType, _: NumericType) => true
case (_: NumericType, _: NumericType) => true

case (VariantType, _) => variant.VariantGet.checkDataType(to)

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)
Expand Down Expand Up @@ -267,6 +271,7 @@ object Cast extends QueryErrorsBase {
* * Cast.castToTimestamp
*/
def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match {
case (VariantType, _) => true
case (_: StringType, TimestampType) => true
case (TimestampType, StringType) => true
case (DateType, TimestampType) => true
Expand Down Expand Up @@ -340,6 +345,7 @@ object Cast extends QueryErrorsBase {
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
case (NullType, _) => false // empty array or map case
case (_, _) if from == to => false
case (VariantType, _) => true

case (_: StringType, BinaryType | _: StringType) => false
case (_: StringType, _) => true
Expand Down Expand Up @@ -1106,6 +1112,10 @@ case class Cast(
// But for nested types like struct, we might reach here for nested null type field.
// We won't call the returned function actually, but returns a placeholder.
_ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to)
} else if (from.isInstanceOf[VariantType]) {
buildCast[VariantVal](_, v => {
variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId)
})
} else {
to match {
case dt if dt == from => identity[Any]
Expand Down Expand Up @@ -1198,6 +1208,20 @@ case class Cast(

case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;"
case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
val tmp = ctx.freshVariable("tmp", classOf[Object])
val dataTypeArg = ctx.addReferenceObj("dataType", to)
val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId)
val failOnError = evalMode != EvalMode.TRY
val cls = classOf[variant.VariantGet].getName
code"""
Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneIdArg);
if ($tmp == null) {
$evNull = true;
} else {
$evPrim = (${CodeGenerator.boxedType(to)})$tmp;
}
"""
case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim)
case BinaryType => castToBinaryCode(from)
case DateType => castToDateCode(from, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,19 @@ case object VariantGet {
VariantGet.cast(v, dataType, failOnError, zoneId)
}

/**
* A simple wrapper of the `cast` function that takes `Variant` rather than `VariantVal`. The
* `Cast` expression uses it and makes the implementation simpler.
*/
def cast(
input: VariantVal,
dataType: DataType,
failOnError: Boolean,
zoneId: Option[String]): Any = {
val v = new Variant(input.getValue, input.getMetadata)
VariantGet.cast(v, dataType, failOnError, zoneId)
}

/**
* Cast a variant `v` into a target data type `dataType`. If the variant represents a variant
* null, the result is always a SQL NULL. The cast may fail due to an illegal type combination
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,14 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
check(expectedResult4, smallObject, smallMetadata)
}

private def variantGet(input: String, path: String, dataType: DataType): VariantGet = {
val inputVariant = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input))
VariantGet(Literal(inputVariant), Literal(path), dataType, failOnError = true)
}
private def parseJson(input: String): VariantVal =
VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input))

private def tryVariantGet(input: String, path: String, dataType: DataType): VariantGet = {
val inputVariant = VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input))
VariantGet(Literal(inputVariant), Literal(path), dataType, failOnError = false)
}
private def variantGet(input: String, path: String, dataType: DataType): VariantGet =
VariantGet(Literal(parseJson(input)), Literal(path), dataType, failOnError = true)

private def tryVariantGet(input: String, path: String, dataType: DataType): VariantGet =
VariantGet(Literal(parseJson(input)), Literal(path), dataType, failOnError = false)

private def testVariantGet(input: String, path: String, dataType: DataType, output: Any): Unit = {
checkEvaluation(variantGet(input, path, dataType), output)
Expand Down Expand Up @@ -642,4 +641,48 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkInvalidPath("$[-1]")
checkInvalidPath("""$['"]""")
}

test("cast from variant") {
// We do not test too many type combinations, as the cast implementation is mostly the same as
// variant_get.

def checkCast(input: Any, dataType: DataType, output: Any): Unit = {
for (mode <- Seq(EvalMode.LEGACY, EvalMode.ANSI, EvalMode.TRY)) {
checkEvaluation(Cast(Literal(input), dataType, evalMode = mode), output)
}
}

def checkInvalidCast(input: Any, dataType: DataType, tryOutput: Any): Unit = {
// Casting from variant is not affected by the ANSI flag.
for (mode <- Seq(EvalMode.LEGACY, EvalMode.ANSI)) {
checkExceptionInExpression[SparkRuntimeException](
Cast(Literal(input), dataType, evalMode = mode),
"INVALID_VARIANT_CAST"
)
}
checkEvaluation(Cast(Literal(input), dataType, evalMode = EvalMode.TRY), tryOutput)
}

checkCast(parseJson("1"), StringType, "1")
// Other to-string casts never produce NULL when the input is not NULL, but variant-to-string
// cast can produce NULL when the input is a variant null (not NULL).
checkCast(parseJson("null"), StringType, null)
checkCast(parseJson("\"1\""), IntegerType, 1)

checkInvalidCast(parseJson("2147483648"), IntegerType, null)
checkInvalidCast(parseJson("[2147483648, 1]"), ArrayType(IntegerType), Array(null, 1))

checkCast(Array(null, parseJson("true")), ArrayType(BooleanType), Array(null, true))
checkCast(
Array(null, parseJson("false"), parseJson("null")),
ArrayType(StringType),
Array(null, "false", null)
)
checkCast(Array(parseJson("[1]")), ArrayType(ArrayType(IntegerType)), Array(Array(1)))
checkInvalidCast(
Array(parseJson("\"hello\""), null, parseJson("\"1\"")),
ArrayType(IntegerType),
Array(null, null, 1)
)
}
}

0 comments on commit 319edfd

Please sign in to comment.