Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/tflog callback #290

Merged
merged 37 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
2380913
TODO: write tests
cxzhang4 Sep 10, 2024
86f87c8
name -> TB. began refactoring based on last meeting with Sebastian
cxzhang4 Sep 22, 2024
400ed74
slight description change
cxzhang4 Oct 2, 2024
9e6acd8
removed extraneous comments
cxzhang4 Oct 2, 2024
fc4f2fa
added n_last_loss frequency test
cxzhang4 Oct 2, 2024
81d1ded
in progress
cxzhang4 Oct 10, 2024
cb03eb3
autotest working, accidentally used the wrong callback_generator
cxzhang4 Oct 10, 2024
78b95a5
simple and eval_freq tests pass
cxzhang4 Oct 11, 2024
a365757
changed logging methods to private
cxzhang4 Oct 11, 2024
43a8ffb
removed magrittr pipe from tests
cxzhang4 Oct 11, 2024
6b9a845
added details for callback class
cxzhang4 Oct 11, 2024
d354b2c
formatting
cxzhang4 Oct 11, 2024
b5b27b1
built docs
cxzhang4 Oct 11, 2024
565456b
Merge branch 'main' into feat/tflog-callback
cxzhang4 Oct 11, 2024
7c9f431
all tests pass, I think this is parity with the previous broken commi…
cxzhang4 Oct 11, 2024
c6c9333
implemented step logging
cxzhang4 Oct 11, 2024
43e7396
removed extraneous comments
cxzhang4 Oct 11, 2024
ec5d8fc
added tensorboard instructions
cxzhang4 Oct 11, 2024
f26a254
passes R CMD Check, minimally addresses every comment in the previous PR
cxzhang4 Oct 11, 2024
a86c946
moved newest news to bottom
cxzhang4 Oct 13, 2024
74757a7
logical -> flag, since the length of this arg must be 1
cxzhang4 Oct 15, 2024
72d23f4
"TensorBoard events" appears to be a more idiomatic phrase
cxzhang4 Oct 15, 2024
beaca43
map() -> walk(), since we don't use the return value of map()
cxzhang4 Oct 15, 2024
1bf2939
map() -> walk(), since we don't use the return value of map()
cxzhang4 Oct 15, 2024
03aad62
Apply suggestions from code review
cxzhang4 Oct 15, 2024
e209092
add a default value for the log_train_loss param
cxzhang4 Oct 15, 2024
694ea85
add package dependency
cxzhang4 Oct 15, 2024
a8f741d
" Now that we have a default no need to specify log_train_loss"
cxzhang4 Oct 15, 2024
795b3f1
better test description
cxzhang4 Oct 15, 2024
6b96aed
Apply suggestions from code review
cxzhang4 Oct 15, 2024
2113faf
unlist(map -> map_lgl
cxzhang4 Oct 15, 2024
afbb074
removed library import in test file, added tfevents:: when we use a f…
cxzhang4 Oct 16, 2024
074f73b
removed extra TODO, don't bind to variables when the value is only us…
cxzhang4 Oct 16, 2024
5f7d56f
remove more unnecessary vars, increase batch size to make tests run f…
cxzhang4 Oct 16, 2024
a1ddb7d
access path field of callback, ifelse -> if
cxzhang4 Oct 16, 2024
3957614
remove old todo
cxzhang4 Oct 18, 2024
c95d241
Merge branch 'main' into feat/tflog-callback
sebffischer Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ Authors@R:
family = "Pfisterer",
role = "ctb",
email = "[email protected]",
comment = c(ORCID = "0000-0001-8867-762X")))
comment = c(ORCID = "0000-0001-8867-762X")),
person(given = "Carson",
family = "Zhang",
role = "ctb",
email = "[email protected]")
)
Description: Deep Learning library that extends the mlr3 framework by building
upon the 'torch' package. It allows to conveniently build, train,
and evaluate deep learning models without having to worry about low level
Expand Down Expand Up @@ -64,6 +69,7 @@ Suggests:
viridis,
visNetwork,
testthat (>= 3.0.0),
tfevents,
torchvision (>= 0.6.0),
waldo
Config/testthat/edition: 3
Expand All @@ -80,6 +86,7 @@ Collate:
'CallbackSetEarlyStopping.R'
'CallbackSetHistory.R'
'CallbackSetProgress.R'
'CallbackSetTB.R'
'ContextTorch.R'
'DataBackendLazy.R'
'utils.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export(CallbackSet)
export(CallbackSetCheckpoint)
export(CallbackSetHistory)
export(CallbackSetProgress)
export(CallbackSetTB)
export(ContextTorch)
export(DataBackendLazy)
export(DataDescriptor)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# mlr3torch dev

