Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Oct 18, 2024
1 parent b039531 commit e97ea15
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tests/testthat/test_with_torch_settings.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,19 @@ test_that("with_torch_settings leaves global state untouched", {
expect_false(torch_equal(at, bt))
})

test_that("interop threads work", {
test_that("interop threads proper warning message", {
skip_if_not_installed("callr")
# otherwise capture.output does for some reason not capture the warning message
skip_if(!running_on_mac())

result = callr::r(function() {
library(torch)
with_torch_settings = getFromNamespace("with_torch_settings", "mlr3torch")
with_torch_settings(NULL, 1, 2, invisible(NULL))
x1 = tryCatch(with_torch_settings(NULL, 1, 2, invisible(NULL)), warning = identity)$message
x2 = tryCatch(with_torch_settings(NULL, 1, 1, invisible(NULL)), warning = identity)$message
list(x1, x2, torch_get_num_interop_threads())
x1 = capture.output(with_torch_settings(NULL, 1, 2, invisible(NULL)), warning = identity)
x2 = capture.output(with_torch_settings(NULL, 1, 1, invisible(NULL)), warning = identity)
list(x1, x2)
})
expect_true(length(result[[1]]) == 0)
expect_true(grepl("keeping the previous value 2", result[[2]], fixed = TRUE))
expect_equal(result[[3]], 2)
})

0 comments on commit e97ea15

Please sign in to comment.