Skip to content

Commit

Permalink
fix: don't use tasks' divide method (deprecated soon)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Aug 18, 2024
1 parent 9fc3e2e commit fef4cdb
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 13 deletions.
3 changes: 1 addition & 2 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
#' * `NULL`: no validation
#' * `ratio`: only proportion `1 - ratio` of the task is used for training and `ratio` is used for validation.
#' * `"test"` means that the `"test"` task of a resampling is used and is not possible when calling `$train()` manually.
#' * `"predefined"`: This will use the predefined `$internal_valid_task` of a [`mlr3::Task`], which can e.g.
#' be created using the `$divide()` method of `Task`.
#' * `"predefined"`: This will use the predefined `$internal_valid_task` of a [`mlr3::Task`].
#'
#' This validation data can also be used for early stopping, see the description of the `Learner`'s parameters.
#'
Expand Down
6 changes: 3 additions & 3 deletions man/mlr_learners.torchvision.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test_CallbackSetHistory.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_that("Autotest", {
test_that("CallbackSetHistory works", {
cb = t_clbk("history")
task = tsk("iris")
task$divide(ids = 2)
task$internal_valid_task = task$clone(deep = TRUE)$filter(2)
task$filter(1)

learner = lrn("classif.mlp", epochs = 3, batch_size = 1, callbacks = t_clbk("history"), validate = "predefined")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_CallbackSetProgress.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ test_that("manual test", {
drop_last = FALSE, shuffle = TRUE, validate = "predefined"
)
task = tsk("iris")
task$divide(ids = 2)
task$internal_valid_task = task$clone(deep = TRUE)$filter(2)
task$filter(1)

# Because the validation is so short, it does not show in the example
Expand Down
4 changes: 1 addition & 3 deletions tests/testthat/test_LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ test_that("train parameters do what they should: classification and regression",
)

# first we test everything with validation

task$divide(ratio = 2 / 3)
learner$validate = 0.3
learner$train(task)

internals = learner$model$callbacks$internals
Expand Down Expand Up @@ -287,7 +286,6 @@ test_that("train parameters do what they should: classification and regression",
expect_permutation(c("epoch", ids(measures_valid)), colnames(learner$model$callbacks$history$valid))

# now without validation
task$internal_valid_task = NULL
learner$validate = NULL

learner$state = NULL
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test_learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ test_that("torch_network_predict works", {

test_that("Validation Task is respected", {
task = tsk("iris")
task$divide(ids = 1:10)
task$internal_valid_task = task$clone(deep = TRUE)$filter(1:10)
task$row_roles$use = 1:10

learner = lrn("classif.torch_featureless", epochs = 2, batch_size = 1, measures_train = msrs(c("classif.acc")),
callbacks = t_clbk("history"), validate = "predefined"
Expand Down

0 comments on commit fef4cdb

Please sign in to comment.