* Don't use deprecated `data_formats` anymore
* Added `CallbackSetTB`, which allows logging that can be viewed by TensorBoard.

# mlr3torch 0.1.1

Expand Down
85 changes: 85 additions & 0 deletions R/CallbackSetTB.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#' @title TensorBoard Logging Callback
#'
#' @name mlr_callback_set.tb
#'
#' @description
#' Logs training loss, training measures, and validation measures as events.
#' To view them, use TensorBoard with `tensorflow::tensorboard()` (requires `tensorflow`) or the CLI.
#' @details
#' Logs events at most every epoch.
#'
#' @param path (`character(1)`)\cr
#' The path to a folder where the events are logged.
#' Point TensorBoard to this folder to view them.
#' @param log_train_loss (`logical(1)`)\cr
#' Whether we log the training loss.
#' @family Callback
#' @export
#' @include CallbackSet.R
CallbackSetTB = R6Class("CallbackSetTB",
inherit = CallbackSet,
lock_objects = FALSE,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(path, log_train_loss) {
self$path = assert_path_for_output(path)
if (!dir.exists(path)) {
dir.create(path, recursive = TRUE)
}
self$log_train_loss = assert_logical(log_train_loss)
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
},
#' @description
#' Logs the training loss, training measures, and validation measures as TensorFlow events.
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
on_epoch_end = function() {
if (self$log_train_loss) {
private$.log_train_loss()
}

if (length(self$ctx$last_scores_train)) {
map(names(self$ctx$measures_train), private$.log_train_score)
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
}

if (length(self$ctx$last_scores_valid)) {
map(names(self$ctx$measures_valid), private$.log_valid_score)
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
}
}
),
private = list(
.log_score = function(prefix, measure_name, score) {
event_list = list(score, self$ctx$epoch)
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
names(event_list) = c(paste0(prefix, measure_name), "step")
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved

with_logdir(self$path, {
do.call(log_event, event_list)
})
},
.log_valid_score = function(measure_name) {
valid_score = self$ctx$last_scores_valid[[measure_name]]
private$.log_score("valid.", measure_name, valid_score)
},
.log_train_score = function(measure_name) {
train_score = self$ctx$last_scores_train[[measure_name]]
private$.log_score("train.", measure_name, train_score)
},
.log_train_loss = function() {
with_logdir(self$path, {
log_event(train.loss = self$ctx$last_loss)
})
}
)
)

#' @include TorchCallback.R
mlr3torch_callbacks$add("tb", function() {
TorchCallback$new(
callback_generator = CallbackSetTB,
param_set = ps(
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
path = p_uty(tags = c("train", "required")),
log_train_loss = p_lgl(tags = c("train", "required"))
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
),
id = "tb",
cxzhang4 marked this conversation as resolved.
Show resolved Hide resolved
label = "TensorBoard",
man = "mlr3torch::mlr_callback_set.tb"
)
})
1 change: 1 addition & 0 deletions man/TorchCallback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/as_torch_callback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/as_torch_callbacks.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/callback_set.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr3torch-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr3torch_callbacks.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_callback_set.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_callback_set.checkpoint.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_callback_set.progress.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

98 changes: 98 additions & 0 deletions man/mlr_callback_set.tb.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_context_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/t_clbk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/torch_callback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading