Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pipeop IPCW draft #407

Merged
merged 86 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
1b8aaae
add pipeop draft
studener Aug 2, 2024
720eb58
update tests
studener Aug 5, 2024
a512695
add basic docs
studener Aug 5, 2024
34e4798
add pipeline draft
studener Aug 6, 2024
c658b66
draft PipeOpPredClassifSurvIPCW
studener Aug 6, 2024
14af33a
update pipeline
studener Aug 6, 2024
15b8d16
updocs
studener Aug 6, 2024
29d7b9d
update tests
studener Aug 6, 2024
c21b8dc
update tests
studener Aug 6, 2024
daec59c
fix binding
studener Aug 6, 2024
1d5d56b
updocs
studener Aug 6, 2024
9dcfb4a
update PipeOpPredClassifSurv
studener Aug 6, 2024
43e3222
fix typo
studener Aug 8, 2024
dde3f72
refactor / add eps param to IPCW pipeop
studener Aug 10, 2024
8680143
remove time_var from features
studener Aug 12, 2024
c6d3a84
add correct time to surv prediction
studener Aug 12, 2024
c485290
updocs
studener Aug 13, 2024
bbb382c
fix typo
studener Aug 13, 2024
dd3f74e
update tests
studener Aug 13, 2024
902b7f0
correct row ids for surv prediction
studener Aug 13, 2024
908aed4
remove classif output option / updocs
studener Aug 26, 2024
bae712e
Merge branch 'main' into pipeop_ipcw
studener Sep 3, 2024
88b2065
updocs
studener Sep 3, 2024
2c980fb
update unloading test
bblodfon Sep 4, 2024
3e7f2b6
refactor + fixes
bblodfon Sep 6, 2024
860aef7
rename test file
bblodfon Sep 6, 2024
3b37b29
add IPCW pipeop test + update pipeline test
bblodfon Sep 6, 2024
4bd50fc
update docs
bblodfon Sep 6, 2024
2de8bcf
Merge branch 'main' into pipeop_ipcw
bblodfon Sep 6, 2024
e29329e
add examples
studener Sep 9, 2024
2ea204c
update pipeops
studener Sep 9, 2024
89f831f
add tests
studener Sep 12, 2024
4bbe7c7
Merge branch 'main' into pipeop_ipcw
bblodfon Sep 17, 2024
ce1b347
update doc (encapsulate method from mlr3 dev)
bblodfon Sep 17, 2024
24f38fe
improve doc and example
bblodfon Sep 17, 2024
792bd6e
refinements (doc and eps param)
bblodfon Sep 17, 2024
c476aed
refine IPCW test
bblodfon Sep 17, 2024
c9bc828
add comment
bblodfon Sep 18, 2024
d972623
fix rare bug in graf score when evulating a survival matrix with only…
bblodfon Sep 18, 2024
5ebda76
empty state => empty list, not NULL
bblodfon Sep 18, 2024
e23a186
update example
bblodfon Sep 18, 2024
854e4eb
update IPCW pipeline test
bblodfon Sep 18, 2024
919378e
update docs
bblodfon Sep 18, 2024
0108d2d
update to v0.6.9
bblodfon Sep 18, 2024
28790e2
ignore docs/
bblodfon Sep 20, 2024
bac562d
add fancy icon
bblodfon Sep 20, 2024
2816a62
temp fix of math rending issue
bblodfon Sep 20, 2024
fe7df89
update docs back to CRAN mlr3 version
bblodfon Sep 20, 2024
e9856a3
update example
bblodfon Sep 20, 2024
2ad288f
refactor: cutoff_time => tau
bblodfon Sep 20, 2024
4e44927
doc: IPCW surv predictions should be evaluated at tau only
bblodfon Sep 20, 2024
c41e192
fix test
bblodfon Sep 20, 2024
e12eb76
updocs
bblodfon Sep 20, 2024
5146c2b
refine doc (mlr3 style)
bblodfon Sep 21, 2024
30c80b4
correct doc about t_max in integrated scores
bblodfon Sep 21, 2024
3e1d8b1
add description for Rcpp function
bblodfon Sep 22, 2024
a6e073c
fix spelling
bblodfon Sep 23, 2024
d0c29ff
fix type
studener Sep 24, 2024
0694141
fix typo
studener Sep 26, 2024
f1b1f0e
move code comment
bblodfon Sep 27, 2024
6a049ac
fix: keep the response to the output prediction object in distrcompose
bblodfon Sep 27, 2024
39439f8
add checks on the `times` arg for the intergrated survival losses
bblodfon Sep 27, 2024
cf3bf30
refactor + add code comments in 'integrated_scores()'
bblodfon Sep 27, 2024
424aa9f
use distr6 C++ function to constantly interpolate S(t)
bblodfon Sep 27, 2024
5031b8b
small fix + refactoring: change the way `times` and `t_max` args are …
bblodfon Sep 27, 2024
d24c37b
compatibility with mlr3 0.21.0
bblodfon Sep 27, 2024
4940d73
test IBS with `times` argument more thoroughly
bblodfon Sep 27, 2024
4d53c73
update docs
bblodfon Sep 27, 2024
c78d7e6
Merge branch 'pipeop_ipcw' of https://github.com/mlr-org/mlr3proba in…
bblodfon Sep 27, 2024
4a55348
complete merging
bblodfon Sep 27, 2024
239784a
change comment doc to not produce a man entry
bblodfon Sep 27, 2024
d6b8670
change outside times range to a warning (more reasonable)
bblodfon Oct 7, 2024
46138b1
update IBS + times tests
bblodfon Oct 7, 2024
e557ae4
update doc files for IBS
bblodfon Oct 7, 2024
b8e8b6c
update doc for RNLL
bblodfon Oct 7, 2024
2697304
update doc for ISBS
bblodfon Oct 7, 2024
73a6748
update NEWs
bblodfon Oct 7, 2024
b4b2b2b
added new doc templates and refined some others
bblodfon Oct 7, 2024
8ad9eb1
refine docs for 3 integrated survival scores (ISBS, ISS, ISSL)
bblodfon Oct 7, 2024
52b400f
refine examples and test for IPCW (integrated = FALSE for a single ti…
bblodfon Oct 8, 2024
8413b0b
add IPCW pipeline alias
bblodfon Oct 8, 2024
c8d4b9f
add doc template for pipeline construction via ppl
bblodfon Oct 8, 2024
4f49f82
pipeline doc refactoring
bblodfon Oct 8, 2024
83349c9
assert classif learner in IPCW and disctime
bblodfon Oct 8, 2024
cf41f1c
add experimental badge for 3 pipelines (survtoregr, distrcompositor, …
bblodfon Oct 8, 2024
69e13df
update NEWs
bblodfon Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Collate:
'PipeOpCrankCompositor.R'
'PipeOpDistrCompositor.R'
'PipeOpPredClassifSurvDiscTime.R'
'PipeOpPredClassifSurvIPCW.R'
'PipeOpTransformer.R'
'PipeOpPredTransformer.R'
'PipeOpPredRegrSurv.R'
Expand All @@ -148,6 +149,7 @@ Collate:
'PipeOpSurvAvg.R'
'PipeOpTaskRegrSurv.R'
'PipeOpTaskSurvClassifDiscTime.R'
'PipeOpTaskSurvClassifIPCW.R'
'PipeOpTaskSurvRegr.R'
'PipeOpTaskTransformer.R'
'PredictionDataDens.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,15 @@ export(PipeOpBreslow)
export(PipeOpCrankCompositor)
export(PipeOpDistrCompositor)
export(PipeOpPredClassifSurvDiscTime)
export(PipeOpPredClassifSurvIPCW)
export(PipeOpPredRegrSurv)
export(PipeOpPredSurvRegr)
export(PipeOpPredTransformer)
export(PipeOpProbregr)
export(PipeOpSurvAvg)
export(PipeOpTaskRegrSurv)
export(PipeOpTaskSurvClassifDiscTime)
export(PipeOpTaskSurvClassifIPCW)
export(PipeOpTaskSurvRegr)
export(PipeOpTaskTransformer)
export(PipeOpTransformer)
Expand All @@ -97,6 +99,7 @@ export(as_task_surv)
export(assert_surv)
export(breslow)
export(pecs)
export(pipeline_survtoclassif_IPCW)
export(pipeline_survtoclassif_disctime)
export(pipeline_survtoregr)
export(plot_probregr)
Expand Down
71 changes: 71 additions & 0 deletions R/PipeOpPredClassifSurvIPCW.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#' @title PipeOpPredClassifSurvIPCW
#' @name mlr_pipeops_trafopred_classifsurv_IPCW
#'
#' @description
#' Transform [PredictionClassif] to [PredictionSurv].
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpPredClassifSurvIPCW$new()
#' mlr_pipeops$get("trafopred_classifsurv_IPCW")
#' po("trafopred_classifsurv_IPCW")
#' ```
#'
#' @section Input and Output Channels:
#' The input is a [PredictionClassif] and a [data.table] containing observed times
#' and row ids both generated by [PipeOpTaskSurvClassifIPCW].
#' The output is the input [PredictionClassif] transformed to a [PredictionSurv].
#' Only works during prediction phase.
#'
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredClassifSurvIPCW = R6Class(
"PipeOpPredClassifSurvIPCW",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param id (character(1))\cr
#' Identifier of the resulting object.
initialize = function(id = "trafopred_classifsurv_IPCW") {
super$initialize(
id = id,
input = data.table(
name = c("input", "data"),
train = c("NULL", "NULL"),
predict = c("PredictionClassif", "data.table")
),
output = data.table(
name = "output",
train = "NULL",
predict = "PredictionSurv"
)
)
}
),

private = list(
.predict = function(input) {
pred = input[[1]]
data = input[[2]]

p = PredictionSurv$new(row_ids = data$ids,
truth = Surv(time = data$times,
event = as.integer(pred$truth)),
studener marked this conversation as resolved.
Show resolved Hide resolved
crank = pred$prob[, 2])
studener marked this conversation as resolved.
Show resolved Hide resolved
list(p)
},

.train = function(input) {
self$state = list()
list(input)
}
)
)

register_pipeop("trafopred_classifsurv_IPCW", PipeOpPredClassifSurvIPCW)
133 changes: 133 additions & 0 deletions R/PipeOpTaskSurvClassifIPCW.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#' @title PipeOpTaskSurvClassifIPCW
#' @name mlr_pipeops_trafotask_survclassif_IPCW
#' @template param_pipelines
#'
#' @description
#' Transform [TaskSurv] to [TaskClassif][mlr3::TaskClassif] using IPCW (Vock et al., 2016).
#'
#' @section Dictionary:
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
#' [dictionary][mlr3misc::Dictionary] [mlr3pipelines::mlr_pipeops]
#' or with the associated sugar function [mlr3pipelines::po()]:
#' ```
#' PipeOpTaskSurvClassifIPCW$new()
#' mlr_pipeops$get("trafotask_survclassif_IPCW")
#' po("trafotask_survclassif_IPCW")
#' ```
#'
#' @section Input and Output Channels:
#' [PipeOpTaskSurvClassifIPCW] has one input channel named "input", and two
#' output channels, one named "output" and the other "data".
#'
#' During training, the "output" is the "input" [TaskSurv] transformed to a
studener marked this conversation as resolved.
Show resolved Hide resolved
#' [TaskClassif][mlr3::TaskClassif].
#' The target column is named `"status"` and indicates whether an event occurred
#' in each time interval.
#' The transformed task now has the property "weights".
#' The "data" is NULL.
#'
#' During prediction, the "input" [TaskSurv] is transformed to the "output"
#' [TaskClassif][mlr3::TaskClassif] with `"status"` as target.
#' The "data" is a [data.table] containing the "time" of each subject as well
#' as corresponding "row_ids".
#' This "data" is only meant to be used with the [PipeOpPredClassifSurvIPCW].
#'
#' @section Parameters:
#' The parameters are
#'
#' * `cutoff_time :: numeric()`\cr
#' Cutoff time for IPCW. Observations with time larger than `cutoff_time` are censored.
#' * `eps :: numeric()`\cr
#' Small value to replace `0` survival probabilities with to prevent infinite weights.
#'
#' @references
#' `r format_bib("vock_2016")`
#'
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpTaskSurvClassifIPCW = R6Class(
"PipeOpTaskSurvClassifIPCW",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "trafotask_survclassif_IPCW") {
param_set = ps(
cutoff_time = p_dbl(lower = 0, special_vals = list()),
eps = p_dbl(lower = 0, default = 1e-6)
studener marked this conversation as resolved.
Show resolved Hide resolved
)
super$initialize(
id = id,
param_set = param_set,
input = data.table(
name = "input",
train = "TaskSurv",
predict = "TaskSurv"
),
output = data.table(
name = c("output", "data"),
train = c("TaskClassif", "NULL"),
predict = c("TaskClassif", "data.table")
)
)
}
),

