From 45b744b3db9dc103808fe52a2caad16a27def4f2 Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Wed, 14 Aug 2024 11:25:50 +0200 Subject: [PATCH 1/3] test reworked weights --- .github/workflows/r-cmd-check.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/r-cmd-check.yml b/.github/workflows/r-cmd-check.yml index d6135b1d1..4430efbde 100644 --- a/.github/workflows/r-cmd-check.yml +++ b/.github/workflows/r-cmd-check.yml @@ -26,6 +26,7 @@ jobs: config: - {os: ubuntu-latest, r: 'devel'} - {os: ubuntu-latest, r: 'release'} + - {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3@weights_reworked'} steps: - uses: actions/checkout@v4 From 2b8120ada977fb27923fd5a270971e7b44279482 Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Wed, 14 Aug 2024 11:55:35 +0200 Subject: [PATCH 2/3] fix actions --- .github/workflows/dev-cmd-check.yml | 1 + .github/workflows/r-cmd-check.yml | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dev-cmd-check.yml b/.github/workflows/dev-cmd-check.yml index 0e17ba4ed..c209d4982 100644 --- a/.github/workflows/dev-cmd-check.yml +++ b/.github/workflows/dev-cmd-check.yml @@ -25,6 +25,7 @@ jobs: matrix: config: - {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3'} + - {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3@weights_reworked'} steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/r-cmd-check.yml b/.github/workflows/r-cmd-check.yml index 4430efbde..d6135b1d1 100644 --- a/.github/workflows/r-cmd-check.yml +++ b/.github/workflows/r-cmd-check.yml @@ -26,7 +26,6 @@ jobs: config: - {os: ubuntu-latest, r: 'devel'} - {os: ubuntu-latest, r: 'release'} - - {os: ubuntu-latest, r: 'release', dev-package: 'mlr-org/mlr3@weights_reworked'} steps: - uses: actions/checkout@v4 From 39a62095dac4238028cd9e606c98dc62d3eba24b Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Fri, 16 Aug 2024 10:59:28 +0200 Subject: [PATCH 3/3] fix learners for new weights --- .lintr | 2 +- R/LearnerSurvCoxPH.R | 11 ++++++----- R/LearnerSurvRpart.R | 12 +++++++----- R/PipeOpTaskSurvRegr.R | 2 +- R/aaa.R | 2 +- inst/testthat/helper_autotest.R | 7 ++++++- tests/testthat/test_mlr_learners_surv_coxph.R | 6 +++--- tests/testthat/test_mlr_learners_surv_rpart.R | 1 - 8 files changed, 25 insertions(+), 18 deletions(-) diff --git a/.lintr b/.lintr index 5dbbc1451..8391db79f 100644 --- a/.lintr +++ b/.lintr @@ -5,5 +5,5 @@ linters: linters_with_defaults( object_name_linter = object_name_linter(c("snake_case", "CamelCase")), cyclocomp_linter = NULL, # do not check function complexity commented_code_linter = NULL, # allow code in comments - line_length_linter = line_length_linter(100L) + line_length_linter = line_length_linter(180L) ) diff --git a/R/LearnerSurvCoxPH.R b/R/LearnerSurvCoxPH.R index 95d83c88f..7bce45a07 100644 --- a/R/LearnerSurvCoxPH.R +++ b/R/LearnerSurvCoxPH.R @@ -22,7 +22,8 @@ LearnerSurvCoxPH = R6Class("LearnerSurvCoxPH", ties = p_fct(default = "efron", levels = c("efron", "breslow", "exact"), tags = "train"), singular.ok = p_lgl(default = TRUE, tags = "train"), type = p_fct(default = "efron", levels = c("efron", "aalen", "kalbfleisch-prentice"), tags = "predict"), - stype = p_int(1L, 2L, default = 2L, tags = "predict") + stype = p_int(1L, 2L, default = 2L, tags = "predict"), + use_weights = p_lgl(default = FALSE, tags = "train") ), predict_types = c("crank", "distr", "lp"), feature_types = c("logical", "integer", "numeric", "factor"), @@ -38,12 +39,12 @@ LearnerSurvCoxPH = R6Class("LearnerSurvCoxPH", .train = function(task) { pv = self$param_set$get_values(tags = "train") - if ("weights" %in% task$properties) { - pv$weights = task$weights$weight + if (isTRUE(pv$use_weights)) { + pv$weights = task$weights_learner$weight } + pv$use_weights = NULL - invoke(survival::coxph, formula = task$formula(), data = task$data(), - .args = pv, x = TRUE) + invoke(survival::coxph, formula = task$formula(), data = task$data(), .args = pv, x = TRUE) }, .predict = function(task) { diff --git a/R/LearnerSurvRpart.R b/R/LearnerSurvRpart.R index a9fff2ed8..1304b89fe 100644 --- a/R/LearnerSurvRpart.R +++ b/R/LearnerSurvRpart.R @@ -32,7 +32,8 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart", surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"), xval = p_int(0L, default = 10L, tags = "train"), cost = p_uty(tags = "train"), - keep_model = p_lgl(default = FALSE, tags = "train") + keep_model = p_lgl(default = FALSE, tags = "train"), + use_weights = p_lgl(default = FALSE, tags = "train") ) ps$set_values(xval = 0L) @@ -75,12 +76,13 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart", .train = function(task) { pv = self$param_set$get_values(tags = "train") names(pv) = replace(names(pv), names(pv) == "keep_model", "model") - if ("weights" %in% task$properties) { - pv = insert_named(pv, list(weights = task$weights$weight)) + + if (isTRUE(pv$use_weights)) { + pv$weights = task$weights_learner$weight } + pv$use_weights = NULL - invoke(rpart::rpart, formula = task$formula(), data = task$data(), - method = "exp", .args = pv) + invoke(rpart::rpart, formula = task$formula(), data = task$data(), method = "exp", .args = pv) }, .predict = function(task) { diff --git a/R/PipeOpTaskSurvRegr.R b/R/PipeOpTaskSurvRegr.R index 1ae86b66f..d15313376 100644 --- a/R/PipeOpTaskSurvRegr.R +++ b/R/PipeOpTaskSurvRegr.R @@ -195,7 +195,7 @@ PipeOpTaskSurvRegr = R6Class("PipeOpTaskSurvRegr", new_task = TaskRegr$new(id = input$id, backend = backend, target = target) if (method == "ipcw") { - new_task$col_roles$weight = "ipc_weights" + new_task$col_roles$weights_learner = "ipc_weights" } return(new_task) diff --git a/R/aaa.R b/R/aaa.R index 20925bec5..dce2ea8b7 100644 --- a/R/aaa.R +++ b/R/aaa.R @@ -50,7 +50,7 @@ register_reflections = function() { )), "type") x$task_col_roles$surv = x$task_col_roles$regr - x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weight", "stratum") + x$task_col_roles$dens = c("feature", "target", "label", "order", "group", "weights_learner", "stratum") x$task_col_roles$classif = unique(c(x$task_col_roles$classif, "original_ids")) # for discrete time x$task_properties$surv = x$task_properties$regr x$task_properties$dens = x$task_properties$regr diff --git a/inst/testthat/helper_autotest.R b/inst/testthat/helper_autotest.R index c96f215c6..f09fc876e 100644 --- a/inst/testthat/helper_autotest.R +++ b/inst/testthat/helper_autotest.R @@ -26,7 +26,6 @@ sanity_check.PredictionDens = function(prediction, ...) { # nolint registerS3method("sanity_check", "PredictionDens", sanity_check.PredictionDens) generate_tasks.LearnerSurv = function(learner, N = 20L, ...) { # nolint - real_time = round(1 + rexp(N, rate = 2) * 20, 1) cens_time = round(1 + rexp(N, rate = 3) * 20, 1) status = ifelse(real_time <= cens_time, 1L, 0L) @@ -49,6 +48,12 @@ generate_tasks.LearnerSurv = function(learner, N = 20L, ...) { # nolint tasks$sanity_reordered = tasks$sanity$clone(deep = TRUE) tasks$sanity_reordered$id = "sanity_reordered" + if ("weights" %in% learner$properties) { + tmp = mlr3proba::TaskSurv$new("weights", mlr3::as_data_backend(cbind(data, weight = runif(N)), time = "time", event = "event")) + tmp$set_col_roles("weight", "weights_learner") + tasks$weights = tmp + } + tasks } registerS3method("generate_tasks", "LearnerSurv", generate_tasks.LearnerSurv) diff --git a/tests/testthat/test_mlr_learners_surv_coxph.R b/tests/testthat/test_mlr_learners_surv_coxph.R index 0c2d9bcd0..f3aa944b9 100644 --- a/tests/testthat/test_mlr_learners_surv_coxph.R +++ b/tests/testthat/test_mlr_learners_surv_coxph.R @@ -10,10 +10,10 @@ test_that("autotest", { }) test_that("weights", { - learner = lrn("surv.coxph") + learner = lrn("surv.coxph", use_weights = TRUE) task = generate_tasks.LearnerSurv(learner)$weights - learner$train(task) - expect_equal(learner$model$weights, task$weights$weight) + suppressWarnings({learner$train(task)}) + expect_equal(learner$model$weights, task$weights_learner$weight) }) test_that("missing", { diff --git a/tests/testthat/test_mlr_learners_surv_rpart.R b/tests/testthat/test_mlr_learners_surv_rpart.R index 3b73a1627..b7d666993 100644 --- a/tests/testthat/test_mlr_learners_surv_rpart.R +++ b/tests/testthat/test_mlr_learners_surv_rpart.R @@ -9,7 +9,6 @@ test_that("autotest", { expect_true(result, info = result$error) }) - test_that("importance/selected", { learner = lrn("surv.rpart") expect_error(learner$importance(), "No model stored")