Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Oct 15, 2024
1 parent 848117c commit 238c8f4
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# mlr3torch dev

* feat: Add parameter `num_threads_interop` to `LearnerTorch`
* feat: Add parameter `num_interop_threads` to `LearnerTorch`

# mlr3torch 0.1.2

Expand Down
4 changes: 2 additions & 2 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ LearnerTorch = R6Class("LearnerTorch",
if (identical(param_vals$seed, "random")) param_vals$seed = sample.int(.Machine$integer.max, 1)

model = with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads,
num_threads_interop = param_vals$num_threads_interop, expr = {
num_interop_threads = param_vals$num_threads_interop, expr = {
learner_torch_train(self, private, super, task, param_vals)
})
model$task_col_info = copy(task$col_info[c(task$feature_names, task$target_names), c("id", "type", "levels")])
Expand All @@ -455,7 +455,7 @@ LearnerTorch = R6Class("LearnerTorch",
private$.verify_predict_task(task, param_vals)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads,
num_threads_interop = param_vals$num_threads_interop, expr = {
num_interop_threads = param_vals$num_threads_interop, expr = {
learner_torch_predict(self, private, super, task, param_vals)
})
},
Expand Down
2 changes: 1 addition & 1 deletion R/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ paramset_torchlearner = function(task_type) {
aggr = epochs_aggr, in_tune_fn = epochs_tune_fn, disable_in_tune = list(patience = 0)),
device = p_fct(tags = c("train", "predict", "required"), levels = mlr_reflections$torch$devices, init = "auto"),
num_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L),
num_threads_interop = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L),
num_interop_threads = p_int(lower = 1L, tags = c("train", "predict", "required", "threads"), init = 1L),
seed = p_int(tags = c("train", "predict", "required"), special_vals = list("random", NULL), init = "random"),
# evaluation
eval_freq = p_int(lower = 1L, tags = c("train", "required"), init = 1L),
Expand Down
2 changes: 1 addition & 1 deletion man-roxygen/paramset_torchlearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' * `num_threads` :: `integer(1)`\cr
#' The number of threads for intraop pararallelization (if `device` is `"cpu"`).
#' This value is initialized to 1.
#' * `num_threads_interop` :: `integer(1)`\cr
#' * `num_interop_threads` :: `integer(1)`\cr
#' The number of threads for intraop and interop pararallelization (if `device` is `"cpu"`).
#' This value is initialized to 1.
#' Note that this can only be set once during a session and changing the value within an R session will raise a warning.
Expand Down
2 changes: 1 addition & 1 deletion man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_pipeops_torch_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 238c8f4

Please sign in to comment.