From e307341a377674f3c4aadbd9ad777e1478b5681a Mon Sep 17 00:00:00 2001 From: minmingzhu <45281494+minmingzhu@users.noreply.github.com> Date: Tue, 10 Sep 2024 10:58:34 +0800 Subject: [PATCH] [ML-381] Fix LR bugs (#385) * update spark to 3.3.3 Signed-off-by: minmingzhu * fix lr bugs * Update dev_cron.yml * update --------- Signed-off-by: minmingzhu --- .github/workflows/dev_cron.yml | 2 +- .../ml/regression/spark313/LinearRegression.scala | 13 ++++++++++++- .../ml/regression/spark333/LinearRegression.scala | 12 +++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/.github/workflows/dev_cron.yml b/.github/workflows/dev_cron.yml index f96e6ed11..73fe1bee8 100644 --- a/.github/workflows/dev_cron.yml +++ b/.github/workflows/dev_cron.yml @@ -28,7 +28,7 @@ on: permissions: issues: write contents: read - + jobs: process: name: Process diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala index 6a8d1051b..043b59f77 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala @@ -455,7 +455,15 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String) private def trainWithNormal( dataset: Dataset[_], instr: Instrumentation): LinearRegressionModel = { - val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty) + val handlePersistence = (dataset.storageLevel == StorageLevel.NONE) + + if (handlePersistence) { + dataset.persist(StorageLevel.MEMORY_AND_DISK) + dataset.count() + } + + val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0) + && (!isDefined(weightCol) || getWeightCol.isEmpty)) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility( @@ -486,6 +494,9 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String) model.diagInvAtWA.toArray, model.objectiveHistory) + if (handlePersistence) { + dataset.unpersist() + } return lrModel.setSummary(Some(trainingSummary)) } else { // For low dimensional data, WeightedLeastSquares is more efficient since the diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala index a921dfbfc..8f2078d2f 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala @@ -454,7 +454,14 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String private def trainWithNormal( dataset: Dataset[_], instr: Instrumentation): LinearRegressionModel = { - val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty) + val handlePersistence = (dataset.storageLevel == StorageLevel.NONE) + + if (handlePersistence) { + dataset.persist(StorageLevel.MEMORY_AND_DISK) + dataset.count() + } + val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0) + && (!isDefined(weightCol) || getWeightCol.isEmpty)) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility( @@ -485,6 +492,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) + if (handlePersistence) { + dataset.unpersist() + } return lrModel.setSummary(Some(trainingSummary)) } else { // For low dimensional data, WeightedLeastSquares is more efficient since the