Skip to content

Commit

Permalink
[ML-381] Fix LR bugs (#385)
Browse files Browse the repository at this point in the history
* update spark to 3.3.3

Signed-off-by: minmingzhu <[email protected]>

* fix lr bugs

* Update dev_cron.yml

* update

---------

Signed-off-by: minmingzhu <[email protected]>
  • Loading branch information
minmingzhu authored Sep 10, 2024
1 parent 4dc4f0a commit e307341
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/dev_cron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ on:
permissions:
issues: write
contents: read

jobs:
process:
name: Process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,15 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String)
private def trainWithNormal(
dataset: Dataset[_],
instr: Instrumentation): LinearRegressionModel = {
val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty)
val handlePersistence = (dataset.storageLevel == StorageLevel.NONE)

if (handlePersistence) {
dataset.persist(StorageLevel.MEMORY_AND_DISK)
dataset.count()
}

val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0)
&& (!isDefined(weightCol) || getWeightCol.isEmpty))
val sparkContext = dataset.sparkSession.sparkContext
val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice)
val isPlatformSupported = Utils.checkClusterPlatformCompatibility(
Expand Down Expand Up @@ -486,6 +494,9 @@ class LinearRegression @Since("1.3") (@Since("1.3.0") override val uid: String)
model.diagInvAtWA.toArray,
model.objectiveHistory)

if (handlePersistence) {
dataset.unpersist()
}
return lrModel.setSummary(Some(trainingSummary))
} else {
// For low dimensional data, WeightedLeastSquares is more efficient since the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,14 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
private def trainWithNormal(
dataset: Dataset[_],
instr: Instrumentation): LinearRegressionModel = {
val paramSupported = ($(regParam) == 0) && (!isDefined(weightCol) || getWeightCol.isEmpty)
val handlePersistence = (dataset.storageLevel == StorageLevel.NONE)

if (handlePersistence) {
dataset.persist(StorageLevel.MEMORY_AND_DISK)
dataset.count()
}
val paramSupported = ($(regParam) == 0 || ($(regParam) != 0 && $(elasticNetParam) == 0)
&& (!isDefined(weightCol) || getWeightCol.isEmpty))
val sparkContext = dataset.sparkSession.sparkContext
val useDevice = sparkContext.getConf.get("spark.oap.mllib.device", Utils.DefaultComputeDevice)
val isPlatformSupported = Utils.checkClusterPlatformCompatibility(
Expand Down Expand Up @@ -485,6 +492,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model.diagInvAtWA.toArray,
model.objectiveHistory)

if (handlePersistence) {
dataset.unpersist()
}
return lrModel.setSummary(Some(trainingSummary))
} else {
// For low dimensional data, WeightedLeastSquares is more efficient since the
Expand Down

0 comments on commit e307341

Please sign in to comment.