private = list(
.predict = function(input) {
data = input[[1]]$data()
studener marked this conversation as resolved.
Show resolved Hide resolved
data$status = factor(data$status, levels = c("0", "1"))
studener marked this conversation as resolved.
Show resolved Hide resolved
task = TaskClassif$new(id = input[[1]]$id, backend = data,
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
studener marked this conversation as resolved.
Show resolved Hide resolved
target = "status", positive = "1")

time = data[[input[[1]]$target_names[1]]]
data = data.table(ids = input[[1]]$row_ids, times = time)
list(task, data)
},

.train = function(input) {
data = input[[1]]$data()
time_var = input[[1]]$target_names[1]
status_var = input[[1]]$target_names[2]

cutoff_time = self$param_set$values$cutoff_time
studener marked this conversation as resolved.
Show resolved Hide resolved
eps = self$param_set$values$eps

if (cutoff_time >= max(data[[time_var]])) {
stop("Cutoff time must be smaller than the maximum event time.")
studener marked this conversation as resolved.
Show resolved Hide resolved
}

# transform data and calculate weights
times = data[[time_var]]
times[times > cutoff_time] = cutoff_time

status = data[[status_var]]
status[times == cutoff_time] = 0

cens = survival::survfit(Surv(times, 1 - status) ~ 1)
studener marked this conversation as resolved.
Show resolved Hide resolved
cens$surv[length(cens$surv)] = cens$surv[length(cens$surv)-1]
studener marked this conversation as resolved.
Show resolved Hide resolved
cens$surv[cens$surv == 0] = eps

weights = rep(1/cens$surv, table(times))
studener marked this conversation as resolved.
Show resolved Hide resolved

# add weights to original data
data[["ipc_weights"]] = weights
data[status_var == 0 & time_var < cutoff_time, "ipc_weights" := 0]
studener marked this conversation as resolved.
Show resolved Hide resolved
data[[status_var]] = factor(data[[status_var]], levels = c("0", "1"))
studener marked this conversation as resolved.
Show resolved Hide resolved
data[[time_var]] = NULL

# create new task
task = TaskClassif$new(id = paste0(input[[1]]$id, "_IPCW"), backend = data,
studener marked this conversation as resolved.
Show resolved Hide resolved
target = status_var, positive = "1")

task$set_col_roles("ipc_weights", roles = "weight")

self$state = list()
studener marked this conversation as resolved.
Show resolved Hide resolved
list(task, NULL)
}
)
)

register_pipeop("trafotask_survclassif_IPCW", PipeOpTaskSurvClassifIPCW)
2 changes: 1 addition & 1 deletion R/PipeOpTaskSurvRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#'
#' * `method::character(1))`\cr
#' Method to use for dealing with censoring. Options are `"ipcw"` (Vock et al., 2016): censoring
#' is column is removed and a `weights` column is added, weights are inverse estimated survival
#' column is removed and a `weights` column is added, weights are inverse estimated survival
#' probability of the censoring distribution evaluated at survival time;
#' `"mrl"` (Klein and Moeschberger, 2003): survival time of censored
#' observations is transformed to the observed time plus the mean residual life-time at the moment
Expand Down
56 changes: 56 additions & 0 deletions R/pipelines.R
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,66 @@ pipeline_survtoclassif_disctime = function(learner, cut = NULL, max_time = NULL,
gr
}

#' @name mlr_graphs_survtoclassif_IPCW
#' @title Survival to Classification Reduction Pipeline using IPCW
#' @description Wrapper around multiple [PipeOp][mlr3pipelines::PipeOp]s to help in creation
#' of complex survival reduction methods.
#'
#' @param learner [LearnerClassif][mlr3::LearnerClassif]\cr
#' Classification learner to fit the transformed [TaskClassif][mlr3::TaskClassif].
#' @param cutoff_time `numeric()`\cr
#' Cutoff time for IPCW. Observations with time larger than `cutoff_time` are censored.
#' @param eps `numeric()`\cr
#' Small value to replace `0` survival probabilities with to prevent infinite weights.
#' @param output `numeric()`\cr
#' If not set to "classif" (default) then the prediction is transformed to a crank.
#' @param graph_learner `logical(1)`\cr
#' If `TRUE` returns wraps the [Graph][mlr3pipelines::Graph] as a
#' [GraphLearner][mlr3pipelines::GraphLearner] otherwise (default) returns as a `Graph`.
#'
#' @details
#' The pipeline consists of the following steps:
#' \enumerate{
#' \item [PipeOpTaskSurvClassifIPCW] Converts [TaskSurv] to a [TaskClassif][mlr3::TaskClassif].
#' \item A [LearnerClassif] is fit and predicted on the new `TaskClassif`.
#' \item Optionally: [PipeOpPredClassifSurvIPCW] transforms the resulting [PredictionClassif][mlr3::PredictionClassif]
#' to [PredictionSurv].
#' }
#'
#' @return [mlr3pipelines::Graph] or [mlr3pipelines::GraphLearner]
#' @family pipelines
#'
#' @export
pipeline_survtoclassif_IPCW = function(learner, cutoff_time = NULL, eps = 1e-6, output = "classif", graph_learner = FALSE) {
assert_true("prob" %in% learner$predict_types)

gr = mlr3pipelines::Graph$new()
gr$add_pipeop(mlr3pipelines::po("trafotask_survclassif_IPCW", cutoff_time = cutoff_time))
gr$add_pipeop(mlr3pipelines::po("learner", learner, predict_type = "prob"))

gr$add_edge(src_id = "trafotask_survclassif_IPCW", dst_id = learner$id, src_channel = "output", dst_channel = "input")

if (output != "classif") {
gr$add_pipeop(mlr3pipelines::po("trafopred_classifsurv_IPCW"))
gr$add_pipeop(mlr3pipelines::po("nop"))

gr$add_edge(src_id = learner$id, dst_id = "trafopred_classifsurv_IPCW", src_channel = "output", dst_channel = "input")
gr$add_edge(src_id = "trafotask_survclassif_IPCW", dst_id = "nop", src_channel = "data", dst_channel = "input")
gr$add_edge(src_id = "nop", dst_id = "trafopred_classifsurv_IPCW", src_channel = "output", dst_channel = "data")
}

if (graph_learner) {
gr = mlr3pipelines::GraphLearner$new(gr)
}

gr
}

register_graph("survaverager", pipeline_survaverager)
register_graph("survbagging", pipeline_survbagging)
register_graph("crankcompositor", pipeline_crankcompositor)
register_graph("distrcompositor", pipeline_distrcompositor)
register_graph("probregr", pipeline_probregr)
register_graph("survtoregr", pipeline_survtoregr)
register_graph("survtoclassif_disctime", pipeline_survtoclassif_disctime)
register_graph("survtoclassif_IPCW", pipeline_survtoclassif_IPCW)
2 changes: 2 additions & 0 deletions man/PipeOpPredTransformer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/PipeOpTaskTransformer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/PipeOpTransformer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_graphs_crankcompositor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_graphs_distrcompositor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_graphs_probregr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_graphs_survaverager.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading