Skip to content

Commit

Permalink
[SPARK-48805][SQL][ML][SS][AVRO][EXAMPLES] Replace calls to bridged A…
Browse files Browse the repository at this point in the history
…PIs based on `SparkSession#sqlContext` with `SparkSession` API

### What changes were proposed in this pull request?
In the internal code of Spark, there are instances where, despite having a SparkSession instance, the bridged APIs based on SparkSession#sqlContext are still used. Therefore, this PR makes some simplifications in this regard:"

1. `SparkSession#sqlContext#read` -> `SparkSession#read`

```scala
/**
   * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
   * `DataFrame`.
   * {{{
   *   sqlContext.read.parquet("/path/to/file.parquet")
   *   sqlContext.read.schema(schema).json("/path/to/file.json")
   * }}}
   *
   * group genericdata
   * since 1.4.0
   */
  def read: DataFrameReader = sparkSession.read
```

2. `SparkSession#sqlContext#setConf` -> `SparkSession#conf#set`

```scala
  /**
   * Set the given Spark SQL configuration property.
   *
   * group config
   * since 1.0.0
   */
  def setConf(key: String, value: String): Unit = {
    sparkSession.conf.set(key, value)
  }
```

3. `SparkSession#sqlContext#getConf` -> `SparkSession#conf#get`

```scala
/**
   * Return the value of Spark SQL configuration property for the given key.
   *
   * group config
   * since 1.0.0
   */
  def getConf(key: String): String = {
    sparkSession.conf.get(key)
  }
```

4. `SparkSession#sqlContext#createDataFrame` -> `SparkSession#createDataFrame`

```scala
/**
   * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples).
   *
   * group dataframes
   * since 1.3.0
   */
  def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
    sparkSession.createDataFrame(rdd)
  }
```

5. `SparkSession#sqlContext#sessionState` -> `SparkSession#sessionState`

```scala
private[sql] def sessionState: SessionState = sparkSession.sessionState
```

6. `SparkSession#sqlContext#sharedState` -> `SparkSession#sharedState`

```scala
private[sql] def sharedState: SharedState = sparkSession.sharedState
```

7. `SparkSession#sqlContext#streams` -> `SparkSession#streams`

```scala
/**
   * Returns a `StreamingQueryManager` that allows managing all the
   * [[org.apache.spark.sql.streaming.StreamingQuery StreamingQueries]] active on `this` context.
   *
   * since 2.0.0
   */
  def streams: StreamingQueryManager = sparkSession.streams
```

8. `SparkSession#sqlContext#uncacheTable` -> `SparkSession#catalog#uncacheTable`

```scala
/**
   * Removes the specified table from the in-memory cache.
   * group cachemgmt
   * since 1.3.0
   */
  def uncacheTable(tableName: String): Unit = {
    sparkSession.catalog.uncacheTable(tableName)
  }
```

### Why are the changes needed?
Decrease the nesting levels of API calls

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
- Pass GitHub Actions
- Manually checked `SparkHiveExample`

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47210 from LuciferYang/session.sqlContext.

Authored-by: yangjie01 <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
  • Loading branch information
LuciferYang committed Jul 4, 2024
1 parent d5dc223 commit 54b7558
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ abstract class AvroSuite
dataFileWriter.flush()
dataFileWriter.close()

val df = spark.sqlContext.read.format("avro").load(nativeWriterPath)
val df = spark.read.format("avro").load(nativeWriterPath)
assertResult(Row(field1, null, null, null))(df.selectExpr("field1.*").first())
assertResult(Row(null, field2, null, null))(df.selectExpr("field2.*").first())
assertResult(Row(null, null, field3, null))(df.selectExpr("field3.*").first())
Expand All @@ -575,7 +575,7 @@ abstract class AvroSuite

df.write.format("avro").option("avroSchema", schema.toString).save(sparkWriterPath)

val df2 = spark.sqlContext.read.format("avro").load(nativeWriterPath)
val df2 = spark.read.format("avro").load(nativeWriterPath)
assertResult(Row(field1, null, null, null))(df2.selectExpr("field1.*").first())
assertResult(Row(null, field2, null, null))(df2.selectExpr("field2.*").first())
assertResult(Row(null, null, field3, null))(df2.selectExpr("field3.*").first())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ object SparkHiveExample {
// ... Order may vary, as spark processes the partitions in parallel.

// Turn on flag for Hive Dynamic Partitioning
spark.sqlContext.setConf("hive.exec.dynamic.partition", "true")
spark.sqlContext.setConf("hive.exec.dynamic.partition.mode", "nonstrict")
spark.conf.set("hive.exec.dynamic.partition", "true")
spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
// Create a Hive partitioned table using DataFrame API
df.write.partitionBy("key").format("hive").saveAsTable("hive_part_tbl")
// Partitioned column `key` will be moved to the end of the schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class LibSVMRelationSuite
StructField("labelFoo", DoubleType, false),
StructField("featuresBar", VectorType, false))
)
val df = spark.sqlContext.createDataFrame(rawData, struct)
val df = spark.createDataFrame(rawData, struct)

val writePath = Utils.createTempDir().getPath

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ trait FlatMapGroupsWithStateExecBase

override def validateAndMaybeEvolveStateSchema(hadoopConf: Configuration): Unit = {
StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
groupingAttributes.toStructType, stateManager.stateSchema, session.sqlContext.sessionState)
groupingAttributes.toStructType, stateManager.stateSchema, session.sessionState)
}

override protected def doExecute(): RDD[InternalRow] = {
Expand All @@ -215,14 +215,14 @@ trait FlatMapGroupsWithStateExecBase
if (hasInitialState) {
// If the user provided initial state we need to have the initial state and the
// data in the same partition so that we can still have just one commit at the end.
val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
val storeConf = new StateStoreConf(session.sessionState.conf)
val hadoopConfBroadcast = sparkContext.broadcast(
new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
new SerializableConfiguration(session.sessionState.newHadoopConf()))
child.execute().stateStoreAwareZipPartitions(
initialState.execute(),
getStateInfo,
storeNames = Seq(),
session.sqlContext.streams.stateStoreCoordinator) {
session.streams.stateStoreCoordinator) {
// The state store aware zip partitions will provide us with two iterators,
// child data iterator and the initial state iterator per partition.
case (partitionId, childDataIterator, initStateIterator) =>
Expand All @@ -246,8 +246,8 @@ trait FlatMapGroupsWithStateExecBase
groupingAttributes.toStructType,
stateManager.stateSchema,
NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType),
session.sqlContext.sessionState,
Some(session.sqlContext.streams.stateStoreCoordinator)
session.sessionState,
Some(session.streams.stateStoreCoordinator)
) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
val processor = createInputProcessor(store)
processDataWithPartition(singleIterator, store, processor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,14 @@ case class TransformWithStateExec(
validateTimeMode()

if (hasInitialState) {
val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
val storeConf = new StateStoreConf(session.sessionState.conf)
val hadoopConfBroadcast = sparkContext.broadcast(
new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
new SerializableConfiguration(session.sessionState.newHadoopConf()))
child.execute().stateStoreAwareZipPartitions(
initialState.execute(),
getStateInfo,
storeNames = Seq(),
session.sqlContext.streams.stateStoreCoordinator) {
session.streams.stateStoreCoordinator) {
// The state store aware zip partitions will provide us with two iterators,
// child data iterator and the initial state iterator per partition.
case (partitionId, childDataIterator, initStateIterator) =>
Expand Down Expand Up @@ -393,8 +393,8 @@ case class TransformWithStateExec(
KEY_ROW_SCHEMA,
VALUE_ROW_SCHEMA,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA),
session.sqlContext.sessionState,
Some(session.sqlContext.streams.stateStoreCoordinator),
session.sessionState,
Some(session.streams.stateStoreCoordinator),
useColumnFamilies = true
) {
case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
Expand All @@ -404,7 +404,7 @@ case class TransformWithStateExec(
// If the query is running in batch mode, we need to create a new StateStore and instantiate
// a temp directory on the executors in mapPartitionsWithIndex.
val hadoopConfBroadcast = sparkContext.broadcast(
new SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
new SerializableConfiguration(session.sessionState.newHadoopConf()))
child.execute().mapPartitionsWithIndex[InternalRow](
(i: Int, iter: Iterator[InternalRow]) => {
initNewStateStoreAndProcessData(i, hadoopConfBroadcast) { store =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
assert(spark.catalog.isCached("view2"))

val oldView = spark.table("view2")
spark.sqlContext.uncacheTable("view1")
spark.catalog.uncacheTable("view1")
assert(storeAnalyzed ==
spark.sharedState.cacheManager.lookupCachedData(oldView).isDefined,
s"when storeAnalyzed = $storeAnalyzed")
Expand All @@ -1493,7 +1493,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
assert(spark.catalog.isCached("view2"))

val oldView = spark.table("view2")
spark.sqlContext.uncacheTable(s"$db.view1")
spark.catalog.uncacheTable(s"$db.view1")
assert(storeAnalyzed ==
spark.sharedState.cacheManager.lookupCachedData(oldView).isDefined,
s"when storeAnalyzed = $storeAnalyzed")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ class QueryExecutionErrorsSuite
checkError(
exception = intercept[SparkIllegalArgumentException] {
val row = spark.sparkContext.parallelize(Seq(1, 2)).map(Row(_))
spark.sqlContext.createDataFrame(row, StructType.fromString("StructType()"))
spark.createDataFrame(row, StructType.fromString("StructType()"))
},
errorClass = "UNSUPPORTED_DATATYPE",
parameters = Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class HiveParquetMetastoreSuite extends ParquetPartitioningTest {
(1 to 10).map(i => Tuple1(Seq(Integer.valueOf(i), null))).toDF("a")
.createOrReplaceTempView("jt_array")

assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true")
assert(spark.conf.get(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true")
}

override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class HiveUDFDynamicLoadSuite extends QueryTest with SQLTestUtils with TestHiveS

assert(Thread.currentThread().getContextClassLoader ne sparkClassLoader)
assert(Thread.currentThread().getContextClassLoader eq
spark.sqlContext.sharedState.jarClassLoader)
spark.sharedState.jarClassLoader)

val udfExpr = udfInfo.fnCreateHiveUDFExpression()
// force initializing - this is what we do in HiveSessionCatalog
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class PartitionedTablePerfStatsSuite
}

genericTest("partitioned pruned table reports only selected files") { spec =>
assert(spark.sqlContext.getConf(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true")
assert(spark.conf.get(HiveUtils.CONVERT_METASTORE_PARQUET.key) == "true")
withTable("test") {
withTempDir { dir =>
spec.setupTable("test", dir)
Expand Down

0 comments on commit 54b7558

Please sign in to comment.