diff --git a/.github/workflows/dev_cron.yml b/.github/workflows/dev_cron.yml index f96e6ed11..73fe1bee8 100644 --- a/.github/workflows/dev_cron.yml +++ b/.github/workflows/dev_cron.yml @@ -28,7 +28,7 @@ on: permissions: issues: write contents: read - + jobs: process: name: Process diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala index 6a8d1051b..043b59f77 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark313/LinearRegression.scala @@ -455,7 +455,15 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String) private def trainWithNormal( dataset: Dataset[_], instr: Instrumentation): LinearRegressionModel = { - val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty) + val handlePersistence = (dataset.storageLevel == StorageLevel.NONE) + + if (handlePersistence) { + dataset.persist(StorageLevel.MEMORY_AND_DISK) + dataset.count() + } + + val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0) + && (!isDefined(weightCol) || getWeightCol.isEmpty)) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility( @@ -486,6 +494,9 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String) model.diagInvAtWA.toArray, model.objectiveHistory) + if (handlePersistence) { + dataset.unpersist() + } return lrModel.setSummary(Some(trainingSummary)) } else { // For low dimensional data, WeightedLeastSquares is more efficient since the diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala index a921dfbfc..8f2078d2f 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark333/LinearRegression.scala @@ -454,7 +454,14 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String private def trainWithNormal( dataset: Dataset[_], instr: Instrumentation): LinearRegressionModel = { - val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty) + val handlePersistence = (dataset.storageLevel == StorageLevel.NONE) + + if (handlePersistence) { + dataset.persist(StorageLevel.MEMORY_AND_DISK) + dataset.count() + } + val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0) + && (!isDefined(weightCol) || getWeightCol.isEmpty)) val sparkContext = dataset.sparkSession.sparkContext val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice) val isPlatformSupported = Utils.checkClusterPlatformCompatibility( @@ -485,6 +492,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) + if (handlePersistence) { + dataset.unpersist() + } return lrModel.setSummary(Some(trainingSummary)) } else { // For low dimensional data, WeightedLeastSquares is more efficient since the