diff --git a/R/CallbackSetHistory.R b/R/CallbackSetHistory.R index 87888e87..57b82f69 100644 --- a/R/CallbackSetHistory.R +++ b/R/CallbackSetHistory.R @@ -70,7 +70,6 @@ CallbackSetHistory = R6Class("CallbackSetHistory", stopf("No eligible measures to plot for set '%s'.", set) } - epoch = score = measure = NULL if (ncol(data) == 2L) { ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = !!rlang::sym(measures))) + @@ -84,7 +83,7 @@ CallbackSetHistory = R6Class("CallbackSetHistory", theme } else { data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score") - ggplot2::ggplot(data = data, ggplot2::aes_string(x = epoch, y = score, color = measure)) + + ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = score, color = measure)) + viridis::scale_color_viridis(discrete = TRUE) + ggplot2::geom_line() + ggplot2::geom_point() + diff --git a/inst/data_scripts/tiny_imagenet.R b/data-raw/tiny_imagenet.R similarity index 65% rename from inst/data_scripts/tiny_imagenet.R rename to data-raw/tiny_imagenet.R index 6d7a23f2..a4277463 100644 --- a/inst/data_scripts/tiny_imagenet.R +++ b/data-raw/tiny_imagenet.R @@ -2,4 +2,4 @@ devtools::load_all() dir = tempfile() ci = col_info(tsk("tiny_imagenet")$backend$backend) -saveRDS(ci, "./data/col_info/tiny_imagenet.rds") +saveRDS(ci, "./inst/col_info/tiny_imagenet.rds") diff --git a/inst/col_info/tiny_imagenet.rds b/inst/col_info/tiny_imagenet.rds index f2ef6093..e5d125c5 100644 Binary files a/inst/col_info/tiny_imagenet.rds and b/inst/col_info/tiny_imagenet.rds differ diff --git a/tests/testthat/test_TaskClassif_tiny_imagenet.R b/tests/testthat/test_TaskClassif_tiny_imagenet.R index 5ad5e609..679e6864 100644 --- a/tests/testthat/test_TaskClassif_tiny_imagenet.R +++ b/tests/testthat/test_TaskClassif_tiny_imagenet.R @@ -3,10 +3,12 @@ skip_on_cran() test_that("tiny_imagenet task works", { withr::local_options(mlr3torch.cache = TRUE) task = tsk("tiny_imagenet") + dt = task$data() + expect_true("tiny-imagenet-200" %in% list.files(file.path(get_cache_dir(), "datasets", "tiny_imagenet", "raw"))) expect_true("data.rds" %in% list.files(file.path(get_cache_dir(), "datasets", "tiny_imagenet"))) - dt = task$data() + expect_equal(task$backend$nrow, 120000) expect_equal(task$backend$ncol, 4) expect_data_table(dt, ncols = 2, nrows = 100000)