From 2380913c38cb4660ebf32e607954eb92d56ee646 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Tue, 10 Sep 2024 18:39:43 +0200 Subject: [PATCH 01/35] TODO: write tests --- DESCRIPTION | 10 +++- R/CallbackSetTFLog.R | 70 ++++++++++++++++++++++++++ tests/testthat/test_CallbackSetTFLog.R | 8 +++ 3 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 R/CallbackSetTFLog.R create mode 100644 tests/testthat/test_CallbackSetTFLog.R diff --git a/DESCRIPTION b/DESCRIPTION index 18aa8517..a9231e48 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,14 @@ Authors@R: family = "Pfisterer", role = "ctb", email = "pfistererf@googlemail.com", - comment = c(ORCID = "0000-0001-8867-762X"))) + comment = c(ORCID = "0000-0001-8867-762X")), + person(given = "Carson", + family = "Zhang", + role = "ctb", + email = "carsonzhang4@gmail.com") + ), + + Description: Deep Learning library that extends the mlr3 framework by building upon the 'torch' package. It allows to conveniently build, train, and evaluate deep learning models without having to worry about low level @@ -64,6 +71,7 @@ Suggests: viridis, visNetwork, testthat (>= 3.0.0), + tfevents, torchvision (>= 0.6.0), waldo Config/testthat/edition: 3 diff --git a/R/CallbackSetTFLog.R b/R/CallbackSetTFLog.R new file mode 100644 index 00000000..d3005105 --- /dev/null +++ b/R/CallbackSetTFLog.R @@ -0,0 +1,70 @@ +#' @title TensorFlow Logging Callback +#' +#' @name mlr_callback_set.tflog +#' +#' @description +#' Logs the training and validation measures for tracking via TensorBoard. +#' @details +#' TODO: add +#' +#' @param path (`character(1)`)\cr +#' The path to a folder where the events are logged. +#' Point TensorBoard to this folder to view them. +#' @family Callback +#' @export +#' @include CallbackSet.R +CallbackSetTFLog = R6Class("CallbackSetTFLog", + inherit = CallbackSet, + lock_objects = TRUE, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function(path = get_default_logdir()) { + self$path = assert_path_for_output(path) + set_default_logdir(path) + }, + #' @description + #' Logs the training measures as TensorFlow events. + #' Meaningful changes happen at the end of each batch, + #' since this is when the gradient step occurs. + on_batch_end = function() { + log_train_score = function(measure_name) { + train_score = list(self$ctx$last_scores_train[[measure_name]]) + names(train_score) = paste0("train.", measure_name) + do.call(log_event, train_score) + } + + if (length(self$ctx$last_scores_train)) { + map(names(self$ctx$measures_train), log_train_score) + } + }, + #' @description + #' Logs the validation measures as TensorFlow events. + #' Meaningful changes happen at the end of each epoch. + #' Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, + #' and therefore differences are due to randomness + on_epoch_end = function() { + log_valid_score = function(measure_name) { + valid_score = list(self$ctx$last_scores_valid[[measure_name]]) + names(valid_score) = paste0("valid.", measure_name) + do.call(log_event, valid_score) + } + + if (length(self$ctx$last_scores_valid)) { + map(names(self$ctx$measure_valid), log_valid_score) + } + } + ) +) + +mlr3torch_callbacks$add("tflog", function() { + TorchCallback$new( + callback_generator = CallbackSetCheckpoint, + param_set = ps( + path = p_uty(tags = c("train", "required")) + ), + id = "tflog", + label = "TFLog", + man = "mlr3torch::mlr_callback_set.tflog" + ) +}) diff --git a/tests/testthat/test_CallbackSetTFLog.R b/tests/testthat/test_CallbackSetTFLog.R new file mode 100644 index 00000000..63ac409a --- /dev/null +++ b/tests/testthat/test_CallbackSetTFLog.R @@ -0,0 +1,8 @@ +test_that("autotest", { + cb = t_clbk() + expect_torch_callback(cb) +}) + +test_that("", { + +}) \ No newline at end of file From 86f87c805d0155f2c61ac55db57fd9560c36bf0e Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Sun, 22 Sep 2024 19:43:09 +0200 Subject: [PATCH 02/35] name -> TB. began refactoring based on last meeting with Sebastian --- R/{CallbackSetTFLog.R => CallbackSetTB.R} | 70 +++++++++++++++-------- tests/testthat/test_CallbackSetTB.R | 21 +++++++ tests/testthat/test_CallbackSetTFLog.R | 8 --- 3 files changed, 67 insertions(+), 32 deletions(-) rename R/{CallbackSetTFLog.R => CallbackSetTB.R} (52%) create mode 100644 tests/testthat/test_CallbackSetTB.R delete mode 100644 tests/testthat/test_CallbackSetTFLog.R diff --git a/R/CallbackSetTFLog.R b/R/CallbackSetTB.R similarity index 52% rename from R/CallbackSetTFLog.R rename to R/CallbackSetTB.R index d3005105..c699a97c 100644 --- a/R/CallbackSetTFLog.R +++ b/R/CallbackSetTB.R @@ -1,6 +1,6 @@ -#' @title TensorFlow Logging Callback +#' @title TensorBoard Logging Callback #' -#' @name mlr_callback_set.tflog +#' @name mlr_callback_set.tb #' #' @description #' Logs the training and validation measures for tracking via TensorBoard. @@ -13,41 +13,57 @@ #' @family Callback #' @export #' @include CallbackSet.R -CallbackSetTFLog = R6Class("CallbackSetTFLog", +CallbackSetTB = R6Class("CallbackSetTB", inherit = CallbackSet, lock_objects = TRUE, public = list( #' @description #' Creates a new instance of this [R6][R6::R6Class] class. - initialize = function(path = get_default_logdir()) { + initialize = function(path = tempfile()) { self$path = assert_path_for_output(path) - set_default_logdir(path) }, - #' @description - #' Logs the training measures as TensorFlow events. - #' Meaningful changes happen at the end of each batch, - #' since this is when the gradient step occurs. - on_batch_end = function() { - log_train_score = function(measure_name) { - train_score = list(self$ctx$last_scores_train[[measure_name]]) - names(train_score) = paste0("train.", measure_name) - do.call(log_event, train_score) - } + # #' @description + # #' Logs the training measures as TensorFlow events. + # #' Meaningful changes happen at the end of each batch, + # #' since this is when the gradient step occurs. + # # TODO: change this to log last_loss + # on_batch_end = function() { + # # TODO: determine whether you can refactor this and the + # # validation one into a single function + # # need to be able to access self$ctx - if (length(self$ctx$last_scores_train)) { - map(names(self$ctx$measures_train), log_train_score) - } - }, + # # TODO: pass in the appropriate step from the context + # log_event(last_loss = self$ctx$last_loss) + # }, #' @description #' Logs the validation measures as TensorFlow events. #' Meaningful changes happen at the end of each epoch. #' Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, #' and therefore differences are due to randomness + # TODO: log last_scores_train here + # TODO: display the appropriate x axis with its label in TensorBoard + # relevant when we log different scores at different times on_epoch_end = function() { log_valid_score = function(measure_name) { valid_score = list(self$ctx$last_scores_valid[[measure_name]]) names(valid_score) = paste0("valid.", measure_name) - do.call(log_event, valid_score) + with_logdir(temp, { + do.call(log_event, valid_score) + }) + } + + log_train_score = function(measure_name) { + # TODO: change this to use last_loss. I don't recall why we wanted to do that. + train_score = list(self$ctx$last_scores_train[[measure_name]]) + names(train_score) = paste0("train.", measure_name) + with_logdir(temp, { + do.call(log_event, valid_score) + }) + } + + if (length(self$ctx$last_scores_train)) { + # TODO: decide whether we should put the temporary logdir modification here instead. + map(names(self$ctx$measures_train), log_train_score) } if (length(self$ctx$last_scores_valid)) { @@ -55,16 +71,22 @@ CallbackSetTFLog = R6Class("CallbackSetTFLog", } } ) + # private = list( + # log_score = function(prefix, measure_name, score) { + + # } + # ) ) -mlr3torch_callbacks$add("tflog", function() { + +mlr3torch_callbacks$add("tb", function() { TorchCallback$new( callback_generator = CallbackSetCheckpoint, param_set = ps( path = p_uty(tags = c("train", "required")) ), - id = "tflog", - label = "TFLog", - man = "mlr3torch::mlr_callback_set.tflog" + id = "tb", + label = "TensorBoard", + man = "mlr3torch::mlr_callback_set.tb" ) }) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R new file mode 100644 index 00000000..f7387f18 --- /dev/null +++ b/tests/testthat/test_CallbackSetTB.R @@ -0,0 +1,21 @@ +test_that("autotest", { + cb = t_clbk("tb") + expect_torch_callback(cb) +}) + +# TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) +test_that("a simple example works", { + # using a temp dir + + # check that directory doesn't exist + + # check that directory was created + + # check that default logging directory is the directory name we passed in + + # check that the correct training measure name was logged at the correct time (correct epoch) + + # check that the correct validation measure name was logged + + # check that logging happens at the same frequency as eval_freq +}) \ No newline at end of file diff --git a/tests/testthat/test_CallbackSetTFLog.R b/tests/testthat/test_CallbackSetTFLog.R deleted file mode 100644 index 63ac409a..00000000 --- a/tests/testthat/test_CallbackSetTFLog.R +++ /dev/null @@ -1,8 +0,0 @@ -test_that("autotest", { - cb = t_clbk() - expect_torch_callback(cb) -}) - -test_that("", { - -}) \ No newline at end of file From 400ed74e926e3ffc1ae886b73ed6342b74db7919 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Wed, 2 Oct 2024 19:39:24 +0200 Subject: [PATCH 03/35] slight description change --- R/CallbackSetTB.R | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index c699a97c..658d4094 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -3,7 +3,7 @@ #' @name mlr_callback_set.tb #' #' @description -#' Logs the training and validation measures for tracking via TensorBoard. +#' Logs training loss and validation measures as events that can be tracked using TensorBoard. #' @details #' TODO: add #' @@ -47,17 +47,21 @@ CallbackSetTB = R6Class("CallbackSetTB", log_valid_score = function(measure_name) { valid_score = list(self$ctx$last_scores_valid[[measure_name]]) names(valid_score) = paste0("valid.", measure_name) - with_logdir(temp, { + with_logdir(self$path, { do.call(log_event, valid_score) }) } log_train_score = function(measure_name) { - # TODO: change this to use last_loss. I don't recall why we wanted to do that. - train_score = list(self$ctx$last_scores_train[[measure_name]]) - names(train_score) = paste0("train.", measure_name) - with_logdir(temp, { - do.call(log_event, valid_score) + # OLD: previously logged the elements in last_scores_train + # train_score = list(self$ctx$last_scores_train[[measure_name]]) + # names(train_score) = paste0("train.", measure_name) + # with_logdir(temp, { + # do.call(log_event, train_score) + # }) + # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures + with_logdir(self$path, { + log_event(train.loss = self$ctx$last_loss) }) } From 9e6acd8301e1fb67157cdf915c79ef89de3b712a Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Wed, 2 Oct 2024 19:53:46 +0200 Subject: [PATCH 04/35] removed extraneous comments --- R/CallbackSetTB.R | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 658d4094..4543cd26 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -22,25 +22,11 @@ CallbackSetTB = R6Class("CallbackSetTB", initialize = function(path = tempfile()) { self$path = assert_path_for_output(path) }, - # #' @description - # #' Logs the training measures as TensorFlow events. - # #' Meaningful changes happen at the end of each batch, - # #' since this is when the gradient step occurs. - # # TODO: change this to log last_loss - # on_batch_end = function() { - # # TODO: determine whether you can refactor this and the - # # validation one into a single function - # # need to be able to access self$ctx - - # # TODO: pass in the appropriate step from the context - # log_event(last_loss = self$ctx$last_loss) - # }, #' @description - #' Logs the validation measures as TensorFlow events. + #' Logs the training loss and validation measures as TensorFlow events. #' Meaningful changes happen at the end of each epoch. #' Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, #' and therefore differences are due to randomness - # TODO: log last_scores_train here # TODO: display the appropriate x axis with its label in TensorBoard # relevant when we log different scores at different times on_epoch_end = function() { @@ -53,13 +39,8 @@ CallbackSetTB = R6Class("CallbackSetTB", } log_train_score = function(measure_name) { - # OLD: previously logged the elements in last_scores_train - # train_score = list(self$ctx$last_scores_train[[measure_name]]) - # names(train_score) = paste0("train.", measure_name) - # with_logdir(temp, { - # do.call(log_event, train_score) - # }) # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures + # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train with_logdir(self$path, { log_event(train.loss = self$ctx$last_loss) }) From fc4f2faab84c4cf58b3a8903ffa837f7f60dff63 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Wed, 2 Oct 2024 21:24:48 +0200 Subject: [PATCH 05/35] added n_last_loss frequency test --- tests/testthat/test_CallbackSetTB.R | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index f7387f18..4e11e3d0 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -6,16 +6,35 @@ test_that("autotest", { # TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) test_that("a simple example works", { # using a temp dir + cb = t_clbk("tb") # check that directory doesn't exist + expect_false(dir.exists(cb$path)) - # check that directory was created + # check that the correct training measure name was logged at the correct time (correct epoch) + task = tsk("iris") - # check that default logging directory is the directory name we passed in + n_epochs = 10 + batch_size = 50 + neurons = 200 + mlp = lrn("classif.mlp", + callbacks = cb, + epochs = n_epochs, batch_size = batch_size, neurons = neurons, + validate = 0.2, + measures_valid = msrs(c("classif.acc", "classif.ce")), + measures_train = msrs(c("classif.acc", "classif.ce")) + ) - # check that the correct training measure name was logged at the correct time (correct epoch) + mlp$train(task) - # check that the correct validation measure name was logged + events = collect_events(cb$path)$summary %>% + mlr3misc::map(unlist) + # TODO: this but for the validation measures + n_last_loss = mlr3misc::map(\(x) x["tag"] == "last_loss") %>% + unlist() %>% + sum() + expect_equal(n_last_loss, n_epochs) + # check that logging happens at the same frequency as eval_freq }) \ No newline at end of file From 81d1dedc717516d97f5991eec461788baaf89334 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Thu, 10 Oct 2024 23:06:26 +0200 Subject: [PATCH 06/35] in progress --- DESCRIPTION | 2 -- R/CallbackSetTB.R | 14 +++++++++----- tests/testthat/test_CallbackSetTB.R | 23 +++++++++++------------ 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index a9231e48..cb1acb1b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,8 +31,6 @@ Authors@R: role = "ctb", email = "carsonzhang4@gmail.com") ), - - Description: Deep Learning library that extends the mlr3 framework by building upon the 'torch' package. It allows to conveniently build, train, and evaluate deep learning models without having to worry about low level diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 4543cd26..8eeadf0d 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -8,19 +8,23 @@ #' TODO: add #' #' @param path (`character(1)`)\cr -#' The path to a folder where the events are logged. +#' The path to a folder where the events are logged. #' Point TensorBoard to this folder to view them. #' @family Callback #' @export #' @include CallbackSet.R CallbackSetTB = R6Class("CallbackSetTB", inherit = CallbackSet, - lock_objects = TRUE, + lock_objects = FALSE, public = list( + path = NULL, #' @description #' Creates a new instance of this [R6][R6::R6Class] class. - initialize = function(path = tempfile()) { - self$path = assert_path_for_output(path) + initialize = function(path) { + self$path = assert_path_for_output(path) + if (!dir.exists(path)) { + dir.create(path, recursive = TRUE) + } }, #' @description #' Logs the training loss and validation measures as TensorFlow events. @@ -58,7 +62,7 @@ CallbackSetTB = R6Class("CallbackSetTB", ) # private = list( # log_score = function(prefix, measure_name, score) { - + # } # ) ) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 4e11e3d0..72e389ec 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -1,6 +1,6 @@ -test_that("autotest", { - cb = t_clbk("tb") - expect_torch_callback(cb) +test_that("basic", { + cb = t_clbk("tb", path = tempfile()) + expect_torch_callback(cb, check_man = FALSE) }) # TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) @@ -11,30 +11,29 @@ test_that("a simple example works", { # check that directory doesn't exist expect_false(dir.exists(cb$path)) - # check that the correct training measure name was logged at the correct time (correct epoch) task = tsk("iris") - n_epochs = 10 batch_size = 50 neurons = 200 - mlp = lrn("classif.mlp", + mlp = lrn("classif.mlp", callbacks = cb, epochs = n_epochs, batch_size = batch_size, neurons = neurons, - validate = 0.2, - measures_valid = msrs(c("classif.acc", "classif.ce")), + validate = 0.2, + measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")) ) - mlp$train(task) events = collect_events(cb$path)$summary %>% mlr3misc::map(unlist) - # TODO: this but for the validation measures n_last_loss = mlr3misc::map(\(x) x["tag"] == "last_loss") %>% unlist() %>% sum() expect_equal(n_last_loss, n_epochs) - + + # TODO: check that the correct training measure name was logged at the correct time (correct epoch) + # TODO: check that the correct validation measure name was logged at the correct time (correct epoch) + # check that logging happens at the same frequency as eval_freq -}) \ No newline at end of file +}) From cb03eb3719c55e90599ed9fc43ef0bfccd0f8e58 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Thu, 10 Oct 2024 23:17:18 +0200 Subject: [PATCH 07/35] autotest working, accidentally used the wrong callback_generator --- R/CallbackSetTB.R | 3 +-- tests/testthat/test_CallbackSetTB.R | 11 ++++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 8eeadf0d..fd016c2c 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -17,7 +17,6 @@ CallbackSetTB = R6Class("CallbackSetTB", inherit = CallbackSet, lock_objects = FALSE, public = list( - path = NULL, #' @description #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(path) { @@ -70,7 +69,7 @@ CallbackSetTB = R6Class("CallbackSetTB", mlr3torch_callbacks$add("tb", function() { TorchCallback$new( - callback_generator = CallbackSetCheckpoint, + callback_generator = CallbackSetTB, param_set = ps( path = p_uty(tags = c("train", "required")) ), diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 72e389ec..61684153 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -8,13 +8,13 @@ test_that("a simple example works", { # using a temp dir cb = t_clbk("tb") - # check that directory doesn't exist - expect_false(dir.exists(cb$path)) - task = tsk("iris") n_epochs = 10 batch_size = 50 neurons = 200 + + pth0 = tempfile() + mlp = lrn("classif.mlp", callbacks = cb, epochs = n_epochs, batch_size = batch_size, neurons = neurons, @@ -22,6 +22,11 @@ test_that("a simple example works", { measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")) ) + mlp$param_set$set_values(cb.tb.path = pth0) + + # check that directory doesn't exist + expect_false(dir.exists(mlp$param_set$get_values(path))) + mlp$train(task) events = collect_events(cb$path)$summary %>% From 78b95a5f6567b85e0132891138043c97126d7dc3 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 08:16:37 +0200 Subject: [PATCH 08/35] simple and eval_freq tests pass --- R/CallbackSetTB.R | 9 ++-- tests/testthat/test_CallbackSetTB.R | 75 +++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 15 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index fd016c2c..0efe792d 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -41,7 +41,7 @@ CallbackSetTB = R6Class("CallbackSetTB", }) } - log_train_score = function(measure_name) { + log_train_score = function() { # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train with_logdir(self$path, { @@ -49,13 +49,10 @@ CallbackSetTB = R6Class("CallbackSetTB", }) } - if (length(self$ctx$last_scores_train)) { - # TODO: decide whether we should put the temporary logdir modification here instead. - map(names(self$ctx$measures_train), log_train_score) - } + log_train_score() if (length(self$ctx$last_scores_valid)) { - map(names(self$ctx$measure_valid), log_valid_score) + map(names(self$ctx$measures_valid), log_valid_score) } } ) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 61684153..4cbeabf5 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -1,3 +1,5 @@ +library(tfevents) + test_that("basic", { cb = t_clbk("tb", path = tempfile()) expect_torch_callback(cb, check_man = FALSE) @@ -24,21 +26,76 @@ test_that("a simple example works", { ) mlp$param_set$set_values(cb.tb.path = pth0) - # check that directory doesn't exist - expect_false(dir.exists(mlp$param_set$get_values(path))) - mlp$train(task) - events = collect_events(cb$path)$summary %>% + events = collect_events(pth0)$summary %>% mlr3misc::map(unlist) - n_last_loss = mlr3misc::map(\(x) x["tag"] == "last_loss") %>% + n_last_loss_events = mlr3misc::map(events, \(x) x["tag"] == "train.loss") %>% unlist() %>% sum() - expect_equal(n_last_loss, n_epochs) - # TODO: check that the correct training measure name was logged at the correct time (correct epoch) - # TODO: check that the correct validation measure name was logged at the correct time (correct epoch) + n_valid_acc_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc") %>% + unlist() %>% + sum() + + n_valid_ce_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce") %>% + unlist() %>% + sum() + + # TODO: refactor to expect a specific ordering of the events list + expect_equal(n_last_loss_events, n_epochs) + expect_equal(n_valid_acc_events, n_epochs) + expect_equal(n_valid_ce_events, n_epochs) +}) + +test_that("eval_freq works", { + # using a temp dir + cb = t_clbk("tb") + + task = tsk("iris") + n_epochs = 9 + batch_size = 50 + neurons = 200 + eval_freq = 4 + + pth0 = tempfile() + + mlp = lrn("classif.mlp", + callbacks = cb, + epochs = n_epochs, batch_size = batch_size, neurons = neurons, + validate = 0.2, + measures_valid = msrs(c("classif.acc", "classif.ce")), + measures_train = msrs(c("classif.acc", "classif.ce")), + eval_freq = eval_freq + ) + mlp$param_set$set_values(cb.tb.path = pth0) + + mlp$train(task) + + events = collect_events(pth0)$summary %>% + mlr3misc::map(unlist) + + n_last_loss_events = mlr3misc::map(events, \(x) x["tag"] == "train.loss") %>% + unlist() %>% + sum() + + n_valid_acc_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc") %>% + unlist() %>% + sum() + + n_valid_ce_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce") %>% + unlist() %>% + sum() + + expect_equal(n_last_loss_events, n_epochs) + expect_equal(n_valid_acc_events, ceiling(n_epochs / 4)) + expect_equal(n_valid_ce_events, ceiling(n_epochs / 4)) +}) - # check that logging happens at the same frequency as eval_freq +test_that("throws an error when using existing directory", { + path = tempfile() + dir.create(path) + cb = t_clbk("tb", path = path) + expect_error(cb$generate(), "already exists") }) From a365757ee3bc63d9da2908441fb13f9e8327f762 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 08:31:01 +0200 Subject: [PATCH 09/35] changed logging methods to private --- R/CallbackSetTB.R | 45 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 0efe792d..10a7bc33 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -33,34 +33,33 @@ CallbackSetTB = R6Class("CallbackSetTB", # TODO: display the appropriate x axis with its label in TensorBoard # relevant when we log different scores at different times on_epoch_end = function() { - log_valid_score = function(measure_name) { - valid_score = list(self$ctx$last_scores_valid[[measure_name]]) - names(valid_score) = paste0("valid.", measure_name) - with_logdir(self$path, { - do.call(log_event, valid_score) - }) - } - - log_train_score = function() { - # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures - # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train - with_logdir(self$path, { - log_event(train.loss = self$ctx$last_loss) - }) - } - - log_train_score() + private$log_train_score() if (length(self$ctx$last_scores_valid)) { - map(names(self$ctx$measures_valid), log_valid_score) + map(names(self$ctx$measures_valid), private$log_valid_score) } } + ), + private = list( + # TODO: refactor into a single function with the following signature + # log_score = function(prefix, measure_name, score) { + # + # }, + log_valid_score = function(measure_name) { + valid_score = list(self$ctx$last_scores_valid[[measure_name]]) + names(valid_score) = paste0("valid.", measure_name) + with_logdir(self$path, { + do.call(log_event, valid_score) + }) + }, + log_train_score = function() { + # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures + # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train + with_logdir(self$path, { + log_event(train.loss = self$ctx$last_loss) + }) + } ) - # private = list( - # log_score = function(prefix, measure_name, score) { - - # } - # ) ) From 43a8ffb4d9f48177600fed1df28846aaea46acb2 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 08:38:33 +0200 Subject: [PATCH 10/35] removed magrittr pipe from tests --- R/CallbackSetTB.R | 108 ++++++++++++++-------------- tests/testthat/test_CallbackSetTB.R | 92 ++++++++++-------------- 2 files changed, 93 insertions(+), 107 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 10a7bc33..07593a50 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -14,63 +14,63 @@ #' @export #' @include CallbackSet.R CallbackSetTB = R6Class("CallbackSetTB", - inherit = CallbackSet, - lock_objects = FALSE, - public = list( - #' @description - #' Creates a new instance of this [R6][R6::R6Class] class. - initialize = function(path) { - self$path = assert_path_for_output(path) - if (!dir.exists(path)) { - dir.create(path, recursive = TRUE) - } - }, - #' @description - #' Logs the training loss and validation measures as TensorFlow events. - #' Meaningful changes happen at the end of each epoch. - #' Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, - #' and therefore differences are due to randomness - # TODO: display the appropriate x axis with its label in TensorBoard - # relevant when we log different scores at different times - on_epoch_end = function() { - private$log_train_score() + inherit = CallbackSet, + lock_objects = FALSE, + public = list( + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + initialize = function(path) { + self$path = assert_path_for_output(path) + if (!dir.exists(path)) { + dir.create(path, recursive = TRUE) + } + }, + #' @description + #' Logs the training loss and validation measures as TensorFlow events. + #' Meaningful changes happen at the end of each epoch. + #' Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, + #' and therefore differences are due to randomness + # TODO: display the appropriate x axis with its label in TensorBoard + # relevant when we log different scores at different times + on_epoch_end = function() { + private$log_train_score() - if (length(self$ctx$last_scores_valid)) { - map(names(self$ctx$measures_valid), private$log_valid_score) - } - } - ), - private = list( - # TODO: refactor into a single function with the following signature - # log_score = function(prefix, measure_name, score) { - # - # }, - log_valid_score = function(measure_name) { - valid_score = list(self$ctx$last_scores_valid[[measure_name]]) - names(valid_score) = paste0("valid.", measure_name) - with_logdir(self$path, { - do.call(log_event, valid_score) - }) - }, - log_train_score = function() { - # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures - # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train - with_logdir(self$path, { - log_event(train.loss = self$ctx$last_loss) - }) - } - ) + if (length(self$ctx$last_scores_valid)) { + map(names(self$ctx$measures_valid), private$log_valid_score) + } + } + ), + private = list( + # TODO: refactor into a single function with the following signature + # log_score = function(prefix, measure_name, score) { + # + # }, + log_valid_score = function(measure_name) { + valid_score = list(self$ctx$last_scores_valid[[measure_name]]) + names(valid_score) = paste0("valid.", measure_name) + with_logdir(self$path, { + do.call(log_event, valid_score) + }) + }, + log_train_score = function() { + # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures + # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train + with_logdir(self$path, { + log_event(train.loss = self$ctx$last_loss) + }) + } + ) ) mlr3torch_callbacks$add("tb", function() { - TorchCallback$new( - callback_generator = CallbackSetTB, - param_set = ps( - path = p_uty(tags = c("train", "required")) - ), - id = "tb", - label = "TensorBoard", - man = "mlr3torch::mlr_callback_set.tb" - ) + TorchCallback$new( + callback_generator = CallbackSetTB, + param_set = ps( + path = p_uty(tags = c("train", "required")) + ), + id = "tb", + label = "TensorBoard", + man = "mlr3torch::mlr_callback_set.tb" + ) }) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 4cbeabf5..8ae879f1 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -1,52 +1,45 @@ library(tfevents) test_that("basic", { - cb = t_clbk("tb", path = tempfile()) - expect_torch_callback(cb, check_man = FALSE) + cb = t_clbk("tb", path = tempfile()) + expect_torch_callback(cb, check_man = FALSE) }) # TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) test_that("a simple example works", { - # using a temp dir - cb = t_clbk("tb") - - task = tsk("iris") - n_epochs = 10 - batch_size = 50 - neurons = 200 - - pth0 = tempfile() - - mlp = lrn("classif.mlp", - callbacks = cb, - epochs = n_epochs, batch_size = batch_size, neurons = neurons, - validate = 0.2, - measures_valid = msrs(c("classif.acc", "classif.ce")), - measures_train = msrs(c("classif.acc", "classif.ce")) - ) - mlp$param_set$set_values(cb.tb.path = pth0) - - mlp$train(task) - - events = collect_events(pth0)$summary %>% - mlr3misc::map(unlist) - - n_last_loss_events = mlr3misc::map(events, \(x) x["tag"] == "train.loss") %>% - unlist() %>% - sum() - - n_valid_acc_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc") %>% - unlist() %>% - sum() - - n_valid_ce_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce") %>% - unlist() %>% - sum() - - # TODO: refactor to expect a specific ordering of the events list - expect_equal(n_last_loss_events, n_epochs) - expect_equal(n_valid_acc_events, n_epochs) - expect_equal(n_valid_ce_events, n_epochs) + # using a temp dir + cb = t_clbk("tb") + + task = tsk("iris") + n_epochs = 10 + batch_size = 50 + neurons = 200 + + pth0 = tempfile() + + mlp = lrn("classif.mlp", + callbacks = cb, + epochs = n_epochs, batch_size = batch_size, neurons = neurons, + validate = 0.2, + measures_valid = msrs(c("classif.acc", "classif.ce")), + measures_train = msrs(c("classif.acc", "classif.ce")) + ) + mlp$param_set$set_values(cb.tb.path = pth0) + + mlp$train(task) + + events = mlr3misc::map(collect_events(pth0)$summary, unlist) + + n_last_loss_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "train.loss"))) + + n_valid_acc_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc"))) + + n_valid_ce_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce"))) + + # TODO: refactor to expect a specific ordering of the events list + expect_equal(n_last_loss_events, n_epochs) + expect_equal(n_valid_acc_events, n_epochs) + expect_equal(n_valid_ce_events, n_epochs) }) test_that("eval_freq works", { @@ -73,20 +66,13 @@ test_that("eval_freq works", { mlp$train(task) - events = collect_events(pth0)$summary %>% - mlr3misc::map(unlist) + events = mlr3misc::map(collect_events(pth0)$summary, unlist) - n_last_loss_events = mlr3misc::map(events, \(x) x["tag"] == "train.loss") %>% - unlist() %>% - sum() + n_last_loss_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "train.loss"))) - n_valid_acc_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc") %>% - unlist() %>% - sum() + n_valid_acc_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc"))) - n_valid_ce_events = mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce") %>% - unlist() %>% - sum() + n_valid_ce_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce"))) expect_equal(n_last_loss_events, n_epochs) expect_equal(n_valid_acc_events, ceiling(n_epochs / 4)) From 6b9a8453d1fa3f295c580203d905e31efe5f9616 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 08:40:46 +0200 Subject: [PATCH 11/35] added details for callback class --- R/CallbackSetTB.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 07593a50..4bdbf491 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -5,7 +5,7 @@ #' @description #' Logs training loss and validation measures as events that can be tracked using TensorBoard. #' @details -#' TODO: add +#' Logs at most every epoch. #' #' @param path (`character(1)`)\cr #' The path to a folder where the events are logged. @@ -53,7 +53,6 @@ CallbackSetTB = R6Class("CallbackSetTB", }) }, log_train_score = function() { - # TODO: figure out what self$ctx$last_loss looks like when there are multiple train measures # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train with_logdir(self$path, { log_event(train.loss = self$ctx$last_loss) From d354b2c09e91c2242be6bc19f2ba60eb6bd25c54 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 08:42:25 +0200 Subject: [PATCH 12/35] formatting --- tests/testthat/test_CallbackSetTB.R | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 8ae879f1..440d8ccf 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -31,9 +31,7 @@ test_that("a simple example works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) n_last_loss_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "train.loss"))) - n_valid_acc_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc"))) - n_valid_ce_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce"))) # TODO: refactor to expect a specific ordering of the events list @@ -69,9 +67,7 @@ test_that("eval_freq works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) n_last_loss_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "train.loss"))) - n_valid_acc_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc"))) - n_valid_ce_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce"))) expect_equal(n_last_loss_events, n_epochs) From b5b27b13f86386bf90f2f6074fecd11f0de6bcf5 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 10:59:14 +0200 Subject: [PATCH 13/35] built docs --- DESCRIPTION | 3 +- NAMESPACE | 1 + man/TorchCallback.Rd | 1 + man/as_torch_callback.Rd | 1 + man/as_torch_callbacks.Rd | 1 + man/callback_set.Rd | 1 + man/mlr3torch-package.Rd | 1 + man/mlr3torch_callbacks.Rd | 1 + man/mlr_callback_set.Rd | 1 + man/mlr_callback_set.checkpoint.Rd | 1 + man/mlr_callback_set.progress.Rd | 1 + man/mlr_callback_set.tb.Rd | 97 +++++++++++++++++++++++++++++ man/mlr_context_torch.Rd | 1 + man/mlr_learners.torchvision.Rd | 6 +- man/t_clbk.Rd | 1 + man/torch_callback.Rd | 1 + tests/testthat/test_CallbackSetTB.R | 2 +- 17 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 man/mlr_callback_set.tb.Rd diff --git a/DESCRIPTION b/DESCRIPTION index cb1acb1b..cd08ca4a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -30,7 +30,7 @@ Authors@R: family = "Zhang", role = "ctb", email = "carsonzhang4@gmail.com") - ), + ) Description: Deep Learning library that extends the mlr3 framework by building upon the 'torch' package. It allows to conveniently build, train, and evaluate deep learning models without having to worry about low level @@ -86,6 +86,7 @@ Collate: 'CallbackSetEarlyStopping.R' 'CallbackSetHistory.R' 'CallbackSetProgress.R' + 'CallbackSetTB.R' 'ContextTorch.R' 'DataBackendLazy.R' 'utils.R' diff --git a/NAMESPACE b/NAMESPACE index ef61f4e3..d3f3593c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -63,6 +63,7 @@ export(CallbackSet) export(CallbackSetCheckpoint) export(CallbackSetHistory) export(CallbackSetProgress) +export(CallbackSetTB) export(ContextTorch) export(DataBackendLazy) export(DataDescriptor) diff --git a/man/TorchCallback.Rd b/man/TorchCallback.Rd index 32d30ecc..d2d5bc9e 100644 --- a/man/TorchCallback.Rd +++ b/man/TorchCallback.Rd @@ -61,6 +61,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/as_torch_callback.Rd b/man/as_torch_callback.Rd index db5feeb8..51416f7a 100644 --- a/man/as_torch_callback.Rd +++ b/man/as_torch_callback.Rd @@ -31,6 +31,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/as_torch_callbacks.Rd b/man/as_torch_callbacks.Rd index 563a4251..e3fb8442 100644 --- a/man/as_torch_callbacks.Rd +++ b/man/as_torch_callbacks.Rd @@ -31,6 +31,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/callback_set.Rd b/man/callback_set.Rd index 4ad98f06..4fb2d46b 100644 --- a/man/callback_set.Rd +++ b/man/callback_set.Rd @@ -81,6 +81,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/mlr3torch-package.Rd b/man/mlr3torch-package.Rd index 77aeb3c0..31cf3ec2 100644 --- a/man/mlr3torch-package.Rd +++ b/man/mlr3torch-package.Rd @@ -39,6 +39,7 @@ Other contributors: \item Bernd Bischl \email{bernd_bischl@gmx.net} (\href{https://orcid.org/0000-0001-6002-6980}{ORCID}) [contributor] \item Lukas Burk \email{github@quantenbrot.de} (\href{https://orcid.org/0000-0001-7528-3795}{ORCID}) [contributor] \item Florian Pfisterer \email{pfistererf@googlemail.com} (\href{https://orcid.org/0000-0001-8867-762X}{ORCID}) [contributor] + \item Carson Zhang \email{carsonzhang4@gmail.com} [contributor] } } diff --git a/man/mlr3torch_callbacks.Rd b/man/mlr3torch_callbacks.Rd index 3fec81d6..ef923915 100644 --- a/man/mlr3torch_callbacks.Rd +++ b/man/mlr3torch_callbacks.Rd @@ -34,6 +34,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/mlr_callback_set.Rd b/man/mlr_callback_set.Rd index 54afcbe2..9b5e6be0 100644 --- a/man/mlr_callback_set.Rd +++ b/man/mlr_callback_set.Rd @@ -68,6 +68,7 @@ Other Callback: \code{\link{mlr3torch_callbacks}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/mlr_callback_set.checkpoint.Rd b/man/mlr_callback_set.checkpoint.Rd index 92da34ce..fcb846ac 100644 --- a/man/mlr_callback_set.checkpoint.Rd +++ b/man/mlr_callback_set.checkpoint.Rd @@ -21,6 +21,7 @@ Other Callback: \code{\link{mlr3torch_callbacks}}, \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/mlr_callback_set.progress.Rd b/man/mlr_callback_set.progress.Rd index 93b6af7c..8927f654 100644 --- a/man/mlr_callback_set.progress.Rd +++ b/man/mlr_callback_set.progress.Rd @@ -16,6 +16,7 @@ Other Callback: \code{\link{mlr3torch_callbacks}}, \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} diff --git a/man/mlr_callback_set.tb.Rd b/man/mlr_callback_set.tb.Rd new file mode 100644 index 00000000..cfa1cbb4 --- /dev/null +++ b/man/mlr_callback_set.tb.Rd @@ -0,0 +1,97 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackSetTB.R +\name{mlr_callback_set.tb} +\alias{mlr_callback_set.tb} +\alias{CallbackSetTB} +\title{TensorBoard Logging Callback} +\description{ +Logs training loss and validation measures as events that can be tracked using TensorBoard. +} +\details{ +Logs at most every epoch. +} +\seealso{ +Other Callback: +\code{\link{TorchCallback}}, +\code{\link{as_torch_callback}()}, +\code{\link{as_torch_callbacks}()}, +\code{\link{callback_set}()}, +\code{\link{mlr3torch_callbacks}}, +\code{\link{mlr_callback_set}}, +\code{\link{mlr_callback_set.checkpoint}}, +\code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_context_torch}}, +\code{\link{t_clbk}()}, +\code{\link{torch_callback}()} +} +\concept{Callback} +\section{Super class}{ +\code{\link[mlr3torch:CallbackSet]{mlr3torch::CallbackSet}} -> \code{CallbackSetTB} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-CallbackSetTB-new}{\code{CallbackSetTB$new()}} +\item \href{#method-CallbackSetTB-on_epoch_end}{\code{CallbackSetTB$on_epoch_end()}} +\item \href{#method-CallbackSetTB-clone}{\code{CallbackSetTB$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackSetTB-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackSetTB$new(path)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{path}}{(\code{character(1)})\cr +The path to a folder where the events are logged. +Point TensorBoard to this folder to view them.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackSetTB-on_epoch_end}{}}} +\subsection{Method \code{on_epoch_end()}}{ +Logs the training loss and validation measures as TensorFlow events. +Meaningful changes happen at the end of each epoch. +Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, +and therefore differences are due to randomness +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackSetTB$on_epoch_end()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackSetTB-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackSetTB$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/mlr_context_torch.Rd b/man/mlr_context_torch.Rd index c0913f20..474a4fff 100644 --- a/man/mlr_context_torch.Rd +++ b/man/mlr_context_torch.Rd @@ -19,6 +19,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{t_clbk}()}, \code{\link{torch_callback}()} } diff --git a/man/mlr_learners.torchvision.Rd b/man/mlr_learners.torchvision.Rd index dd0123ad..e1f4f9ba 100644 --- a/man/mlr_learners.torchvision.Rd +++ b/man/mlr_learners.torchvision.Rd @@ -89,9 +89,9 @@ Krizhevsky, Alex, Sutskever, Ilya, Hinton, E. G (2017). Sandler, Mark, Howard, Andrew, Zhu, Menglong, Zhmoginov, Andrey, Chen, Liang-Chieh (2018). \dQuote{Mobilenetv2: Inverted residuals and linear bottlenecks.} In \emph{Proceedings of the IEEE conference on computer vision and pattern recognition}, 4510--4520. -He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian (2016 ). -\dQuote{Deep residual learning for image recognition .} -In \emph{Proceedings of the IEEE conference on computer vision and pattern recognition }, 770--778 . +He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian (2016). +\dQuote{Deep residual learning for image recognition.} +In \emph{Proceedings of the IEEE conference on computer vision and pattern recognition}, 770--778. Simonyan, Karen, Zisserman, Andrew (2014). \dQuote{Very deep convolutional networks for large-scale image recognition.} \emph{arXiv preprint arXiv:1409.1556}.} diff --git a/man/t_clbk.Rd b/man/t_clbk.Rd index c329ab46..fdac6ea2 100644 --- a/man/t_clbk.Rd +++ b/man/t_clbk.Rd @@ -43,6 +43,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{torch_callback}()} diff --git a/man/torch_callback.Rd b/man/torch_callback.Rd index ae54583a..2d5568e3 100644 --- a/man/torch_callback.Rd +++ b/man/torch_callback.Rd @@ -149,6 +149,7 @@ Other Callback: \code{\link{mlr_callback_set}}, \code{\link{mlr_callback_set.checkpoint}}, \code{\link{mlr_callback_set.progress}}, +\code{\link{mlr_callback_set.tb}}, \code{\link{mlr_context_torch}}, \code{\link{t_clbk}()} } diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 440d8ccf..cbf031cb 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -34,7 +34,7 @@ test_that("a simple example works", { n_valid_acc_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc"))) n_valid_ce_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce"))) - # TODO: refactor to expect a specific ordering of the events list + # TODO: refactor to expect a specific ordering of the events list, not just the right counts expect_equal(n_last_loss_events, n_epochs) expect_equal(n_valid_acc_events, n_epochs) expect_equal(n_valid_ce_events, n_epochs) From 7c9f431d3e368d838ec081d0fa37d15dd33831fb Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 17:37:06 +0200 Subject: [PATCH 14/35] all tests pass, I think this is parity with the previous broken commit. still need to incorporate the step logging --- R/CallbackSetTB.R | 35 ++++++++---- man/mlr_callback_set.tb.Rd | 5 +- man/mlr_learners.mlp.Rd | 1 - man/mlr_learners.tab_resnet.Rd | 1 - man/mlr_learners.torch_featureless.Rd | 1 - man/mlr_learners.torchvision.Rd | 1 - man/mlr_learners_torch.Rd | 1 - man/mlr_learners_torch_image.Rd | 1 - man/mlr_learners_torch_model.Rd | 1 - tests/testthat/test_CallbackSetTB.R | 82 ++++++++++++++++++++++----- 10 files changed, 96 insertions(+), 33 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 4bdbf491..7424ba7d 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -10,6 +10,8 @@ #' @param path (`character(1)`)\cr #' The path to a folder where the events are logged. #' Point TensorBoard to this folder to view them. +#' @param log_train_loss (`logical(1)`)\cr +#' Whether we log the training loss. #' @family Callback #' @export #' @include CallbackSet.R @@ -19,11 +21,12 @@ CallbackSetTB = R6Class("CallbackSetTB", public = list( #' @description #' Creates a new instance of this [R6][R6::R6Class] class. - initialize = function(path) { + initialize = function(path, log_train_loss) { self$path = assert_path_for_output(path) if (!dir.exists(path)) { dir.create(path, recursive = TRUE) } + self$log_train_loss = assert_logical(log_train_loss) }, #' @description #' Logs the training loss and validation measures as TensorFlow events. @@ -33,26 +36,35 @@ CallbackSetTB = R6Class("CallbackSetTB", # TODO: display the appropriate x axis with its label in TensorBoard # relevant when we log different scores at different times on_epoch_end = function() { - private$log_train_score() + if (self$log_train_loss) { + private$.log_train_loss() + } + + if (length(self$ctx$last_scores_train)) { + map(names(self$ctx$measures_train), private$.log_train_score) + } if (length(self$ctx$last_scores_valid)) { - map(names(self$ctx$measures_valid), private$log_valid_score) + map(names(self$ctx$measures_valid), private$.log_valid_score) } } ), private = list( - # TODO: refactor into a single function with the following signature - # log_score = function(prefix, measure_name, score) { - # - # }, - log_valid_score = function(measure_name) { + .log_valid_score = function(measure_name) { valid_score = list(self$ctx$last_scores_valid[[measure_name]]) names(valid_score) = paste0("valid.", measure_name) with_logdir(self$path, { do.call(log_event, valid_score) }) }, - log_train_score = function() { + .log_train_score = function(measure_name) { + train_score = list(self$ctx$last_scores_train[[measure_name]]) + names(train_score) = paste0("train.", measure_name) + with_logdir(self$path, { + do.call(log_event, train_score) + }) + }, + .log_train_loss = function() { # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train with_logdir(self$path, { log_event(train.loss = self$ctx$last_loss) @@ -61,12 +73,13 @@ CallbackSetTB = R6Class("CallbackSetTB", ) ) - +#' @include TorchCallback.R mlr3torch_callbacks$add("tb", function() { TorchCallback$new( callback_generator = CallbackSetTB, param_set = ps( - path = p_uty(tags = c("train", "required")) + path = p_uty(tags = c("train", "required")), + log_train_loss = p_lgl(tags = c("train", "required")) ), id = "tb", label = "TensorBoard", diff --git a/man/mlr_callback_set.tb.Rd b/man/mlr_callback_set.tb.Rd index cfa1cbb4..a5f226a5 100644 --- a/man/mlr_callback_set.tb.Rd +++ b/man/mlr_callback_set.tb.Rd @@ -51,7 +51,7 @@ Other Callback: \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{CallbackSetTB$new(path)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{CallbackSetTB$new(path, log_train_loss)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -60,6 +60,9 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. \item{\code{path}}{(\code{character(1)})\cr The path to a folder where the events are logged. Point TensorBoard to this folder to view them.} + +\item{\code{log_train_loss}}{(\code{logical(1)})\cr +Whether we log the training loss.} } \if{html}{\out{}} } diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index 6eb586aa..e6809199 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -100,7 +100,6 @@ Other Learner:
Inherited methods
  • mlr3::Learner$base_learner()
  • -
  • mlr3::Learner$encapsulate()
  • mlr3::Learner$help()
  • mlr3::Learner$predict()
  • mlr3::Learner$predict_newdata()
  • diff --git a/man/mlr_learners.tab_resnet.Rd b/man/mlr_learners.tab_resnet.Rd index 1d8b9e8d..b78be86c 100644 --- a/man/mlr_learners.tab_resnet.Rd +++ b/man/mlr_learners.tab_resnet.Rd @@ -102,7 +102,6 @@ Other Learner:
    Inherited methods
    • mlr3::Learner$base_learner()
    • -
    • mlr3::Learner$encapsulate()
    • mlr3::Learner$help()
    • mlr3::Learner$predict()
    • mlr3::Learner$predict_newdata()
    • diff --git a/man/mlr_learners.torch_featureless.Rd b/man/mlr_learners.torch_featureless.Rd index 1fb6274a..58f316e3 100644 --- a/man/mlr_learners.torch_featureless.Rd +++ b/man/mlr_learners.torch_featureless.Rd @@ -86,7 +86,6 @@ Other Learner:
      Inherited methods
      • mlr3::Learner$base_learner()
      • -
      • mlr3::Learner$encapsulate()
      • mlr3::Learner$help()
      • mlr3::Learner$predict()
      • mlr3::Learner$predict_newdata()
      • diff --git a/man/mlr_learners.torchvision.Rd b/man/mlr_learners.torchvision.Rd index 87883dd9..e1f4f9ba 100644 --- a/man/mlr_learners.torchvision.Rd +++ b/man/mlr_learners.torchvision.Rd @@ -42,7 +42,6 @@ number of classes inferred from the \code{\link[mlr3:Task]{Task}}.
        Inherited methods
        • mlr3::Learner$base_learner()
        • -
        • mlr3::Learner$encapsulate()
        • mlr3::Learner$help()
        • mlr3::Learner$predict()
        • mlr3::Learner$predict_newdata()
        • diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 5797f246..748b1b06 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -265,7 +265,6 @@ which are varied systematically during tuning (parameter values).}
          Inherited methods
          • mlr3::Learner$base_learner()
          • -
          • mlr3::Learner$encapsulate()
          • mlr3::Learner$help()
          • mlr3::Learner$predict()
          • mlr3::Learner$predict_newdata()
          • diff --git a/man/mlr_learners_torch_image.Rd b/man/mlr_learners_torch_image.Rd index af2b854d..723a4e22 100644 --- a/man/mlr_learners_torch_image.Rd +++ b/man/mlr_learners_torch_image.Rd @@ -36,7 +36,6 @@ Other Learner:
            Inherited methods
            • mlr3::Learner$base_learner()
            • -
            • mlr3::Learner$encapsulate()
            • mlr3::Learner$help()
            • mlr3::Learner$predict()
            • mlr3::Learner$predict_newdata()
            • diff --git a/man/mlr_learners_torch_model.Rd b/man/mlr_learners_torch_model.Rd index da6fa008..505b8fbd 100644 --- a/man/mlr_learners_torch_model.Rd +++ b/man/mlr_learners_torch_model.Rd @@ -92,7 +92,6 @@ The ingress tokens. Must be non-\code{NULL} when calling \verb{$train()}.}
              Inherited methods
              • mlr3::Learner$base_learner()
              • -
              • mlr3::Learner$encapsulate()
              • mlr3::Learner$help()
              • mlr3::Learner$predict()
              • mlr3::Learner$predict_newdata()
              • diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index cbf031cb..f2c3818f 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -1,8 +1,8 @@ library(tfevents) -test_that("basic", { - cb = t_clbk("tb", path = tempfile()) - expect_torch_callback(cb, check_man = FALSE) +test_that("autotest", { + cb = t_clbk("tb", path = tempfile(), log_train_loss = TRUE) + expect_torch_callback(cb, check_man = TRUE) }) # TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) @@ -17,6 +17,8 @@ test_that("a simple example works", { pth0 = tempfile() + log_train_loss = TRUE + mlp = lrn("classif.mlp", callbacks = cb, epochs = n_epochs, batch_size = batch_size, neurons = neurons, @@ -26,16 +28,26 @@ test_that("a simple example works", { ) mlp$param_set$set_values(cb.tb.path = pth0) + mlp$param_set$set_values(cb.tb.log_train_loss = log_train_loss) + mlp$train(task) events = mlr3misc::map(collect_events(pth0)$summary, unlist) - n_last_loss_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "train.loss"))) - n_valid_acc_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc"))) - n_valid_ce_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce"))) + event_tag_is = function(event, tag_name) { + ifelse(is.null(event), FALSE, event["tag"] == tag_name) + } + + n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) + n_train_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.acc"))) + n_train_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.ce"))) + n_valid_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.acc"))) + n_valid_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.ce"))) # TODO: refactor to expect a specific ordering of the events list, not just the right counts - expect_equal(n_last_loss_events, n_epochs) + expect_equal(n_train_loss_events, n_epochs) + expect_equal(n_train_acc_events, n_epochs) + expect_equal(n_train_ce_events, n_epochs) expect_equal(n_valid_acc_events, n_epochs) expect_equal(n_valid_ce_events, n_epochs) }) @@ -52,6 +64,8 @@ test_that("eval_freq works", { pth0 = tempfile() + log_train_loss = TRUE + mlp = lrn("classif.mlp", callbacks = cb, epochs = n_epochs, batch_size = batch_size, neurons = neurons, @@ -61,23 +75,63 @@ test_that("eval_freq works", { eval_freq = eval_freq ) mlp$param_set$set_values(cb.tb.path = pth0) + mlp$param_set$set_values(cb.tb.log_train_loss = log_train_loss) + + mlp$train(task) + + events = mlr3misc::map(collect_events(pth0)$summary, unlist) + + event_tag_is = function(event, tag_name) { + ifelse(is.null(event), FALSE, event["tag"] == tag_name) + } + + n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) + n_train_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.acc"))) + n_train_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.ce"))) + n_valid_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.acc"))) + n_valid_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.ce"))) + + expect_equal(n_train_loss_events, n_epochs) + expect_equal(n_train_acc_events, ceiling(n_epochs / eval_freq)) + expect_equal(n_train_ce_events, ceiling(n_epochs / eval_freq)) + expect_equal(n_valid_acc_events, ceiling(n_epochs / eval_freq)) + expect_equal(n_valid_ce_events, ceiling(n_epochs / eval_freq)) +}) + +test_that("the flag for tracking the train loss works", { + cb = t_clbk("tb") + + task = tsk("iris") + n_epochs = 10 + batch_size = 50 + neurons = 200 + + log_train_loss = FALSE + + pth0 = tempfile() + + mlp = lrn("classif.mlp", + callbacks = cb, + epochs = n_epochs, batch_size = batch_size, neurons = neurons, + validate = 0.2, + measures_valid = msrs(c("classif.acc", "classif.ce")), + measures_train = msrs(c("classif.acc", "classif.ce")) + ) + mlp$param_set$set_values(cb.tb.path = pth0) + mlp$param_set$set_values(cb.tb.log_train_loss = log_train_loss) mlp$train(task) events = mlr3misc::map(collect_events(pth0)$summary, unlist) - n_last_loss_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "train.loss"))) - n_valid_acc_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.acc"))) - n_valid_ce_events = sum(unlist(mlr3misc::map(events, \(x) x["tag"] == "valid.classif.ce"))) + n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) - expect_equal(n_last_loss_events, n_epochs) - expect_equal(n_valid_acc_events, ceiling(n_epochs / 4)) - expect_equal(n_valid_ce_events, ceiling(n_epochs / 4)) + expect_equal(n_train_loss_events, 0) }) test_that("throws an error when using existing directory", { path = tempfile() dir.create(path) - cb = t_clbk("tb", path = path) + cb = t_clbk("tb", path = path, log_train_loss = TRUE) expect_error(cb$generate(), "already exists") }) From c6c93336d42b794abc36a974b6657352a3f3d514 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 17:57:22 +0200 Subject: [PATCH 15/35] implemented step logging --- R/CallbackSetTB.R | 23 +++++++++++------------ tests/testthat/test_CallbackSetTB.R | 4 ++++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 7424ba7d..55bbe7cb 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -30,9 +30,6 @@ CallbackSetTB = R6Class("CallbackSetTB", }, #' @description #' Logs the training loss and validation measures as TensorFlow events. - #' Meaningful changes happen at the end of each epoch. - #' Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, - #' and therefore differences are due to randomness # TODO: display the appropriate x axis with its label in TensorBoard # relevant when we log different scores at different times on_epoch_end = function() { @@ -50,19 +47,21 @@ CallbackSetTB = R6Class("CallbackSetTB", } ), private = list( - .log_valid_score = function(measure_name) { - valid_score = list(self$ctx$last_scores_valid[[measure_name]]) - names(valid_score) = paste0("valid.", measure_name) + .log_score = function(prefix, measure_name, score) { + event_list = list(score, self$ctx$epoch) + names(event_list) = c(paste0(prefix, measure_name), "step") + with_logdir(self$path, { - do.call(log_event, valid_score) + do.call(log_event, event_list) }) }, + .log_valid_score = function(measure_name) { + valid_score = self$ctx$last_scores_valid[[measure_name]] + private$.log_score("valid.", measure_name, valid_score) + }, .log_train_score = function(measure_name) { - train_score = list(self$ctx$last_scores_train[[measure_name]]) - names(train_score) = paste0("train.", measure_name) - with_logdir(self$path, { - do.call(log_event, train_score) - }) + train_score = self$ctx$last_scores_train[[measure_name]] + private$.log_score("train.", measure_name, train_score) }, .log_train_loss = function() { # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index f2c3818f..5f7b56fb 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -124,6 +124,10 @@ test_that("the flag for tracking the train loss works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) + event_tag_is = function(event, tag_name) { + ifelse(is.null(event), FALSE, event["tag"] == tag_name) + } + n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) expect_equal(n_train_loss_events, 0) From 43e7396e08820e3c9c9ee8e5417d8deabab8aa4f Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 18:01:46 +0200 Subject: [PATCH 16/35] removed extraneous comments --- NEWS.md | 1 + R/CallbackSetTB.R | 3 --- tests/testthat/test_CallbackSetTB.R | 2 -- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index b072d959..9bf2927e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # mlr3torch dev +* Added `CallbackSetTB`, which allows logging that can be viewed by TensorBoard. * Don't use deprecated `data_formats` anymore # mlr3torch 0.1.1 diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 55bbe7cb..b50b0a35 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -30,8 +30,6 @@ CallbackSetTB = R6Class("CallbackSetTB", }, #' @description #' Logs the training loss and validation measures as TensorFlow events. - # TODO: display the appropriate x axis with its label in TensorBoard - # relevant when we log different scores at different times on_epoch_end = function() { if (self$log_train_loss) { private$.log_train_loss() @@ -64,7 +62,6 @@ CallbackSetTB = R6Class("CallbackSetTB", private$.log_score("train.", measure_name, train_score) }, .log_train_loss = function() { - # TODO: remind ourselves why we wanted to display last_loss and not last_scores_train with_logdir(self$path, { log_event(train.loss = self$ctx$last_loss) }) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 5f7b56fb..c89b07e2 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -7,7 +7,6 @@ test_that("autotest", { # TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) test_that("a simple example works", { - # using a temp dir cb = t_clbk("tb") task = tsk("iris") @@ -53,7 +52,6 @@ test_that("a simple example works", { }) test_that("eval_freq works", { - # using a temp dir cb = t_clbk("tb") task = tsk("iris") From ec5d8fc8bc2ddcaa5d25742bb4f0d4fff7e9849d Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 18:07:49 +0200 Subject: [PATCH 17/35] added tensorboard instructions --- R/CallbackSetTB.R | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index b50b0a35..8c2a8351 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -3,9 +3,10 @@ #' @name mlr_callback_set.tb #' #' @description -#' Logs training loss and validation measures as events that can be tracked using TensorBoard. +#' Logs training loss, training measures, and validation measures as events. +#' To view them, use TensorBoard with `tensorflow::tensorboard()` (requires `tensorflow`) or the CLI. #' @details -#' Logs at most every epoch. +#' Logs events at most every epoch. #' #' @param path (`character(1)`)\cr #' The path to a folder where the events are logged. @@ -29,7 +30,7 @@ CallbackSetTB = R6Class("CallbackSetTB", self$log_train_loss = assert_logical(log_train_loss) }, #' @description - #' Logs the training loss and validation measures as TensorFlow events. + #' Logs the training loss, training measures, and validation measures as TensorFlow events. on_epoch_end = function() { if (self$log_train_loss) { private$.log_train_loss() From f26a2544065b6bf52d83265d359d2e1c384d47aa Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Fri, 11 Oct 2024 18:40:54 +0200 Subject: [PATCH 18/35] passes R CMD Check, minimally addresses every comment in the previous PR --- man/mlr_callback_set.tb.Rd | 10 ++++------ man/mlr_learners.mlp.Rd | 1 + man/mlr_learners.tab_resnet.Rd | 1 + man/mlr_learners.torch_featureless.Rd | 1 + man/mlr_learners.torchvision.Rd | 1 + man/mlr_learners_torch.Rd | 1 + man/mlr_learners_torch_image.Rd | 1 + man/mlr_learners_torch_model.Rd | 1 + tests/testthat/test_CallbackSetTB.R | 16 ++++------------ 9 files changed, 15 insertions(+), 18 deletions(-) diff --git a/man/mlr_callback_set.tb.Rd b/man/mlr_callback_set.tb.Rd index a5f226a5..ab0fb6ae 100644 --- a/man/mlr_callback_set.tb.Rd +++ b/man/mlr_callback_set.tb.Rd @@ -5,10 +5,11 @@ \alias{CallbackSetTB} \title{TensorBoard Logging Callback} \description{ -Logs training loss and validation measures as events that can be tracked using TensorBoard. +Logs training loss, training measures, and validation measures as events. +To view them, use TensorBoard with \code{tensorflow::tensorboard()} (requires \code{tensorflow}) or the CLI. } \details{ -Logs at most every epoch. +Logs events at most every epoch. } \seealso{ Other Callback: @@ -71,10 +72,7 @@ Whether we log the training loss.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-CallbackSetTB-on_epoch_end}{}}} \subsection{Method \code{on_epoch_end()}}{ -Logs the training loss and validation measures as TensorFlow events. -Meaningful changes happen at the end of each epoch. -Notably NOT on_batch_valid_end, since there are no gradient steps between validation batches, -and therefore differences are due to randomness +Logs the training loss, training measures, and validation measures as TensorFlow events. \subsection{Usage}{ \if{html}{\out{
                }}\preformatted{CallbackSetTB$on_epoch_end()}\if{html}{\out{
                }} } diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index e6809199..6eb586aa 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -100,6 +100,7 @@ Other Learner:
                Inherited methods
                • mlr3::Learner$base_learner()
                • +
                • mlr3::Learner$encapsulate()
                • mlr3::Learner$help()
                • mlr3::Learner$predict()
                • mlr3::Learner$predict_newdata()
                • diff --git a/man/mlr_learners.tab_resnet.Rd b/man/mlr_learners.tab_resnet.Rd index b78be86c..1d8b9e8d 100644 --- a/man/mlr_learners.tab_resnet.Rd +++ b/man/mlr_learners.tab_resnet.Rd @@ -102,6 +102,7 @@ Other Learner:
                  Inherited methods
                  • mlr3::Learner$base_learner()
                  • +
                  • mlr3::Learner$encapsulate()
                  • mlr3::Learner$help()
                  • mlr3::Learner$predict()
                  • mlr3::Learner$predict_newdata()
                  • diff --git a/man/mlr_learners.torch_featureless.Rd b/man/mlr_learners.torch_featureless.Rd index 58f316e3..1fb6274a 100644 --- a/man/mlr_learners.torch_featureless.Rd +++ b/man/mlr_learners.torch_featureless.Rd @@ -86,6 +86,7 @@ Other Learner:
                    Inherited methods
                    • mlr3::Learner$base_learner()
                    • +
                    • mlr3::Learner$encapsulate()
                    • mlr3::Learner$help()
                    • mlr3::Learner$predict()
                    • mlr3::Learner$predict_newdata()
                    • diff --git a/man/mlr_learners.torchvision.Rd b/man/mlr_learners.torchvision.Rd index e1f4f9ba..87883dd9 100644 --- a/man/mlr_learners.torchvision.Rd +++ b/man/mlr_learners.torchvision.Rd @@ -42,6 +42,7 @@ number of classes inferred from the \code{\link[mlr3:Task]{Task}}.
                      Inherited methods
                      • mlr3::Learner$base_learner()
                      • +
                      • mlr3::Learner$encapsulate()
                      • mlr3::Learner$help()
                      • mlr3::Learner$predict()
                      • mlr3::Learner$predict_newdata()
                      • diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 748b1b06..5797f246 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -265,6 +265,7 @@ which are varied systematically during tuning (parameter values).}
                        Inherited methods
                        • mlr3::Learner$base_learner()
                        • +
                        • mlr3::Learner$encapsulate()
                        • mlr3::Learner$help()
                        • mlr3::Learner$predict()
                        • mlr3::Learner$predict_newdata()
                        • diff --git a/man/mlr_learners_torch_image.Rd b/man/mlr_learners_torch_image.Rd index 723a4e22..af2b854d 100644 --- a/man/mlr_learners_torch_image.Rd +++ b/man/mlr_learners_torch_image.Rd @@ -36,6 +36,7 @@ Other Learner:
                          Inherited methods
                          • mlr3::Learner$base_learner()
                          • +
                          • mlr3::Learner$encapsulate()
                          • mlr3::Learner$help()
                          • mlr3::Learner$predict()
                          • mlr3::Learner$predict_newdata()
                          • diff --git a/man/mlr_learners_torch_model.Rd b/man/mlr_learners_torch_model.Rd index 505b8fbd..da6fa008 100644 --- a/man/mlr_learners_torch_model.Rd +++ b/man/mlr_learners_torch_model.Rd @@ -92,6 +92,7 @@ The ingress tokens. Must be non-\code{NULL} when calling \verb{$train()}.}
                            Inherited methods
                            • mlr3::Learner$base_learner()
                            • +
                            • mlr3::Learner$encapsulate()
                            • mlr3::Learner$help()
                            • mlr3::Learner$predict()
                            • mlr3::Learner$predict_newdata()
                            • diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index c89b07e2..8a894ec4 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -1,5 +1,9 @@ library(tfevents) +event_tag_is = function(event, tag_name) { + ifelse(is.null(event), FALSE, event["tag"] == tag_name) +} + test_that("autotest", { cb = t_clbk("tb", path = tempfile(), log_train_loss = TRUE) expect_torch_callback(cb, check_man = TRUE) @@ -33,10 +37,6 @@ test_that("a simple example works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) - event_tag_is = function(event, tag_name) { - ifelse(is.null(event), FALSE, event["tag"] == tag_name) - } - n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) n_train_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.acc"))) n_train_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.ce"))) @@ -79,10 +79,6 @@ test_that("eval_freq works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) - event_tag_is = function(event, tag_name) { - ifelse(is.null(event), FALSE, event["tag"] == tag_name) - } - n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) n_train_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.acc"))) n_train_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.ce"))) @@ -122,10 +118,6 @@ test_that("the flag for tracking the train loss works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) - event_tag_is = function(event, tag_name) { - ifelse(is.null(event), FALSE, event["tag"] == tag_name) - } - n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) expect_equal(n_train_loss_events, 0) From a86c9461569716b550076b0518f858c81e311639 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Sun, 13 Oct 2024 21:37:26 +0200 Subject: [PATCH 19/35] moved newest news to bottom --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 9bf2927e..f2d36348 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,7 @@ # mlr3torch dev -* Added `CallbackSetTB`, which allows logging that can be viewed by TensorBoard. * Don't use deprecated `data_formats` anymore +* Added `CallbackSetTB`, which allows logging that can be viewed by TensorBoard. # mlr3torch 0.1.1 From 74757a7b09c37afbef10f590cb1a9214684ef1cd Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:33:43 +0200 Subject: [PATCH 20/35] logical -> flag, since the length of this arg must be 1 Co-authored-by: Sebastian Fischer --- R/CallbackSetTB.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 8c2a8351..902e20aa 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -27,7 +27,7 @@ CallbackSetTB = R6Class("CallbackSetTB", if (!dir.exists(path)) { dir.create(path, recursive = TRUE) } - self$log_train_loss = assert_logical(log_train_loss) + self$log_train_loss = assert_flag(log_train_loss) }, #' @description #' Logs the training loss, training measures, and validation measures as TensorFlow events. From 72d23f42425789541d7ad0008e4e3506060a26ab Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:34:20 +0200 Subject: [PATCH 21/35] "TensorBoard events" appears to be a more idiomatic phrase Co-authored-by: Sebastian Fischer --- R/CallbackSetTB.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 902e20aa..6f68e1d3 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -30,7 +30,7 @@ CallbackSetTB = R6Class("CallbackSetTB", self$log_train_loss = assert_flag(log_train_loss) }, #' @description - #' Logs the training loss, training measures, and validation measures as TensorFlow events. + #' Logs the training loss, training measures, and validation measures as TensorBoard events. on_epoch_end = function() { if (self$log_train_loss) { private$.log_train_loss() From beaca4316a6fe019d38a826acc3d8299d27088ba Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:34:50 +0200 Subject: [PATCH 22/35] map() -> walk(), since we don't use the return value of map() Co-authored-by: Sebastian Fischer --- R/CallbackSetTB.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 6f68e1d3..8eb02e85 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -37,7 +37,7 @@ CallbackSetTB = R6Class("CallbackSetTB", } if (length(self$ctx$last_scores_train)) { - map(names(self$ctx$measures_train), private$.log_train_score) + walk(names(self$ctx$measures_train), private$.log_train_score) } if (length(self$ctx$last_scores_valid)) { From 1bf2939ddbabfb8073ab01307118cee390424aa3 Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:35:13 +0200 Subject: [PATCH 23/35] map() -> walk(), since we don't use the return value of map() Co-authored-by: Sebastian Fischer --- R/CallbackSetTB.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 8eb02e85..ba88475a 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -41,7 +41,7 @@ CallbackSetTB = R6Class("CallbackSetTB", } if (length(self$ctx$last_scores_valid)) { - map(names(self$ctx$measures_valid), private$.log_valid_score) + walk(names(self$ctx$measures_valid), private$.log_valid_score) } } ), From 03aad62b100132ac2fbae628169b5670abf266cd Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:41:02 +0200 Subject: [PATCH 24/35] Apply suggestions from code review setnames() is cleaner, allows us to create a list and set the names in a single line Co-authored-by: Sebastian Fischer --- R/CallbackSetTB.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index ba88475a..1d5468f8 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -47,8 +47,7 @@ CallbackSetTB = R6Class("CallbackSetTB", ), private = list( .log_score = function(prefix, measure_name, score) { - event_list = list(score, self$ctx$epoch) - names(event_list) = c(paste0(prefix, measure_name), "step") + event_list = set_names(list(score, self$ctx$epoch), c(paste0(prefix, measure_name), "step")) with_logdir(self$path, { do.call(log_event, event_list) From e20909245b3677e77e88c72726608da011f97544 Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:41:45 +0200 Subject: [PATCH 25/35] add a default value for the log_train_loss param Co-authored-by: Sebastian Fischer --- R/CallbackSetTB.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 1d5468f8..c10be46e 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -75,7 +75,7 @@ mlr3torch_callbacks$add("tb", function() { callback_generator = CallbackSetTB, param_set = ps( path = p_uty(tags = c("train", "required")), - log_train_loss = p_lgl(tags = c("train", "required")) + log_train_loss = p_lgl(tags = c("train", "required"), init = FALSE) ), id = "tb", label = "TensorBoard", From 694ea85aa7022bd048872876215271a502e7efd8 Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:42:06 +0200 Subject: [PATCH 26/35] add package dependency Co-authored-by: Sebastian Fischer --- R/CallbackSetTB.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index c10be46e..79a9cc9f 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -78,6 +78,7 @@ mlr3torch_callbacks$add("tb", function() { log_train_loss = p_lgl(tags = c("train", "required"), init = FALSE) ), id = "tb", + packages = "tfevents", label = "TensorBoard", man = "mlr3torch::mlr_callback_set.tb" ) From a8f741d9d79065612b3c69ca1bc0e08de2f634a0 Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:42:42 +0200 Subject: [PATCH 27/35] " Now that we have a default no need to specify log_train_loss" Co-authored-by: Sebastian Fischer --- tests/testthat/test_CallbackSetTB.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 8a894ec4..23b29b9c 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -5,7 +5,7 @@ event_tag_is = function(event, tag_name) { } test_that("autotest", { - cb = t_clbk("tb", path = tempfile(), log_train_loss = TRUE) + cb = t_clbk("tb", path = tempfile()) expect_torch_callback(cb, check_man = TRUE) }) From 795b3f155d4f516a9de5b26c4ecf48f78e69c351 Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:43:06 +0200 Subject: [PATCH 28/35] better test description Co-authored-by: Sebastian Fischer --- tests/testthat/test_CallbackSetTB.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 23b29b9c..a02fc0a0 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -10,7 +10,7 @@ test_that("autotest", { }) # TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) -test_that("a simple example works", { +test_that("metrics are logged correctly", { cb = t_clbk("tb") task = tsk("iris") From 6b96aed7ef40fba03a2ac1be045cd621110aff74 Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Tue, 15 Oct 2024 18:43:41 +0200 Subject: [PATCH 29/35] Apply suggestions from code review Make test more efficient Co-authored-by: Sebastian Fischer --- tests/testthat/test_CallbackSetTB.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index a02fc0a0..3361900e 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -15,8 +15,8 @@ test_that("metrics are logged correctly", { task = tsk("iris") n_epochs = 10 - batch_size = 50 - neurons = 200 + batch_size = 150 + neurons = 10 pth0 = tempfile() From 2113faf2adda95738c96467d424309dcd469dbc9 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Tue, 15 Oct 2024 18:48:56 +0200 Subject: [PATCH 30/35] unlist(map -> map_lgl --- tests/testthat/test_CallbackSetTB.R | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 3361900e..ea7ed1c0 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -37,11 +37,11 @@ test_that("metrics are logged correctly", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) - n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) - n_train_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.acc"))) - n_train_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.ce"))) - n_valid_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.acc"))) - n_valid_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.ce"))) + n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) + n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc")) + n_train_ce_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.ce")) + n_valid_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "valid.classif.acc")) + n_valid_ce_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "valid.classif.ce")) # TODO: refactor to expect a specific ordering of the events list, not just the right counts expect_equal(n_train_loss_events, n_epochs) @@ -79,11 +79,11 @@ test_that("eval_freq works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) - n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) - n_train_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.acc"))) - n_train_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.classif.ce"))) - n_valid_acc_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.acc"))) - n_valid_ce_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "valid.classif.ce"))) + n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) + n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc")) + n_train_ce_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.ce")) + n_valid_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "valid.classif.acc")) + n_valid_ce_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "valid.classif.ce")) expect_equal(n_train_loss_events, n_epochs) expect_equal(n_train_acc_events, ceiling(n_epochs / eval_freq)) @@ -118,7 +118,7 @@ test_that("the flag for tracking the train loss works", { events = mlr3misc::map(collect_events(pth0)$summary, unlist) - n_train_loss_events = sum(unlist(mlr3misc::map(events, event_tag_is, tag_name = "train.loss"))) + n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) expect_equal(n_train_loss_events, 0) }) From afbb0742725fe291a12e272d3d8a82bcd4c93748 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Wed, 16 Oct 2024 16:48:01 +0200 Subject: [PATCH 31/35] removed library import in test file, added tfevents:: when we use a function from tfevents --- R/CallbackSetTB.R | 8 ++++---- man/mlr_callback_set.tb.Rd | 2 +- tests/testthat/test_CallbackSetTB.R | 8 +++----- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/R/CallbackSetTB.R b/R/CallbackSetTB.R index 79a9cc9f..ed404009 100644 --- a/R/CallbackSetTB.R +++ b/R/CallbackSetTB.R @@ -49,8 +49,8 @@ CallbackSetTB = R6Class("CallbackSetTB", .log_score = function(prefix, measure_name, score) { event_list = set_names(list(score, self$ctx$epoch), c(paste0(prefix, measure_name), "step")) - with_logdir(self$path, { - do.call(log_event, event_list) + tfevents::with_logdir(self$path, { + do.call(tfevents::log_event, event_list) }) }, .log_valid_score = function(measure_name) { @@ -62,8 +62,8 @@ CallbackSetTB = R6Class("CallbackSetTB", private$.log_score("train.", measure_name, train_score) }, .log_train_loss = function() { - with_logdir(self$path, { - log_event(train.loss = self$ctx$last_loss) + tfevents::with_logdir(self$path, { + tfevents::log_event(train.loss = self$ctx$last_loss) }) } ) diff --git a/man/mlr_callback_set.tb.Rd b/man/mlr_callback_set.tb.Rd index ab0fb6ae..63ca837d 100644 --- a/man/mlr_callback_set.tb.Rd +++ b/man/mlr_callback_set.tb.Rd @@ -72,7 +72,7 @@ Whether we log the training loss.} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-CallbackSetTB-on_epoch_end}{}}} \subsection{Method \code{on_epoch_end()}}{ -Logs the training loss, training measures, and validation measures as TensorFlow events. +Logs the training loss, training measures, and validation measures as TensorBoard events. \subsection{Usage}{ \if{html}{\out{
                              }}\preformatted{CallbackSetTB$on_epoch_end()}\if{html}{\out{
                              }} } diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index ea7ed1c0..8be5c2e4 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -1,5 +1,3 @@ -library(tfevents) - event_tag_is = function(event, tag_name) { ifelse(is.null(event), FALSE, event["tag"] == tag_name) } @@ -35,7 +33,7 @@ test_that("metrics are logged correctly", { mlp$train(task) - events = mlr3misc::map(collect_events(pth0)$summary, unlist) + events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist) n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc")) @@ -77,7 +75,7 @@ test_that("eval_freq works", { mlp$train(task) - events = mlr3misc::map(collect_events(pth0)$summary, unlist) + events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist) n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc")) @@ -116,7 +114,7 @@ test_that("the flag for tracking the train loss works", { mlp$train(task) - events = mlr3misc::map(collect_events(pth0)$summary, unlist) + events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist) n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) From 074f73b8c246600ae5522fc4a602a0395fee80df Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Wed, 16 Oct 2024 16:49:52 +0200 Subject: [PATCH 32/35] removed extra TODO, don't bind to variables when the value is only used once --- tests/testthat/test_CallbackSetTB.R | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 8be5c2e4..66e01ca2 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -16,10 +16,6 @@ test_that("metrics are logged correctly", { batch_size = 150 neurons = 10 - pth0 = tempfile() - - log_train_loss = TRUE - mlp = lrn("classif.mlp", callbacks = cb, epochs = n_epochs, batch_size = batch_size, neurons = neurons, @@ -27,9 +23,8 @@ test_that("metrics are logged correctly", { measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")) ) - mlp$param_set$set_values(cb.tb.path = pth0) - - mlp$param_set$set_values(cb.tb.log_train_loss = log_train_loss) + mlp$param_set$set_values(cb.tb.path = tempfile()) + mlp$param_set$set_values(cb.tb.log_train_loss = TRUE) mlp$train(task) @@ -41,7 +36,6 @@ test_that("metrics are logged correctly", { n_valid_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "valid.classif.acc")) n_valid_ce_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "valid.classif.ce")) - # TODO: refactor to expect a specific ordering of the events list, not just the right counts expect_equal(n_train_loss_events, n_epochs) expect_equal(n_train_acc_events, n_epochs) expect_equal(n_train_ce_events, n_epochs) From 5f7d56f26e4feceec0ba2460fc7d60901ce16f35 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Wed, 16 Oct 2024 16:53:03 +0200 Subject: [PATCH 33/35] remove more unnecessary vars, increase batch size to make tests run faster --- tests/testthat/test_CallbackSetTB.R | 42 +++++++---------------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 66e01ca2..4f80f3b9 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -12,13 +12,10 @@ test_that("metrics are logged correctly", { cb = t_clbk("tb") task = tsk("iris") - n_epochs = 10 - batch_size = 150 - neurons = 10 mlp = lrn("classif.mlp", callbacks = cb, - epochs = n_epochs, batch_size = batch_size, neurons = neurons, + epochs = 10, batch_size = 150, neurons = 10, validate = 0.2, measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")) @@ -44,28 +41,18 @@ test_that("metrics are logged correctly", { }) test_that("eval_freq works", { - cb = t_clbk("tb") - task = tsk("iris") - n_epochs = 9 - batch_size = 50 - neurons = 200 - eval_freq = 4 - - pth0 = tempfile() - - log_train_loss = TRUE mlp = lrn("classif.mlp", - callbacks = cb, - epochs = n_epochs, batch_size = batch_size, neurons = neurons, + callbacks = t_clbk("tb"), + epochs = 9, batch_size = 150, neurons = 200, validate = 0.2, measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")), - eval_freq = eval_freq + eval_freq = 4 ) - mlp$param_set$set_values(cb.tb.path = pth0) - mlp$param_set$set_values(cb.tb.log_train_loss = log_train_loss) + mlp$param_set$set_values(cb.tb.path = tempfile()) + mlp$param_set$set_values(cb.tb.log_train_loss = TRUE) mlp$train(task) @@ -85,26 +72,17 @@ test_that("eval_freq works", { }) test_that("the flag for tracking the train loss works", { - cb = t_clbk("tb") - task = tsk("iris") - n_epochs = 10 - batch_size = 50 - neurons = 200 - - log_train_loss = FALSE - - pth0 = tempfile() mlp = lrn("classif.mlp", - callbacks = cb, - epochs = n_epochs, batch_size = batch_size, neurons = neurons, + callbacks = t_clbk("tb"), + epochs = 10, batch_size = 150, neurons = 200, validate = 0.2, measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")) ) - mlp$param_set$set_values(cb.tb.path = pth0) - mlp$param_set$set_values(cb.tb.log_train_loss = log_train_loss) + mlp$param_set$set_values(cb.tb.path = tempfile() + mlp$param_set$set_values(cb.tb.log_train_loss = FALSE) mlp$train(task) From a1ddb7d2ab0b16b65d15ba6e6d22856900f5ae23 Mon Sep 17 00:00:00 2001 From: Carson Zhang Date: Wed, 16 Oct 2024 17:08:07 +0200 Subject: [PATCH 34/35] access path field of callback, ifelse -> if --- tests/testthat/test_CallbackSetTB.R | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 4f80f3b9..685a80bf 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -1,5 +1,5 @@ event_tag_is = function(event, tag_name) { - ifelse(is.null(event), FALSE, event["tag"] == tag_name) + if (is.null(event)) FALSE else event["tag"] == tag_name } test_that("autotest", { @@ -13,9 +13,11 @@ test_that("metrics are logged correctly", { task = tsk("iris") + n_epochs = 10 + mlp = lrn("classif.mlp", callbacks = cb, - epochs = 10, batch_size = 150, neurons = 10, + epochs = n_epochs, batch_size = 150, neurons = 10, validate = 0.2, measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")) @@ -25,7 +27,7 @@ test_that("metrics are logged correctly", { mlp$train(task) - events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist) + events = mlr3misc::map(tfevents::collect_events(mlp$param_set$get_values()$cb.tb.path)$summary, unlist) n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc")) @@ -43,20 +45,23 @@ test_that("metrics are logged correctly", { test_that("eval_freq works", { task = tsk("iris") + n_epochs = 9 + eval_freq = 4 + mlp = lrn("classif.mlp", callbacks = t_clbk("tb"), - epochs = 9, batch_size = 150, neurons = 200, + epochs = n_epochs, batch_size = 150, neurons = 200, validate = 0.2, measures_valid = msrs(c("classif.acc", "classif.ce")), measures_train = msrs(c("classif.acc", "classif.ce")), - eval_freq = 4 + eval_freq = eval_freq ) mlp$param_set$set_values(cb.tb.path = tempfile()) mlp$param_set$set_values(cb.tb.log_train_loss = TRUE) mlp$train(task) - events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist) + events = mlr3misc::map(tfevents::collect_events(mlp$param_set$get_values()$cb.tb.path)$summary, unlist) n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) n_train_acc_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.classif.acc")) @@ -71,22 +76,19 @@ test_that("eval_freq works", { expect_equal(n_valid_ce_events, ceiling(n_epochs / eval_freq)) }) -test_that("the flag for tracking the train loss works", { +test_that("we can disable training loss tracking", { task = tsk("iris") mlp = lrn("classif.mlp", callbacks = t_clbk("tb"), - epochs = 10, batch_size = 150, neurons = 200, - validate = 0.2, - measures_valid = msrs(c("classif.acc", "classif.ce")), - measures_train = msrs(c("classif.acc", "classif.ce")) + epochs = 10, batch_size = 150, neurons = 200 ) - mlp$param_set$set_values(cb.tb.path = tempfile() + mlp$param_set$set_values(cb.tb.path = tempfile()) mlp$param_set$set_values(cb.tb.log_train_loss = FALSE) mlp$train(task) - events = mlr3misc::map(tfevents::collect_events(pth0)$summary, unlist) + events = mlr3misc::map(tfevents::collect_events(mlp$param_set$get_values()$cb.tb.path)$summary, unlist) n_train_loss_events = sum(mlr3misc::map_lgl(events, event_tag_is, tag_name = "train.loss")) From 39576149b5765c40252e5a40e408f9fc24588e1d Mon Sep 17 00:00:00 2001 From: cxzhang4 Date: Fri, 18 Oct 2024 11:09:55 +0200 Subject: [PATCH 35/35] remove old todo --- tests/testthat/test_CallbackSetTB.R | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/testthat/test_CallbackSetTB.R b/tests/testthat/test_CallbackSetTB.R index 685a80bf..2faa457c 100644 --- a/tests/testthat/test_CallbackSetTB.R +++ b/tests/testthat/test_CallbackSetTB.R @@ -7,7 +7,6 @@ test_that("autotest", { expect_torch_callback(cb, check_man = TRUE) }) -# TODO: investigate what's happening when there is only a single epoch (why don't we log anything?) test_that("metrics are logged correctly", { cb = t_clbk("tb")