Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to impute missing values during training #336

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# development version

- New option to impute missing data after the train/test split rather than before (#301, @megancoden and @shah-priyal).
- Added `impute_in_training` option to `run_ml()`, which defaults to FALSE.
- Added `impute_in_preprocessing` option to `preprocess()`, which defaults to TRUE.

# mikropml 1.6.0

- New functions:
Expand All @@ -13,7 +17,6 @@
- Renamed the column `names` to `feat` to represent each feature or group of correlated features.
- New column `lower` and `upper` to report the bounds of the empirical 95% confidence interval from the permutation test.
See `vignette('parallel')` for an example of plotting feature importance with confidence intervals.
- Minor documentation improvements (#323, #332, @kelly-sovacool).

# mikropml 1.5.0

Expand Down
37 changes: 37 additions & 0 deletions R/impute.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#' Replace NA values with the median value of the column for continious variables in the dataset
#'
#' @param transformed_cont Data frame that may include NA values in one or more columns
#'
#' @return Data frame that has no NA values in continious numeric columns
#'
#' @examples
#' transformed_cont <- impute(transformed_cont)
#' train_data <- impute(train_data)
#' test_data <- impute(test_data)
impute <- function(transformed_cont) {
sapply_fn <- select_apply("sapply")
cl <- sapply_fn(transformed_cont, function(x) {
class(x)
})
missing <-
is.na(transformed_cont[, cl %in% c("integer", "numeric")])
n_missing <- sum(missing)
if (n_missing > 0) {
transformed_cont <- sapply_fn(transformed_cont, function(x) {
if (class(x) %in% c("integer", "numeric")) {
m <- is.na(x)
x[m] <- stats::median(x, na.rm = TRUE)
}
message(typeof(x))
message(class(x))
return(x)
}) %>% dplyr::as_tibble()
message(
paste0(
n_missing,
" missing continuous value(s) were imputed using the median value of the feature."
)
)
}
return (transformed_cont)
}
68 changes: 18 additions & 50 deletions R/preprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ preprocess_data <- function(dataset, outcome_colname,
method = c("center", "scale"),
remove_var = "nzv", collapse_corr_feats = TRUE,
to_numeric = TRUE, group_neg_corr = TRUE,
prefilter_threshold = 1) {
prefilter_threshold = 1, impute_in_preprocessing = TRUE) {
progbar <- NULL
if (isTRUE(check_packages_installed("progressr"))) {
progbar <- progressr::progressor(steps = 20, message = "preprocessing")
Expand All @@ -70,10 +70,11 @@ preprocess_data <- function(dataset, outcome_colname,
check_outcome_column(dataset, outcome_colname, check_values = FALSE)
check_remove_var(remove_var)
pbtick(progbar)

dataset[[outcome_colname]] <- replace_spaces(dataset[[outcome_colname]])
dataset <- rm_missing_outcome(dataset, outcome_colname)
split_dat <- split_outcome_features(dataset, outcome_colname)

features <- split_dat$features
removed_feats <- character(0)
if (to_numeric) {
Expand All @@ -83,22 +84,22 @@ preprocess_data <- function(dataset, outcome_colname,
features <- feats$dat
}
pbtick(progbar)

nv_feats <- process_novar_feats(features, progbar = progbar)
pbtick(progbar)
split_feats <- process_cat_feats(nv_feats$var_feats, progbar = progbar)
pbtick(progbar)
cont_feats <- process_cont_feats(split_feats$cont_feats, method)
cont_feats <- process_cont_feats(split_feats$cont_feats, method, impute_in_preprocessing)
pbtick(progbar)

# combine all processed features
processed_feats <- dplyr::bind_cols(
cont_feats$transformed_cont,
split_feats$cat_feats,
nv_feats$novar_feats
)
pbtick(progbar)

# remove features with (near-)zero variance
feats <- get_caret_processed_df(processed_feats, remove_var)
processed_feats <- feats$processed
Expand Down Expand Up @@ -140,17 +141,15 @@ preprocess_data <- function(dataset, outcome_colname,
#' @inheritParams run_ml
#'
#' @return dataset with no missing outcomes
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' rm_missing_outcome(mikropml::otu_mini_bin, "dx")
#'
#' test_df <- mikropml::otu_mini_bin
#' test_df[1:100, "dx"] <- NA
#' rm_missing_outcome(test_df, "dx")
#' }
rm_missing_outcome <- function(dataset, outcome_colname) {
n_outcome_na <- sum(is.na(dataset %>% dplyr::pull(outcome_colname)))
total_outcomes <- nrow(dataset)
Expand All @@ -168,13 +167,11 @@ rm_missing_outcome <- function(dataset, outcome_colname) {
#' @param features dataframe of features for machine learning
#'
#' @return dataframe with numeric columns where possible
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' class(change_to_num(data.frame(val = c("1", "2", "3")))[[1]])
#' }
change_to_num <- function(features) {
lapply_fn <- select_apply(fun = "lapply")
check_features(features, check_missing = FALSE)
Expand Down Expand Up @@ -228,13 +225,11 @@ remove_singleton_columns <- function(dat, threshold = 1) {
#' @param progbar optional progress bar (default: `NULL`)
#'
#' @return list of two dataframes: features with variability (unprocessed) and without (processed)
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' process_novar_feats(mikropml::otu_small[, 2:ncol(otu_small)])
#' }
process_novar_feats <- function(features, progbar = NULL) {
novar_feats <- NULL
var_feats <- NULL
Expand Down Expand Up @@ -297,13 +292,11 @@ process_novar_feats <- function(features, progbar = NULL) {
#' @inheritParams process_novar_feats
#'
#' @return list of two dataframes: categorical (processed) and continuous features (unprocessed)
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' process_cat_feats(mikropml::otu_small[, 2:ncol(otu_small)])
#' }
process_cat_feats <- function(features, progbar = NULL) {
feature_design_cat_mat <- NULL
cont_feats <- NULL
Expand Down Expand Up @@ -367,14 +360,12 @@ process_cat_feats <- function(features, progbar = NULL) {
#' @inheritParams get_caret_processed_df
#'
#' @return dataframe of preprocessed features
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' process_cont_feats(mikropml::otu_small[, 2:ncol(otu_small)], c("center", "scale"))
#' }
process_cont_feats <- function(features, method) {
process_cont_feats <- function(features, method, impute_in_preprocessing) {
transformed_cont <- NULL
removed_cont <- NULL

Expand All @@ -389,31 +380,12 @@ process_cont_feats <- function(features, method) {
transformed_cont <- feats$processed
removed_cont <- feats$removed
}
sapply_fn <- select_apply("sapply")
cl <- sapply_fn(transformed_cont, function(x) {
class(x)
})
missing <-
is.na(transformed_cont[, cl %in% c("integer", "numeric")])
n_missing <- sum(missing)
if (n_missing > 0) {
# impute missing data using the median value
transformed_cont <- sapply_fn(transformed_cont, function(x) {
if (class(x) %in% c("integer", "numeric")) {
m <- is.na(x)
x[m] <- stats::median(x, na.rm = TRUE)
}
return(x)
}) %>% dplyr::as_tibble()
message(
paste0(
n_missing,
" missing continuous value(s) were imputed using the median value of the feature."
)
)
if (impute_in_preprocessing) {
transformed_cont <- impute(transformed_cont)
}
}
}
}
return(list(transformed_cont = transformed_cont, removed_cont = removed_cont))
}

Expand Down Expand Up @@ -450,11 +422,10 @@ get_caret_processed_df <- function(features, method) {
#' @inheritParams process_novar_feats
#' @param full_rank whether matrix should be full rank or not (see `[caret::dummyVars])
#' @return design matrix
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' df <- data.frame(
#' outcome = c("normal", "normal", "cancer"),
#' var1 = 1:3,
Expand All @@ -463,7 +434,6 @@ get_caret_processed_df <- function(features, method) {
#' var4 = c(0, 1, 0)
#' )
#' get_caret_dummyvars_df(df, TRUE)
#' }
get_caret_dummyvars_df <- function(features, full_rank = FALSE, progbar = NULL) {
check_features(features, check_missing = FALSE)
if (!is.null(process_novar_feats(features, progbar = progbar)$novar_feats)) {
Expand All @@ -481,13 +451,11 @@ get_caret_dummyvars_df <- function(features, full_rank = FALSE, progbar = NULL)
#' @inheritParams group_correlated_features
#'
#' @return features where perfectly correlated ones are collapsed
#' @keywords internal
#' @noRd
#' @author Zena Lapp, \email{zenalapp@@umich.edu}
#'
#' @examples
#' \dontrun{
#' collapse_correlated_features(mikropml::otu_small[, 2:ncol(otu_small)])
#' }
collapse_correlated_features <- function(features, group_neg_corr = TRUE, progbar = NULL) {
feats_nocorr <- features
grp_feats <- NULL
Expand Down
43 changes: 24 additions & 19 deletions R/run_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ run_ml <-
group_partitions = NULL,
corr_thresh = 1,
seed = NA,
impute_after_split = FALSE,
...) {
check_all(
dataset,
Expand All @@ -162,7 +163,7 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}

# `future.apply` is required for `find_feature_importance()`.
# check it here to adhere to the fail fast principle.
if (find_feature_importance) {
Expand All @@ -173,20 +174,20 @@ run_ml <-
if (find_feature_importance) {
check_cat_feats(dataset %>% dplyr::select(-outcome_colname))
}

dataset <- dataset %>%
randomize_feature_order(outcome_colname) %>%
# convert tibble to dataframe to silence warning from caret::train():
# "Warning: Setting row names on a tibble is deprecated.."
as.data.frame()

outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname)

if (length(training_frac) == 1) {
training_inds <- get_partition_indices(outcomes_vctr,
training_frac = training_frac,
groups = groups,
group_partitions = group_partitions
training_frac = training_frac,
groups = groups,
group_partitions = group_partitions
)
} else {
training_inds <- training_frac
Expand All @@ -201,30 +202,34 @@ run_ml <-
}
check_training_frac(training_frac)
check_training_indices(training_inds, dataset)

train_data <- dataset[training_inds, ]
test_data <- dataset[-training_inds, ]
if (impute_after_split == TRUE) {
train_data <- impute(train_data)
test_data <- impute(test_data)
}
# train_groups & test_groups will be NULL if groups is NULL
train_groups <- groups[training_inds]
test_groups <- groups[-training_inds]

if (is.null(hyperparameters)) {
hyperparameters <- get_hyperparams_list(dataset, method)
}
tune_grid <- get_tuning_grid(hyperparameters, method)


outcome_type <- get_outcome_type(outcomes_vctr)
class_probs <- outcome_type != "continuous"

if (is.null(perf_metric_function)) {
perf_metric_function <- get_perf_metric_fn(outcome_type)
}

if (is.null(perf_metric_name)) {
perf_metric_name <- get_perf_metric_name(outcome_type)
}

if (is.null(cross_val)) {
cross_val <- define_cv(
train_data,
Expand All @@ -238,8 +243,8 @@ run_ml <-
group_partitions = group_partitions
)
}


message("Training the model...")
trained_model_caret <- train_model(
train_data = train_data,
Expand All @@ -254,7 +259,7 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}

if (calculate_performance) {
performance_tbl <- get_performance_tbl(
trained_model_caret,
Expand All @@ -269,7 +274,7 @@ run_ml <-
} else {
performance_tbl <- "Skipped calculating performance"
}

if (find_feature_importance) {
message("Finding feature importance...")
feature_importance_tbl <- get_feature_importance(
Expand All @@ -287,7 +292,7 @@ run_ml <-
} else {
feature_importance_tbl <- "Skipped feature importance"
}

return(
list(
trained_model = trained_model_caret,
Expand Down
Loading