diff --git a/NEWS.md b/NEWS.md index efa177d8..758bcec3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,6 @@ # mlr3torch dev -* feat: Add parameter `num_threads_interop` to `LearnerTorch` +* feat: Add parameter `num_interop_threads` to `LearnerTorch` # mlr3torch 0.1.2 diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index e0355af6..bcdd2326 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -434,7 +434,7 @@ LearnerTorch = R6Class("LearnerTorch", if (identical(param_vals$seed, "random")) param_vals$seed = sample.int(.Machine$integer.max, 1) model = with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads, - num_threads_interop = param_vals$num_threads_interop, expr = { + num_interop_threads = param_vals$num_threads_interop, expr = { learner_torch_train(self, private, super, task, param_vals) }) model$task_col_info = copy(task$col_info[c(task$feature_names, task$target_names), c("id", "type", "levels")]) @@ -455,7 +455,7 @@ LearnerTorch = R6Class("LearnerTorch", private$.verify_predict_task(task, param_vals) with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, - num_threads_interop = param_vals$num_threads_interop, expr = { + num_interop_threads = param_vals$num_threads_interop, expr = { learner_torch_predict(self, private, super, task, param_vals) }) }, diff --git a/R/paramset_torchlearner.R b/R/paramset_torchlearner.R index 2131a454..f340dc8a 100644 --- a/R/paramset_torchlearner.R +++ b/R/paramset_torchlearner.R @@ -54,7 +54,7 @@ paramset_torchlearner = function(task_type) { aggr = epochs_aggr, in_tune_fn = epochs_tune_fn, disable_in_tune = list(patience = 0)), device = p_fct(tags = c("train", "predict", "required"), levels = mlr_reflections$torch$devices, init = "auto"), num_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L), - num_threads_interop = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L), + num_interop_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L), seed = p_int(tags = c("train", "predict", "required"), special_vals = list("random", NULL), init = "random"), # evaluation eval_freq = p_int(lower = 1L, tags = c("train", "required"), init = 1L), diff --git a/man-roxygen/paramset_torchlearner.R b/man-roxygen/paramset_torchlearner.R index 470aff19..cf5d7505 100644 --- a/man-roxygen/paramset_torchlearner.R +++ b/man-roxygen/paramset_torchlearner.R @@ -14,7 +14,7 @@ #' * `num_threads` :: `integer(1)`\cr #' The number of threads for intraop pararallelization (if `device` is `"cpu"`). #' This value is initialized to 1. -#' * `num_threads_interop` :: `integer(1)`\cr +#' * `num_interop_threads` :: `integer(1)`\cr #' The number of threads for intraop and interop pararallelization (if `device` is `"cpu"`). #' This value is initialized to 1. #' Note that this can only be set once during a session and changing the value within an R session will raise a warning. diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 2cd2a454..ebdfe53d 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -70,7 +70,7 @@ fall back to \code{"cpu"}. \item \code{num_threads} :: \code{integer(1)}\cr The number of threads for intraop pararallelization (if \code{device} is \code{"cpu"}). This value is initialized to 1. -\item \code{num_threads_interop} :: \code{integer(1)}\cr +\item \code{num_interop_threads} :: \code{integer(1)}\cr The number of threads for intraop and interop pararallelization (if \code{device} is \code{"cpu"}). This value is initialized to 1. Note that this can only be set once during a session and changing the value within an R session will raise a warning. diff --git a/man/mlr_pipeops_torch_model.Rd b/man/mlr_pipeops_torch_model.Rd index eb3a5737..7cac5719 100644 --- a/man/mlr_pipeops_torch_model.Rd +++ b/man/mlr_pipeops_torch_model.Rd @@ -37,7 +37,7 @@ fall back to \code{"cpu"}. \item \code{num_threads} :: \code{integer(1)}\cr The number of threads for intraop pararallelization (if \code{device} is \code{"cpu"}). This value is initialized to 1. -\item \code{num_threads_interop} :: \code{integer(1)}\cr +\item \code{num_interop_threads} :: \code{integer(1)}\cr The number of threads for intraop and interop pararallelization (if \code{device} is \code{"cpu"}). This value is initialized to 1. Note that this can only be set once during a session and changing the value within an R session will raise a warning.