Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spatial resampling in the function 'PipeOpLearnerCV #774

Open
marineReg opened this issue Jun 8, 2024 · 0 comments
Open

Spatial resampling in the function 'PipeOpLearnerCV #774

marineReg opened this issue Jun 8, 2024 · 0 comments

Comments

@marineReg
Copy link

Hello, I'm trying to modify the function 'PipeOpLearnerCV' to integrate a spatial resampling of type 'sptcv_cstf'. However, I'm getting an error message when I execute:

PipeOpLearnerCV_mod$new(learner = mlr3::lrn("classif.glmnet", predict_type = "prob"), param_vals = list(resampling.method = "sptcv_cstf", resampling.folds = 5))

Error in if (stratify) task$target_names else NULL : argument is of length zero This happened PipeOp classif.glmnet's $train()

Here are my edits in the function:

                              ########################################################################
                              ## My edits
                              private$.crossval_param_set = ps(
                                method = p_fct(levels = c("cv", "insample", "sptcv_cstf", "repeated_sptcv_cstf"), tags = c("train", "required")),
                                folds = p_int(lower = 2L, upper = Inf, tags = c("train", "required")), repeats = p_int(lower = 1L, upper = Inf),
                                keep_response = p_lgl(tags = c("train", "required"))
                              )
                              ########################################

                                ...
                                 if (pv$method == "cv") rdesc$param_set$values = list(folds = pv$folds)
                                ########################################################################
                                ## My edits
                                if (pv$method == "sptcv_cstf") rdesc$param_set$values = list(folds = pv$folds)
                                if (pv$method == "repeated_sptcv_cstf") rdesc$param_set$values = list(folds = pv$folds, repeats = pv$repeats)
                                ########################################################################

Thank you very much for your assistance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant