diff --git a/R/utils.R b/R/utils.R index 5231898d4..582de7d63 100644 --- a/R/utils.R +++ b/R/utils.R @@ -42,23 +42,19 @@ task_filter_ex = function(task, row_ids) { # Rbind duplicated rows to task if (length(dup_ids)) { + # First, get a data.table with all duplicated rows. new_data = task$data(rows = dup_ids, cols = cols) - # For column with role "group", create new groups for duplicates by adding a suffix. + # Second, if task has a column with role "group", create new groups for duplicate rows by adding a suffix to the group entry. if (!is.null(task$groups)) { group = NULL # for binding row_id = NULL # for binding - # We create a data.table with the corresponding group to each duplicated ID. - # We then change the group entry based on how often the ID occurs. - # Note that we make no assumptions on whether the whole group is sampled here. - # That has to be checked in the functions calling this. - # - # We assume that the rbinded rows are in the same positions as the original ids in dup_ids. - # This should generally be the case as long as the task does not have a col role group - # and task$data(..., ordered = FALSE) in task$rbind() above (default). - - grps = unique(task$groups) + # We create a data.table "new_groups" with the corresponding group to each duplicated ID. + # We then change the group entry based on how often the ID occurs. E.g. row_id = 1 occurs + # two times has the group entry "g". Then we rename the group entries to "g_1" and "g_2". + # If a group with a suffix (e.g. "_1") already exists, we add another suffix to it (i.e. "_1_1"). + grps = unique(task$groups$group) new_groups = task$groups[J(dup_ids), on = "row_id"][, group := { groups = character(0) i = 1 @@ -70,17 +66,18 @@ task_filter_ex = function(task, row_ids) { groups }, by = row_id] - # Generate data.table with rows for all newly added rows and updated group names + # Use "new_groups" to update the group entries. new_data[, (task$col_roles$group) := new_groups$group] } + # Lastly, new data is rbinded to the original task. task$rbind(new_data) - } - # Row ids can be anything, we just take what mlr3 happens to assign to filter the task. + # row_ids can be anything, we just take what mlr3 happens to assign to filter the task. row_ids[duplicated(row_ids)] = task$row_ids[newrows] + # Update row_ids, effectively filtering the task task$row_roles$use = row_ids task } @@ -107,7 +104,6 @@ curry = function(fn, ..., varname = "x") { } } - # 'and' operator for checkmate check_*-functions # example: # check_numeric(x) %check&&% check_true(all(x < 0)) @@ -121,7 +117,6 @@ curry = function(fn, ..., varname = "x") { TRUE } - # perform gsub on names of list # `...` are given to `gsub()` rename_list = function(x, ...) {