Skip to content

Commit

Permalink
[SPARK-48158][SQL] Add collation support for XML expressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Introduce collation awareness for XML expressions: from_xml, schema_of_xml, to_xml.

### Why are the changes needed?
Add collation support for XML expressions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for XML functions: from_xml, schema_of_xml, to_xml.

### How was this patch tested?
E2e sql tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46507 from uros-db/xml-expressions.

Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
uros-db authored and cloud-fan committed May 10, 2024
1 parent 33cac44 commit 2df494f
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -140,7 +141,7 @@ case class XmlToStructs(
converter(parser.parse(str))
}

override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil

override def sql: String = schema match {
case _: MapType => "entries"
Expand Down Expand Up @@ -178,7 +179,7 @@ case class SchemaOfXml(
child = child,
options = ExprUtils.convertToMapData(options))

override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType

override def nullable: Boolean = false

Expand Down Expand Up @@ -226,7 +227,7 @@ case class SchemaOfXml(
.map(ArrayType(_, containsNull = at.containsNull))
.getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull))
case other: DataType =>
xmlInferSchema.canonicalizeType(other).getOrElse(StringType)
xmlInferSchema.canonicalizeType(other).getOrElse(SQLConf.get.defaultStringType)
}

UTF8String.fromString(dataType.sql)
Expand Down Expand Up @@ -320,7 +321,7 @@ case class StructsToXml(
getAndReset()
}

override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import java.text.SimpleDateFormat

import scala.collection.immutable.Seq

import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkRuntimeException}
Expand Down Expand Up @@ -860,6 +862,128 @@ class CollationSQLExpressionsSuite
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support XmlToStructs xml expression with collation") {
case class XmlToStructsTestCase(
input: String,
collationName: String,
schema: String,
options: String,
result: Row,
structFields: Seq[StructField]
)

val testCases = Seq(
XmlToStructsTestCase("<p><a>1</a></p>", "UTF8_BINARY", "'a INT'", "",
Row(1), Seq(
StructField("a", IntegerType, nullable = true)
)),
XmlToStructsTestCase("<p><A>true</A><B>0.8</B></p>", "UTF8_BINARY_LCASE",
"'A BOOLEAN, B DOUBLE'", "", Row(true, 0.8), Seq(
StructField("A", BooleanType, nullable = true),
StructField("B", DoubleType, nullable = true)
)),
XmlToStructsTestCase("<p><s>Spark</s></p>", "UNICODE", "'s STRING'", "",
Row("Spark"), Seq(
StructField("s", StringType("UNICODE"), nullable = true)
)),
XmlToStructsTestCase("<p><time>26/08/2015</time></p>", "UNICODE_CI", "'time Timestamp'",
", map('timestampFormat', 'dd/MM/yyyy')", Row(
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.S").parse("2015-08-26 00:00:00.0")
), Seq(
StructField("time", TimestampType, nullable = true)
))
)

// Supported collations
testCases.foreach(t => {
val query =
s"""
|select from_xml('${t.input}', ${t.schema} ${t.options})
|""".stripMargin
// Result
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val testQuery = sql(query)
checkAnswer(testQuery, Row(t.result))
val dataType = StructType(t.structFields)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
})
}

test("Support SchemaOfXml xml expression with collation") {
case class SchemaOfXmlTestCase(
input: String,
collationName: String,
result: String
)

val testCases = Seq(
SchemaOfXmlTestCase("<p><a>1</a></p>", "UTF8_BINARY", "STRUCT<a: BIGINT>"),
SchemaOfXmlTestCase("<p><A>true</A><B>0.8</B></p>", "UTF8_BINARY_LCASE",
"STRUCT<A: BOOLEAN, B: DOUBLE>"),
SchemaOfXmlTestCase("<p></p>", "UNICODE", "STRUCT<>"),
SchemaOfXmlTestCase("<p><A>1</A><A>2</A><A>3</A></p>", "UNICODE_CI",
"STRUCT<A: ARRAY<BIGINT>>")
)

// Supported collations
testCases.foreach(t => {
val query =
s"""
|select schema_of_xml('${t.input}')
|""".stripMargin
// Result
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val testQuery = sql(query)
checkAnswer(testQuery, Row(t.result))
val dataType = StringType(t.collationName)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
})
}

test("Support StructsToXml xml expression with collation") {
case class StructsToXmlTestCase(
input: String,
collationName: String,
result: String
)

val testCases = Seq(
StructsToXmlTestCase("named_struct('a', 1, 'b', 2)", "UTF8_BINARY",
s"""<ROW>
| <a>1</a>
| <b>2</b>
|</ROW>""".stripMargin),
StructsToXmlTestCase("named_struct('A', true, 'B', 2.0)", "UTF8_BINARY_LCASE",
s"""<ROW>
| <A>true</A>
| <B>2.0</B>
|</ROW>""".stripMargin),
StructsToXmlTestCase("named_struct()", "UNICODE",
"<ROW/>"),
StructsToXmlTestCase("named_struct('time', to_timestamp('2015-08-26'))", "UNICODE_CI",
s"""<ROW>
| <time>2015-08-26T00:00:00.000-07:00</time>
|</ROW>""".stripMargin)
)

// Supported collations
testCases.foreach(t => {
val query =
s"""
|select to_xml(${t.input})
|""".stripMargin
// Result
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) {
val testQuery = sql(query)
checkAnswer(testQuery, Row(t.result))
val dataType = StringType(t.collationName)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
}
})
}

test("Support ParseJson & TryParseJson variant expressions with collation") {
case class ParseJsonTestCase(
input: String,
Expand Down

0 comments on commit 2df494f

Please sign in to comment.