From 2df494fd4e4e64b9357307fb0c5e8fc1b7491ac3 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 10 May 2024 14:03:08 +0800 Subject: [PATCH] [SPARK-48158][SQL] Add collation support for XML expressions ### What changes were proposed in this pull request? Introduce collation awareness for XML expressions: from_xml, schema_of_xml, to_xml. ### Why are the changes needed? Add collation support for XML expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for XML functions: from_xml, schema_of_xml, to_xml. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46507 from uros-db/xml-expressions. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../catalyst/expressions/xmlExpressions.scala | 9 +- .../sql/CollationSQLExpressionsSuite.scala | 124 ++++++++++++++++++ 2 files changed, 129 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 415d55d19ded2..237d740e04362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -140,7 +141,7 @@ case class XmlToStructs( converter(parser.parse(str)) } - override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -178,7 +179,7 @@ case class SchemaOfXml( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false @@ -226,7 +227,7 @@ case class SchemaOfXml( .map(ArrayType(_, containsNull = at.containsNull)) .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) case other: DataType => - xmlInferSchema.canonicalizeType(other).getOrElse(StringType) + xmlInferSchema.canonicalizeType(other).getOrElse(SQLConf.get.defaultStringType) } UTF8String.fromString(dataType.sql) @@ -320,7 +321,7 @@ case class StructsToXml( getAndReset() } - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 2b6390151bb9b..dd5703d1284a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.text.SimpleDateFormat + import scala.collection.immutable.Seq import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException} @@ -860,6 +862,128 @@ class CollationSQLExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("Support XmlToStructs xml expression with collation") { + case class XmlToStructsTestCase( + input: String, + collationName: String, + schema: String, + options: String, + result: Row, + structFields: Seq[StructField] + ) + + val testCases = Seq( + XmlToStructsTestCase("

1

", "UTF8_BINARY", "'a INT'", "", + Row(1), Seq( + StructField("a", IntegerType, nullable = true) + )), + XmlToStructsTestCase("

true0.8

", "UTF8_BINARY_LCASE", + "'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq( + StructField("A", BooleanType, nullable = true), + StructField("B", DoubleType, nullable = true) + )), + XmlToStructsTestCase("

Spark

", "UNICODE", "'s STRING'", "", + Row("Spark"), Seq( + StructField("s", StringType("UNICODE"), nullable = true) + )), + XmlToStructsTestCase("

", "UNICODE_CI", "'time Timestamp'", + ", map('timestampFormat', 'dd/MM/yyyy')", Row( + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0") + ), Seq( + StructField("time", TimestampType, nullable = true) + )) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select from_xml('${t.input}', ${t.schema} ${t.options}) + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StructType(t.structFields) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support SchemaOfXml xml expression with collation") { + case class SchemaOfXmlTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + SchemaOfXmlTestCase("

1

", "UTF8_BINARY", "STRUCT"), + SchemaOfXmlTestCase("

true0.8

", "UTF8_BINARY_LCASE", + "STRUCT"), + SchemaOfXmlTestCase("

", "UNICODE", "STRUCT<>"), + SchemaOfXmlTestCase("

123

", "UNICODE_CI", + "STRUCT>") + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select schema_of_xml('${t.input}') + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support StructsToXml xml expression with collation") { + case class StructsToXmlTestCase( + input: String, + collationName: String, + result: String + ) + + val testCases = Seq( + StructsToXmlTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY", + s""" + | 1 + | 2 + |""".stripMargin), + StructsToXmlTestCase("named_struct('A', true, 'B', 2.0)", "UTF8_BINARY_LCASE", + s""" + | true + | 2.0 + |""".stripMargin), + StructsToXmlTestCase("named_struct()", "UNICODE", + ""), + StructsToXmlTestCase("named_struct('time', to_timestamp('2015-08-26'))", "UNICODE_CI", + s""" + | + |""".stripMargin) + ) + + // Supported collations + testCases.foreach(t => { + val query = + s""" + |select to_xml(${t.input}) + |""".stripMargin + // Result + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + test("Support ParseJson & TryParseJson variant expressions with collation") { case class ParseJsonTestCase( input: String,