From daa03bb5e8e51403d6b915c3a65d99ba714ecc3d Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Wed, 6 Nov 2024 16:47:01 +0100 Subject: [PATCH] feat: add new callback stage to result method (#263) * feat: add new callback stage to result method * ... * ... * ... * ... --- NEWS.md | 4 + R/CallbackAsync.R | 105 ++++++++++++------ R/CallbackBatch.R | 80 ++++++++++---- R/ContextAsync.R | 80 ++++++++++---- R/ContextBatch.R | 48 ++++++-- R/OptimInstance.R | 5 +- R/OptimInstanceAsync.R | 10 +- R/OptimInstanceAsyncMultiCrit.R | 33 ++++-- R/OptimInstanceAsyncSingleCrit.R | 33 ++++-- R/OptimInstanceBatch.R | 1 + R/OptimInstanceBatchMultiCrit.R | 33 ++++-- R/OptimInstanceBatchSingleCrit.R | 33 ++++-- R/OptimizerAsync.R | 2 +- man/CallbackAsync.Rd | 8 +- man/CallbackBatch.Rd | 8 +- man/ContextAsync.Rd | 31 ++++-- man/ContextBatch.Rd | 19 +++- man/callback_async.Rd | 30 ++++- man/callback_batch.Rd | 25 ++++- tests/testthat/test_CallbackAsync.R | 160 +++++++++++++++++++++++++-- tests/testthat/test_CallbackBatch.R | 166 +++++++++++++++++++++++++--- 21 files changed, 734 insertions(+), 180 deletions(-) diff --git a/NEWS.md b/NEWS.md index 5aad46a9..f32c412e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # bbotk (development version) +* feat: Add new stage `on_result_begin` to `CallbackAsyncTuning` and `CallbackBatchTuning`. +* refactor: Rename stage `on_result` to `on_result_end` in `CallbackAsyncTuning` and `CallbackBatchTuning`. +* docs: Extend the `CallbackAsyncTuning` and `CallbackBatchTuning` documentation. + # bbotk 1.2.0 * feat: `ContextBatch` and `ContextAsync` have a `result_extra` field now to access additional results passed to the instance. diff --git a/R/CallbackAsync.R b/R/CallbackAsync.R index efa04570..33d4e7c0 100644 --- a/R/CallbackAsync.R +++ b/R/CallbackAsync.R @@ -37,10 +37,15 @@ CallbackAsync = R6Class("CallbackAsync", #' Called in the worker loop. on_worker_end = NULL, - #' @field on_result (`function()`)\cr - #' Stage called after result are written. + #' @field on_result_begin (`function()`)\cr + #' Stage called before the results are written. #' Called in `OptimInstance$assign_result()`. - on_result = NULL, + on_result_begin = NULL, + + #' @field on_result_end (`function()`)\cr + #' Stage called after the results are written. + #' Called in `OptimInstance$assign_result()`. + on_result_end = NULL, #' @field on_optimization_end (`function()`)\cr #' Stage called at the end of the optimization in the main process. @@ -68,7 +73,8 @@ CallbackAsync = R6Class("CallbackAsync", #' End Optimization on Worker #' - on_worker_end #' End Worker -#' - on_result +#' - on_result_begin +#' - on_result_end #' - on_optimization_end #' End Optimization #' ``` @@ -81,40 +87,56 @@ CallbackAsync = R6Class("CallbackAsync", #' The [ContextAsync] allows to modify the instance, archive, optimizer and final result. #' #' @param id (`character(1)`)\cr -#' Identifier for the new instance. +#' Identifier for the new instance. #' @param label (`character(1)`)\cr -#' Label for the new instance. +#' Label for the new instance. #' @param man (`character(1)`)\cr -#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. -#' The referenced help package can be opened via method `$help()`. +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. +#' #' @param on_optimization_begin (`function()`)\cr -#' Stage called at the beginning of the optimization in the main process. -#' Called in `Optimizer$optimize()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called at the beginning of the optimization in the main process. +#' Called in `Optimizer$optimize()`. +#' The functions must have two arguments named `callback` and `context`. #' @param on_worker_begin (`function()`)\cr -#' Stage called at the beginning of the optimization on the worker. -#' Called in the worker loop. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called at the beginning of the optimization on the worker. +#' Called in the worker loop. +#' The functions must have two arguments named `callback` and `context`. #' @param on_optimizer_before_eval (`function()`)\cr -#' Stage called after the optimizer proposes points. -#' Called in `OptimInstance$eval_point()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called after the optimizer proposes points. +#' Called in `OptimInstance$.eval_point()`. +#' The functions must have two arguments named `callback` and `context`. +#' The argument of `instance$.eval_point(xs)` and `xs_trafoed` and `extra` are available in the `context`. +#' Or `xs` and `xs_trafoed` of `instance$.eval_queue()` are available in the `context`. #' @param on_optimizer_after_eval (`function()`)\cr -#' Stage called after points are evaluated. -#' Called in `OptimInstance$eval_point()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called after points are evaluated. +#' Called in `OptimInstance$.eval_point()`. +#' The functions must have two arguments named `callback` and `context`. +#' The outcome `y` is available in the `context`. #' @param on_worker_end (`function()`)\cr -#' Stage called at the end of the optimization on the worker. -#' Called in the worker loop. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called at the end of the optimization on the worker. +#' Called in the worker loop. +#' The functions must have two arguments named `callback` and `context`. +#' @param on_result_begin (`function()`)\cr +#' Stage called before result are written. +#' Called in `OptimInstance$assign_result()`. +#' The functions must have two arguments named `callback` and `context`. +#' The arguments of `$.assign_result(xdt, y, extra)` are available in the `context`. +#' @param on_result_end (`function()`)\cr +#' Stage called after result are written. +#' Called in `OptimInstance$assign_result()`. +#' The functions must have two arguments named `callback` and `context`. +#' The final result `instance$result` is available in the `context`. #' @param on_result (`function()`)\cr -#' Stage called after result are written. -#' Called in `OptimInstance$assign_result()`. -#' The functions must have two arguments named `callback` and `context`. +#' Deprecated. +#' Use `on_result_end` instead. +#' Stage called after result are written. +#' Called in `OptimInstance$assign_result()`. +#' The functions must have two arguments named `callback` and `context`. #' @param on_optimization_end (`function()`)\cr -#' Stage called at the end of the optimization in the main process. -#' Called in `Optimizer$optimize()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called at the end of the optimization in the main process. +#' Called in `Optimizer$optimize()`. +#' The functions must have two arguments named `callback` and `context`. #' #' @export callback_async = function( @@ -126,6 +148,8 @@ callback_async = function( on_optimizer_before_eval = NULL, on_optimizer_after_eval = NULL, on_worker_end = NULL, + on_result_begin = NULL, + on_result_end = NULL, on_result = NULL, on_optimization_end = NULL ) { @@ -135,15 +159,26 @@ callback_async = function( on_optimizer_before_eval, on_optimizer_after_eval, on_worker_end, + on_result_begin, + on_result_end, on_result, on_optimization_end), c("on_optimization_begin", - "on_worker_begin", - "on_optimizer_before_eval", - "on_optimizer_after_eval", - "on_worker_end", - "on_result", - "on_optimization_end")), is.null) + "on_worker_begin", + "on_optimizer_before_eval", + "on_optimizer_after_eval", + "on_worker_end", + "on_result_begin", + "on_result_end", + "on_result", + "on_optimization_end")), is.null) + + if ("on_result" %in% names(stages)) { + .Deprecated(old = "on_result", new = "on_result_end") + stages$on_result_end = stages$on_result + stages$on_result = NULL + } + walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) callback = CallbackAsync$new(id, label, man) iwalk(stages, function(stage, name) callback[[name]] = stage) diff --git a/R/CallbackBatch.R b/R/CallbackBatch.R index 9d9c9394..e62a13dc 100644 --- a/R/CallbackBatch.R +++ b/R/CallbackBatch.R @@ -34,10 +34,15 @@ CallbackBatch = R6Class("CallbackBatch", #' Called in `OptimInstance$eval_batch()`. on_optimizer_after_eval = NULL, - #' @field on_result (`function()`)\cr - #' Stage called after result are written. + #' @field on_result_begin (`function()`)\cr + #' Stage called before the results are written. #' Called in `OptimInstance$assign_result()`. - on_result = NULL, + on_result_begin = NULL, + + #' @field on_result_end (`function()`)\cr + #' Stage called after the results are written. + #' Called in `OptimInstance$assign_result()`. + on_result_end = NULL, #' @field on_optimization_end (`function()`)\cr #' Stage called at the end of the optimization. @@ -61,7 +66,8 @@ CallbackBatch = R6Class("CallbackBatch", #' - on_optimizer_before_eval #' - on_optimizer_after_eval #' End Optimizer Batch -#' - on_result +#' - on_result_begin +#' - on_result_end #' - on_optimization_end #' End Optimization #' ``` @@ -75,32 +81,47 @@ CallbackBatch = R6Class("CallbackBatch", #' #' #' @param id (`character(1)`)\cr -#' Identifier for the new instance. +#' Identifier for the new instance. #' @param label (`character(1)`)\cr -#' Label for the new instance. +#' Label for the new instance. #' @param man (`character(1)`)\cr -#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. -#' The referenced help package can be opened via method `$help()`. +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. +#' #' @param on_optimization_begin (`function()`)\cr -#' Stage called at the beginning of the optimization. -#' Called in `Optimizer$optimize()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called at the beginning of the optimization. +#' Called in `Optimizer$optimize()`. +#' The functions must have two arguments named `callback` and `context`. #' @param on_optimizer_before_eval (`function()`)\cr -#' Stage called after the optimizer proposes points. -#' Called in `OptimInstance$eval_batch()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called after the optimizer proposes points. +#' Called in `OptimInstance$eval_batch()`. +#' The functions must have two arguments named `callback` and `context`. +#' The argument of `$eval_batch(xdt)` is available in `context`. #' @param on_optimizer_after_eval (`function()`)\cr -#' Stage called after points are evaluated. -#' Called in `OptimInstance$eval_batch()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called after points are evaluated. +#' Called in `OptimInstance$eval_batch()`. +#' The functions must have two arguments named `callback` and `context`. +#' The new points and outcomes in `instance$archive` are available in `context`. +#' @param on_result_begin (`function()`)\cr +#' Stage called before result are written to the instance. +#' Called in `OptimInstance$assign_result()`. +#' The functions must have two arguments named `callback` and `context`. +#' The arguments of `$assign_result(xdt, y, extra)` are available in `context`. +#' @param on_result_end (`function()`)\cr +#' Stage called after result are written to the instance. +#' Called in `OptimInstance$assign_result()`. +#' The functions must have two arguments named `callback` and `context`. +#' The final result `instance$result` is available in `context`. #' @param on_result (`function()`)\cr -#' Stage called after result are written. -#' Called in `OptimInstance$assign_result()`. -#' The functions must have two arguments named `callback` and `context`. +#' Deprecated. +#' Use `on_result_end` instead. +#' Stage called after result are written. +#' Called in `OptimInstance$assign_result()`. +#' The functions must have two arguments named `callback` and `context`. #' @param on_optimization_end (`function()`)\cr -#' Stage called at the end of the optimization. -#' Called in `Optimizer$optimize()`. -#' The functions must have two arguments named `callback` and `context`. +#' Stage called at the end of the optimization. +#' Called in `Optimizer$optimize()`. +#' The functions must have two arguments named `callback` and `context`. #' #' @export #' @inherit CallbackBatch examples @@ -111,6 +132,8 @@ callback_batch = function( on_optimization_begin = NULL, on_optimizer_before_eval = NULL, on_optimizer_after_eval = NULL, + on_result_begin = NULL, + on_result_end = NULL, on_result = NULL, on_optimization_end = NULL ) { @@ -118,14 +141,25 @@ callback_batch = function( on_optimization_begin, on_optimizer_before_eval, on_optimizer_after_eval, + on_result_begin, + on_result_end, on_result, on_optimization_end), c( "on_optimization_begin", "on_optimizer_before_eval", "on_optimizer_after_eval", + "on_result_begin", + "on_result_end", "on_result", "on_optimization_end")), is.null) + + if ("on_result" %in% names(stages)) { + .Deprecated(old = "on_result", new = "on_result_end") + stages$on_result_end = stages$on_result + stages$on_result = NULL + } + walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) callback = CallbackBatch$new(id, label, man) iwalk(stages, function(stage, name) callback[[name]] = stage) diff --git a/R/ContextAsync.R b/R/ContextAsync.R index 47d455a0..1330b061 100644 --- a/R/ContextAsync.R +++ b/R/ContextAsync.R @@ -32,28 +32,8 @@ ContextAsync = R6Class("ContextAsync", active = list( - #' @field result ([data.table::data.table])\cr - #' The result of the optimization. - result = function(rhs) { - if (missing(rhs)) { - get_private(self$instance)$.result - } else { - get_private(self$instance, ".result") = rhs - } - }, - - #' @field result_extra ([data.table::data.table])\cr - #' Additional information about the result. - result_extra = function(rhs) { - if (missing(rhs)) { - get_private(self$instance)$.result_extra - } else { - get_private(self$instance, ".result_extra") = rhs - } - }, - #' @field xs (list())\cr - #' The point to be evaluated. + #' The point to be evaluated in `instance$.eval_point()`. xs = function(rhs) { if (missing(rhs)) { get_private(self$instance)$.xs @@ -63,7 +43,7 @@ ContextAsync = R6Class("ContextAsync", }, #' @field xs_trafoed (list())\cr - #' The transformed point to be evaluated. + #' The transformed point to be evaluated in `instance$.eval_point()`. xs_trafoed = function(rhs) { if (missing(rhs)) { get_private(self$instance)$.xs_trafoed @@ -73,7 +53,7 @@ ContextAsync = R6Class("ContextAsync", }, #' @field extra (list())\cr - #' Additional information. + #' Additional information of the point to be evaluated in `instance$.eval_point()`. extra = function(rhs) { if (missing(rhs)) { get_private(self$instance)$.extra @@ -83,13 +63,65 @@ ContextAsync = R6Class("ContextAsync", }, #' @field ys (list())\cr - #' The result of the evaluation. + #' The result of the evaluation in `instance$.eval_point()`. ys = function(rhs) { if (missing(rhs)) { get_private(self$instance)$.ys } else { get_private(self$instance, ".ys") = rhs } + }, + + #' @field result_xdt ([data.table::data.table])\cr + #' The xdt passed to `instance$assign_result()`. + result_xdt = function(rhs) { + if (missing(rhs)) { + return(get_private(self$instance)$.result_xdt) + } else { + self$instance$.__enclos_env__$private$.result_xdt = rhs + } + }, + + #' @field result_y (`numeric(1)`)\cr + #' The y passed to `instance$assign_result()`. + #' Only available for single criterion optimization. + result_y = function(rhs) { + if (missing(rhs)) { + return(get_private(self$instance)$.result_y) + } else { + self$instance$.__enclos_env__$private$.result_y = rhs + } + }, + + #' @field result_ydt ([data.table::data.table])\cr + #' The ydt passed to `instance$assign_result()`. + #' Only available for multi criterion optimization. + result_ydt = function(rhs) { + if (missing(rhs)) { + return(get_private(self$instance)$.result_ydt) + } else { + self$instance$.__enclos_env__$private$.result_ydt = rhs + } + }, + + #' @field result_extra ([data.table::data.table])\cr + #' Additional information about the result passed to `instance$assign_result()`. + result_extra = function(rhs) { + if (missing(rhs)) { + get_private(self$instance)$.result_extra + } else { + get_private(self$instance, ".result_extra") = rhs + } + }, + + #' @field result ([data.table::data.table])\cr + #' The result of the optimization in `instance$assign_result()`. + result = function(rhs) { + if (missing(rhs)) { + get_private(self$instance)$.result + } else { + get_private(self$instance, ".result") = rhs + } } ) ) diff --git a/R/ContextBatch.R b/R/ContextBatch.R index f262b9f9..97096d5a 100644 --- a/R/ContextBatch.R +++ b/R/ContextBatch.R @@ -30,8 +30,8 @@ ContextBatch = R6Class("ContextBatch", active = list( #' @field xdt ([data.table::data.table])\cr - #' The points of the latest batch. - #' Contains the values in the search space i.e. transformations are not yet applied. + #' The points of the latest batch in `instance$eval_batch()`. + #' Contains the values in the search space i.e. transformations are not yet applied. xdt = function(rhs) { if (missing(rhs)) { get_private(self$instance)$.xdt @@ -40,24 +40,56 @@ ContextBatch = R6Class("ContextBatch", } }, - #' @field result ([data.table::data.table])\cr - #' The result of the optimization. - result = function(rhs) { + #' @field result_xdt ([data.table::data.table])\cr + #' The xdt passed to `instance$assign_result()`. + result_xdt = function(rhs) { if (missing(rhs)) { - get_private(self$instance)$.result + return(get_private(self$instance)$.result_xdt) } else { - get_private(self$instance, ".result") = rhs + self$instance$.__enclos_env__$private$.result_xdt = rhs + } + }, + + #' @field result_y (`numeric(1)`)\cr + #' The y passed to `instance$assign_result()`. + #' Only available for single criterion optimization. + result_y = function(rhs) { + if (missing(rhs)) { + return(get_private(self$instance)$.result_y) + } else { + self$instance$.__enclos_env__$private$.result_y = rhs + } + }, + + #' @field result_ydt ([data.table::data.table])\cr + #' The ydt passed to `instance$assign_result()`. + #' Only available for multi criterion optimization. + result_ydt = function(rhs) { + if (missing(rhs)) { + return(get_private(self$instance)$.result_ydt) + } else { + self$instance$.__enclos_env__$private$.result_ydt = rhs } }, #' @field result_extra ([data.table::data.table])\cr - #' Additional information about the result. + #' Additional information about the result passed to `instance$assign_result()`. result_extra = function(rhs) { if (missing(rhs)) { get_private(self$instance)$.result_extra } else { get_private(self$instance, ".result_extra") = rhs } + }, + + #' @field result ([data.table::data.table])\cr + #' The result of the optimization in `instance$assign_result()`. + result = function(rhs) { + if (missing(rhs)) { + get_private(self$instance)$.result + } else { + get_private(self$instance, ".result") = rhs + } } ) ) diff --git a/R/OptimInstance.R b/R/OptimInstance.R index ab7d5543..689ce7a0 100644 --- a/R/OptimInstance.R +++ b/R/OptimInstance.R @@ -154,8 +154,11 @@ OptimInstance = R6Class("OptimInstance", ), private = list( - .result = NULL, + # intermediate objects + .result_xdt = NULL, .result_extra = NULL, + .result = NULL, + .label = NULL, .man = NULL, diff --git a/R/OptimInstanceAsync.R b/R/OptimInstanceAsync.R index 017ff6f9..0d37565d 100644 --- a/R/OptimInstanceAsync.R +++ b/R/OptimInstanceAsync.R @@ -94,6 +94,7 @@ OptimInstanceAsync = R6Class("OptimInstanceAsync", ), private = list( + # intermediate objects .xs = NULL, .xs_trafoed = NULL, .extra = NULL, @@ -123,17 +124,18 @@ OptimInstanceAsync = R6Class("OptimInstanceAsync", while (!self$is_terminated && self$archive$n_queued) { task = self$archive$pop_point() if (!is.null(task)) { - private$.xs = task$xs - # transpose point + private$.xs = task$xs private$.xs_trafoed = trafo_xs(private$.xs, self$search_space) - # eval call_back("on_optimizer_before_eval", self$objective$callbacks, self$objective$context) + + # eval private$.ys = self$objective$eval(private$.xs_trafoed) - # push reuslt call_back("on_optimizer_after_eval", self$objective$callbacks, self$objective$context) + + # push reuslt self$archive$push_result(task$key, private$.ys, x_domain = private$.xs_trafoed) } } diff --git a/R/OptimInstanceAsyncMultiCrit.R b/R/OptimInstanceAsyncMultiCrit.R index 70902a8e..244e32a4 100644 --- a/R/OptimInstanceAsyncMultiCrit.R +++ b/R/OptimInstanceAsyncMultiCrit.R @@ -56,16 +56,26 @@ OptimInstanceAsyncMultiCrit = R6Class("OptimInstanceAsyncMultiCrit", #' @param ... (`any`)\cr #' ignored. assign_result = function(xdt, ydt, extra = NULL, ...) { - # FIXME: We could have one way that just lets us put a 1xn DT as result directly. - assert_data_table(xdt) - assert_names(names(xdt), must.include = self$search_space$ids()) - assert_data_table(ydt) - assert_names(names(ydt), permutation.of = self$objective$codomain$ids()) - private$.result_extra = assert_data_table(extra, null.ok = TRUE) - x_domain = transform_xdt_to_xss(xdt, self$search_space) + # assign for callbacks + private$.result_xdt = xdt + private$.result_ydt = ydt + private$.result_extra = extra + + call_back("on_result_begin", self$objective$callbacks, self$objective$context) + + # assert inputs + assert_data_table(private$.result_xdt) + assert_names(names(private$.result_xdt), must.include = self$search_space$ids()) + assert_data_table(private$.result_ydt) + assert_names(names(private$.result_ydt), permutation.of = self$objective$codomain$ids()) + assert_data_table(private$.result_extra, null.ok = TRUE) + + # add x_domain to result + x_domain = transform_xdt_to_xss(private$.result_xdt, self$search_space) if (length(x_domain) == 0) x_domain = list(list()) - private$.result = cbind(xdt, x_domain = x_domain, ydt) - call_back("on_result", self$objective$callbacks, self$objective$context) + + private$.result = cbind(private$.result_xdt, x_domain = x_domain, private$.result_ydt) + call_back("on_result_end", self$objective$callbacks, self$objective$context) } ), @@ -82,5 +92,10 @@ OptimInstanceAsyncMultiCrit = R6Class("OptimInstanceAsyncMultiCrit", result_y = function() { private$.result[, self$objective$codomain$ids(), with = FALSE] } + ), + + private = list( + # intermediate objects + .result_ydt = NULL ) ) diff --git a/R/OptimInstanceAsyncSingleCrit.R b/R/OptimInstanceAsyncSingleCrit.R index 85194054..df833b61 100644 --- a/R/OptimInstanceAsyncSingleCrit.R +++ b/R/OptimInstanceAsyncSingleCrit.R @@ -57,16 +57,26 @@ OptimInstanceAsyncSingleCrit = R6Class("OptimInstanceAsyncSingleCrit", #' @param ... (`any`)\cr #' ignored. assign_result = function(xdt, y, extra = NULL, ...) { - # FIXME: We could have one way that just lets us put a 1xn DT as result directly. - assert_data_table(xdt) - assert_names(names(xdt), must.include = self$search_space$ids()) - assert_number(y) - assert_names(names(y), permutation.of = self$objective$codomain$target_ids) - private$.result_extra = assert_data_table(extra, null.ok = TRUE) - x_domain = unlist(transform_xdt_to_xss(xdt, self$search_space), recursive = FALSE) + # assign for callbacks + private$.result_xdt = xdt + private$.result_y = y + private$.result_extra = extra + + call_back("on_result_begin", self$objective$callbacks, self$objective$context) + + # assert inputs + assert_names(names(private$.result_xdt), must.include = self$search_space$ids()) + assert_data_table(private$.result_xdt) + assert_number(private$.result_y) + assert_names(names(private$.result_y), permutation.of = self$objective$codomain$target_ids) + assert_data_table(private$.result_extra, null.ok = TRUE) + + # add x_domain to result + x_domain = unlist(transform_xdt_to_xss(private$.result_xdt, self$search_space), recursive = FALSE) if (is.null(x_domain)) x_domain = list() - private$.result = cbind(xdt, x_domain = list(x_domain), t(y)) # t(y) so the name of y stays - call_back("on_result", self$objective$callbacks, self$objective$context) + + private$.result = cbind(private$.result_xdt, x_domain = list(x_domain), t(private$.result_y)) # t(y) so the name of y stays + call_back("on_result_end", self$objective$callbacks, self$objective$context) } ), @@ -83,5 +93,10 @@ OptimInstanceAsyncSingleCrit = R6Class("OptimInstanceAsyncSingleCrit", result_y = function() { unlist(private$.result[, self$objective$codomain$ids(), with = FALSE]) } + ), + + private = list( + # intermediate objects + .result_y = NULL ) ) diff --git a/R/OptimInstanceBatch.R b/R/OptimInstanceBatch.R index 2f07e471..bd416a42 100644 --- a/R/OptimInstanceBatch.R +++ b/R/OptimInstanceBatch.R @@ -156,6 +156,7 @@ OptimInstanceBatch = R6Class("OptimInstanceBatch", ), private = list( + # intermediate objects .xdt = NULL, .objective_function = NULL, diff --git a/R/OptimInstanceBatchMultiCrit.R b/R/OptimInstanceBatchMultiCrit.R index 393a0807..7c4bde46 100644 --- a/R/OptimInstanceBatchMultiCrit.R +++ b/R/OptimInstanceBatchMultiCrit.R @@ -50,16 +50,26 @@ OptimInstanceBatchMultiCrit = R6Class("OptimInstanceBatchMultiCrit", #' @param ... (`any`)\cr #' ignored. assign_result = function(xdt, ydt, extra = NULL, ...) { - # FIXME: We could have one way that just lets us put a 1xn DT as result directly. - assert_data_table(xdt) - assert_names(names(xdt), must.include = self$search_space$ids()) - assert_data_table(ydt) - assert_names(names(ydt), permutation.of = self$objective$codomain$ids()) - private$.result_extra = assert_data_table(extra, null.ok = TRUE) - x_domain = transform_xdt_to_xss(xdt, self$search_space) + # assign for callbacks + private$.result_xdt = xdt + private$.result_ydt = ydt + private$.result_extra = extra + + call_back("on_result_begin", self$objective$callbacks, self$objective$context) + + # assert inputs + assert_data_table(private$.result_xdt) + assert_names(names(private$.result_xdt), must.include = self$search_space$ids()) + assert_data_table(private$.result_ydt) + assert_names(names(private$.result_ydt), permutation.of = self$objective$codomain$ids()) + assert_data_table(private$.result_extra, null.ok = TRUE) + + # add x_domain to result + x_domain = transform_xdt_to_xss(private$.result_xdt, self$search_space) if (length(x_domain) == 0) x_domain = list(list()) - private$.result = cbind(xdt, x_domain = x_domain, ydt) - call_back("on_result", self$objective$callbacks, self$objective$context) + + private$.result = cbind(private$.result_xdt, x_domain = x_domain, private$.result_ydt) + call_back("on_result_end", self$objective$callbacks, self$objective$context) } ), @@ -75,5 +85,10 @@ OptimInstanceBatchMultiCrit = R6Class("OptimInstanceBatchMultiCrit", result_y = function() { private$.result[, self$objective$codomain$ids(), with = FALSE] } + ), + + private = list( + # intermediate objects + .result_ydt = NULL ) ) diff --git a/R/OptimInstanceBatchSingleCrit.R b/R/OptimInstanceBatchSingleCrit.R index 598f3658..0331fbaa 100644 --- a/R/OptimInstanceBatchSingleCrit.R +++ b/R/OptimInstanceBatchSingleCrit.R @@ -55,16 +55,31 @@ OptimInstanceBatchSingleCrit = R6Class("OptimInstanceBatchSingleCrit", #' @param ... (`any`)\cr #' ignored. assign_result = function(xdt, y, extra = NULL, ...) { - # FIXME: We could have one way that just lets us put a 1xn DT as result directly. - assert_data_table(xdt) - assert_names(names(xdt), must.include = self$search_space$ids()) - assert_number(y) - assert_names(names(y), permutation.of = self$objective$codomain$target_ids) - private$.result_extra = assert_data_table(extra, null.ok = TRUE) - x_domain = unlist(transform_xdt_to_xss(xdt, self$search_space), recursive = FALSE) + # assign for callbacks + private$.result_xdt = xdt + private$.result_y = y + private$.result_extra = extra + + call_back("on_result_begin", self$objective$callbacks, self$objective$context) + + # assert inputs + assert_names(names(private$.result_xdt), must.include = self$search_space$ids()) + assert_data_table(private$.result_xdt) + assert_number(private$.result_y) + assert_names(names(private$.result_y), permutation.of = self$objective$codomain$target_ids) + assert_data_table(private$.result_extra, null.ok = TRUE) + + # add x_domain to result + x_domain = unlist(transform_xdt_to_xss(private$.result_xdt, self$search_space), recursive = FALSE) if (is.null(x_domain)) x_domain = list() - private$.result = cbind(xdt, x_domain = list(x_domain), t(y)) # t(y) so the name of y stays - call_back("on_result", self$objective$callbacks, self$objective$context) + + private$.result = cbind(private$.result_xdt, x_domain = list(x_domain), t(private$.result_y)) # t(y) so the name of y stays + call_back("on_result_end", self$objective$callbacks, self$objective$context) } + ), + + private = list( + # intermediate objects + .result_y = NULL ) ) diff --git a/R/OptimizerAsync.R b/R/OptimizerAsync.R index 6b9d9eff..b8f7a94a 100644 --- a/R/OptimizerAsync.R +++ b/R/OptimizerAsync.R @@ -141,7 +141,7 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers get_private(optimizer)$.assign_result(instance) lg$info("Finished optimizing after %i evaluation(s)", instance$archive$n_evals) lg$info("Result:") - lg$info(capture.output(print(instance$result, lass = FALSE, row.names = FALSE, print.keys = FALSE))) + lg$info(capture.output(print(instance$result, class = FALSE, row.names = FALSE, print.keys = FALSE))) call_back("on_optimization_end", instance$objective$callbacks, instance$objective$context) return(instance$result) diff --git a/man/CallbackAsync.Rd b/man/CallbackAsync.Rd index e1c29934..26526f46 100644 --- a/man/CallbackAsync.Rd +++ b/man/CallbackAsync.Rd @@ -36,8 +36,12 @@ Called in \code{OptimInstance$.eval_point()}.} Stage called at the end of the optimization on the worker. Called in the worker loop.} -\item{\code{on_result}}{(\verb{function()})\cr -Stage called after result are written. +\item{\code{on_result_begin}}{(\verb{function()})\cr +Stage called before the results are written. +Called in \code{OptimInstance$assign_result()}.} + +\item{\code{on_result_end}}{(\verb{function()})\cr +Stage called after the results are written. Called in \code{OptimInstance$assign_result()}.} \item{\code{on_optimization_end}}{(\verb{function()})\cr diff --git a/man/CallbackBatch.Rd b/man/CallbackBatch.Rd index 5f414ec2..8e671472 100644 --- a/man/CallbackBatch.Rd +++ b/man/CallbackBatch.Rd @@ -36,8 +36,12 @@ Called in \code{OptimInstance$eval_batch()}.} Stage called after points are evaluated. Called in \code{OptimInstance$eval_batch()}.} -\item{\code{on_result}}{(\verb{function()})\cr -Stage called after result are written. +\item{\code{on_result_begin}}{(\verb{function()})\cr +Stage called before the results are written. +Called in \code{OptimInstance$assign_result()}.} + +\item{\code{on_result_end}}{(\verb{function()})\cr +Stage called after the results are written. Called in \code{OptimInstance$assign_result()}.} \item{\code{on_optimization_end}}{(\verb{function()})\cr diff --git a/man/ContextAsync.Rd b/man/ContextAsync.Rd index 4b125fcc..9b0398c2 100644 --- a/man/ContextAsync.Rd +++ b/man/ContextAsync.Rd @@ -26,23 +26,34 @@ Changes to \verb{$instance} and \verb{$optimizer} in the stages executed on the \section{Active bindings}{ \if{html}{\out{
}} \describe{ -\item{\code{result}}{(\link[data.table:data.table]{data.table::data.table})\cr -The result of the optimization.} - -\item{\code{result_extra}}{(\link[data.table:data.table]{data.table::data.table})\cr -Additional information about the result.} - \item{\code{xs}}{(list())\cr -The point to be evaluated.} +The point to be evaluated in \code{instance$.eval_point()}.} \item{\code{xs_trafoed}}{(list())\cr -The transformed point to be evaluated.} +The transformed point to be evaluated in \code{instance$.eval_point()}.} \item{\code{extra}}{(list())\cr -Additional information.} +Additional information of the point to be evaluated in \code{instance$.eval_point()}.} \item{\code{ys}}{(list())\cr -The result of the evaluation.} +The result of the evaluation in \code{instance$.eval_point()}.} + +\item{\code{result_xdt}}{(\link[data.table:data.table]{data.table::data.table})\cr +The xdt passed to \code{instance$assign_result()}.} + +\item{\code{result_y}}{(\code{numeric(1)})\cr +The y passed to \code{instance$assign_result()}. +Only available for single criterion optimization.} + +\item{\code{result_ydt}}{(\link[data.table:data.table]{data.table::data.table})\cr +The ydt passed to \code{instance$assign_result()}. +Only available for multi criterion optimization.} + +\item{\code{result_extra}}{(\link[data.table:data.table]{data.table::data.table})\cr +Additional information about the result passed to \code{instance$assign_result()}.} + +\item{\code{result}}{(\link[data.table:data.table]{data.table::data.table})\cr +The result of the optimization in \code{instance$assign_result()}.} } \if{html}{\out{
}} } diff --git a/man/ContextBatch.Rd b/man/ContextBatch.Rd index 2c1b42f5..f7b28389 100644 --- a/man/ContextBatch.Rd +++ b/man/ContextBatch.Rd @@ -24,14 +24,25 @@ See \code{\link[=callback_batch]{callback_batch()}} for a list of stages which t \if{html}{\out{
}} \describe{ \item{\code{xdt}}{(\link[data.table:data.table]{data.table::data.table})\cr -The points of the latest batch. +The points of the latest batch in \code{instance$eval_batch()}. Contains the values in the search space i.e. transformations are not yet applied.} -\item{\code{result}}{(\link[data.table:data.table]{data.table::data.table})\cr -The result of the optimization.} +\item{\code{result_xdt}}{(\link[data.table:data.table]{data.table::data.table})\cr +The xdt passed to \code{instance$assign_result()}.} + +\item{\code{result_y}}{(\code{numeric(1)})\cr +The y passed to \code{instance$assign_result()}. +Only available for single criterion optimization.} + +\item{\code{result_ydt}}{(\link[data.table:data.table]{data.table::data.table})\cr +The ydt passed to \code{instance$assign_result()}. +Only available for multi criterion optimization.} \item{\code{result_extra}}{(\link[data.table:data.table]{data.table::data.table})\cr -Additional information about the result.} +Additional information about the result passed to \code{instance$assign_result()}.} + +\item{\code{result}}{(\link[data.table:data.table]{data.table::data.table})\cr +The result of the optimization in \code{instance$assign_result()}.} } \if{html}{\out{
}} } diff --git a/man/callback_async.Rd b/man/callback_async.Rd index 308b5989..65bcca59 100644 --- a/man/callback_async.Rd +++ b/man/callback_async.Rd @@ -13,6 +13,8 @@ callback_async( on_optimizer_before_eval = NULL, on_optimizer_after_eval = NULL, on_worker_end = NULL, + on_result_begin = NULL, + on_result_end = NULL, on_result = NULL, on_optimization_end = NULL ) @@ -40,20 +42,37 @@ The functions must have two arguments named \code{callback} and \code{context}.} \item{on_optimizer_before_eval}{(\verb{function()})\cr Stage called after the optimizer proposes points. -Called in \code{OptimInstance$eval_point()}. -The functions must have two arguments named \code{callback} and \code{context}.} +Called in \code{OptimInstance$.eval_point()}. +The functions must have two arguments named \code{callback} and \code{context}. +The argument of \code{instance$.eval_point(xs)} and \code{xs_trafoed} and \code{extra} are available in the \code{context}. +Or \code{xs} and \code{xs_trafoed} of \code{instance$.eval_queue()} are available in the \code{context}.} \item{on_optimizer_after_eval}{(\verb{function()})\cr Stage called after points are evaluated. -Called in \code{OptimInstance$eval_point()}. -The functions must have two arguments named \code{callback} and \code{context}.} +Called in \code{OptimInstance$.eval_point()}. +The functions must have two arguments named \code{callback} and \code{context}. +The outcome \code{y} is available in the \code{context}.} \item{on_worker_end}{(\verb{function()})\cr Stage called at the end of the optimization on the worker. Called in the worker loop. The functions must have two arguments named \code{callback} and \code{context}.} +\item{on_result_begin}{(\verb{function()})\cr +Stage called before result are written. +Called in \code{OptimInstance$assign_result()}. +The functions must have two arguments named \code{callback} and \code{context}. +The arguments of \verb{$.assign_result(xdt, y, extra)} are available in the \code{context}.} + +\item{on_result_end}{(\verb{function()})\cr +Stage called after result are written. +Called in \code{OptimInstance$assign_result()}. +The functions must have two arguments named \code{callback} and \code{context}. +The final result \code{instance$result} is available in the \code{context}.} + \item{on_result}{(\verb{function()})\cr +Deprecated. +Use \code{on_result_end} instead. Stage called after result are written. Called in \code{OptimInstance$assign_result()}. The functions must have two arguments named \code{callback} and \code{context}.} @@ -79,7 +98,8 @@ The stages are prefixed with \verb{on_*}. End Optimization on Worker - on_worker_end End Worker - - on_result + - on_result_begin + - on_result_end - on_optimization_end End Optimization }\if{html}{\out{}} diff --git a/man/callback_batch.Rd b/man/callback_batch.Rd index a654a5cd..d4855788 100644 --- a/man/callback_batch.Rd +++ b/man/callback_batch.Rd @@ -11,6 +11,8 @@ callback_batch( on_optimization_begin = NULL, on_optimizer_before_eval = NULL, on_optimizer_after_eval = NULL, + on_result_begin = NULL, + on_result_end = NULL, on_result = NULL, on_optimization_end = NULL ) @@ -34,14 +36,30 @@ The functions must have two arguments named \code{callback} and \code{context}.} \item{on_optimizer_before_eval}{(\verb{function()})\cr Stage called after the optimizer proposes points. Called in \code{OptimInstance$eval_batch()}. -The functions must have two arguments named \code{callback} and \code{context}.} +The functions must have two arguments named \code{callback} and \code{context}. +The argument of \verb{$eval_batch(xdt)} is available in \code{context}.} \item{on_optimizer_after_eval}{(\verb{function()})\cr Stage called after points are evaluated. Called in \code{OptimInstance$eval_batch()}. -The functions must have two arguments named \code{callback} and \code{context}.} +The functions must have two arguments named \code{callback} and \code{context}. +The new points and outcomes in \code{instance$archive} are available in \code{context}.} + +\item{on_result_begin}{(\verb{function()})\cr +Stage called before result are written to the instance. +Called in \code{OptimInstance$assign_result()}. +The functions must have two arguments named \code{callback} and \code{context}. +The arguments of \verb{$assign_result(xdt, y, extra)} are available in \code{context}.} + +\item{on_result_end}{(\verb{function()})\cr +Stage called after result are written to the instance. +Called in \code{OptimInstance$assign_result()}. +The functions must have two arguments named \code{callback} and \code{context}. +The final result \code{instance$result} is available in \code{context}.} \item{on_result}{(\verb{function()})\cr +Deprecated. +Use \code{on_result_end} instead. Stage called after result are written. Called in \code{OptimInstance$assign_result()}. The functions must have two arguments named \code{callback} and \code{context}.} @@ -63,7 +81,8 @@ The stages are prefixed with \verb{on_*}. - on_optimizer_before_eval - on_optimizer_after_eval End Optimizer Batch - - on_result + - on_result_begin + - on_result_end - on_optimization_end End Optimization }\if{html}{\out{}} diff --git a/tests/testthat/test_CallbackAsync.R b/tests/testthat/test_CallbackAsync.R index 3978201d..37239d28 100644 --- a/tests/testthat/test_CallbackAsync.R +++ b/tests/testthat/test_CallbackAsync.R @@ -1,3 +1,5 @@ +# stages in $optimize() -------------------------------------------------------- + test_that("on_optimization_begin works", { skip_on_cran() skip_if_not_installed("rush") @@ -23,6 +25,32 @@ test_that("on_optimization_begin works", { expect_equal(instance$terminator$param_set$values$n_evals, 20) }) +test_that("on_optimization_end works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + callback = callback_async(id = "test", + on_optimization_end = function(callback, context) { + context$instance$terminator$param_set$values$n_evals = 200 + } + ) + + rush::rush_plan(n_workers = 2) + instance = oi_async( + objective = OBJ_1D, + search_space = PS_1D, + terminator = trm("evals", n_evals = 10), + callbacks = list(callback) + ) + + optimizer = opt("async_random_search") + optimizer$optimize(instance) + expect_equal(instance$terminator$param_set$values$n_evals, 200) +}) + +# stager in worker_loop() ------------------------------------------------------ + test_that("on_worker_begin works", { skip_on_cran() skip_if_not_installed("rush") @@ -49,6 +77,7 @@ test_that("on_worker_begin works", { expect_subset(1, instance$archive$data$x) }) + test_that("on_worker_end works", { skip_on_cran() skip_if_not_installed("rush") @@ -75,6 +104,8 @@ test_that("on_worker_end works", { expect_subset(1, instance$archive$data$x) }) +# stages in $.eval_point() ----------------------------------------------------- + test_that("on_optimizer_before_eval and on_optimizer_after_eval works", { skip_on_cran() skip_if_not_installed("rush") @@ -107,14 +138,17 @@ test_that("on_optimizer_before_eval and on_optimizer_after_eval works", { expect_equal(unique(unlist(instance$archive$data$x_domain)), 0) }) -test_that("on_result works", { +# stages in $assign_result() in OptimInstanceAsyncSingleCrit ------------------- + +test_that("on_result_begin in OptimInstanceAsyncSingleCrit works", { skip_on_cran() skip_if_not_installed("rush") flush_redis() callback = callback_async(id = "test", - on_result = function(callback, context) { - context$result = 2 + on_result_begin = function(callback, context) { + context$result_xdt = data.table(x = 1) + context$result_y = c(y = 2) } ) @@ -129,17 +163,18 @@ test_that("on_result works", { optimizer = opt("async_random_search") optimizer$optimize(instance) - expect_equal(instance$result, 2) + expect_equal(instance$result$x, 1) + expect_equal(instance$result$y, 2) }) -test_that("on_optimization_end works", { +test_that("on_result_end in OptimInstanceAsyncSingleCrit works", { skip_on_cran() skip_if_not_installed("rush") flush_redis() callback = callback_async(id = "test", - on_optimization_end = function(callback, context) { - context$instance$terminator$param_set$values$n_evals = 200 + on_result_end = function(callback, context) { + context$result$y = 2 } ) @@ -153,5 +188,114 @@ test_that("on_optimization_end works", { optimizer = opt("async_random_search") optimizer$optimize(instance) - expect_equal(instance$terminator$param_set$values$n_evals, 200) + + expect_equal(instance$result$y, 2) +}) + +test_that("on_result in OptimInstanceAsyncSingleCrit works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + expect_warning({callback = callback_async(id = "test", + on_result = function(callback, context) { + context$result = 2 + } + )}, "deprecated") + + rush::rush_plan(n_workers = 2) + instance = oi_async( + objective = OBJ_1D, + search_space = PS_1D, + terminator = trm("evals", n_evals = 10), + callbacks = list(callback) + ) + + optimizer = opt("async_random_search") + optimizer$optimize(instance) + + expect_equal(instance$result, 2) }) + +# stages in $assign_result() in OptimInstanceAsyncMultiCrit -------------------- + +test_that("on_result_begin in OptimInstanceAsyncMultiCrit works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + callback = callback_async(id = "test", + on_result_begin = function(callback, context) { + context$result_xdt = data.table(x1 = 1, x2 = 1) + context$result_ydt = data.table(y1 = 2, y2 = 2) + } + ) + + rush::rush_plan(n_workers = 2) + instance = oi_async( + objective = OBJ_2D_2D, + search_space = PS_2D, + terminator = trm("evals", n_evals = 10), + callbacks = list(callback) + ) + + optimizer = opt("async_random_search") + optimizer$optimize(instance) + expect_equal(instance$result$x1, 1) + expect_equal(instance$result$x2, 1) + expect_equal(unique(instance$result$y1), 2) + expect_equal(unique(instance$result$y2), 2) +}) + +test_that("on_result_end in OptimInstanceAsyncMultiCrit works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + callback = callback_async(id = "test", + on_result_end = function(callback, context) { + set(context$result, j = "y1", value = 2) + set(context$result, j = "y2", value = 3) + } + ) + + rush::rush_plan(n_workers = 2) + instance = oi_async( + objective = OBJ_2D_2D, + search_space = PS_2D, + terminator = trm("evals", n_evals = 10), + callbacks = list(callback) + ) + + optimizer = opt("async_random_search") + optimizer$optimize(instance) + expect_equal(unique(instance$result$y1), 2) + expect_equal(unique(instance$result$y2), 3) +}) + +test_that("on_result in OptimInstanceAsyncMultiCrit works", { + skip_on_cran() + skip_if_not_installed("rush") + flush_redis() + + expect_warning({callback = callback_async(id = "test", + on_result = function(callback, context) { + set(context$result, j = "y1", value = 2) + set(context$result, j = "y2", value = 3) + } + )}, "deprecated") + + rush::rush_plan(n_workers = 2) + instance = oi_async( + objective = OBJ_2D_2D, + search_space = PS_2D, + terminator = trm("evals", n_evals = 10), + callbacks = list(callback) + ) + + optimizer = opt("async_random_search") + optimizer$optimize(instance) + expect_equal(unique(instance$result$y1), 2) + expect_equal(unique(instance$result$y2), 3) +}) + diff --git a/tests/testthat/test_CallbackBatch.R b/tests/testthat/test_CallbackBatch.R index a163b839..62a558f0 100644 --- a/tests/testthat/test_CallbackBatch.R +++ b/tests/testthat/test_CallbackBatch.R @@ -1,3 +1,5 @@ +# stages in $optimize() -------------------------------------------------------- + test_that("on_optimization_begin works", { callback = callback_batch(id = "test", on_optimization_begin = function(callback, context) { @@ -5,11 +7,11 @@ test_that("on_optimization_begin works", { } ) - instance = OptimInstanceBatchSingleCrit$new( + instance = oi( objective = OBJ_1D, search_space = PS_1D, terminator = trm("evals", n_evals = 10), - callbacks = list(callback) + callbacks = callback ) optimizer = opt("random_search") @@ -25,11 +27,11 @@ test_that("on_optimization_end works", { } ) - instance = OptimInstanceBatchSingleCrit$new( + instance = oi( objective = OBJ_1D, search_space = PS_1D, terminator = trm("evals", n_evals = 10), - callbacks = list(callback) + callbacks = callback ) optimizer = opt("random_search") @@ -38,43 +40,179 @@ test_that("on_optimization_end works", { expect_equal(instance$terminator$param_set$values$n_evals, 20) }) -test_that("on_result in OptimInstanceBatchSingleCrit works", { +# stages in $eval_batch() ------------------------------------------------------ + +test_that("on_optimizer_before_eval works", { callback = callback_batch(id = "test", - on_result = function(callback, context) { - context$result$y = 2 + on_optimizer_before_eval = function(callback, context) { + set(context$xdt, j = "x", value = 1) + } + ) + + instance = oi( + objective = OBJ_1D, + search_space = PS_1D, + terminator = trm("evals", n_evals = 10), + callbacks = callback + ) + + optimizer = opt("random_search") + optimizer$optimize(instance) + expect_class(instance$objective$context, "ContextBatch") + expect_equal(unique(instance$archive$data$x), 1) +}) + +test_that("on_optimizer_after_eval works", { + callback = callback_batch(id = "test", + on_optimizer_after_eval = function(callback, context) { + set(context$instance$archive$data, j = "y", value = 0.5) + } + ) + + instance = oi( + objective = OBJ_1D, + search_space = PS_1D, + terminator = trm("evals", n_evals = 10), + callbacks = callback + ) + + optimizer = opt("random_search") + optimizer$optimize(instance) + expect_class(instance$objective$context, "ContextBatch") + expect_equal(unique(instance$archive$data$y), 0.5) +}) + +# stages in $assign_result() in OptimInstanceBatchSingleCrit ------------------- + +test_that("on_result_begin in OptimInstanceBatchSingleCrit works", { + callback = callback_batch(id = "test", + on_result_begin = function(callback, context) { + context$result_xdt = data.table(x = 1) + context$result_y = c(y = 2) } ) - instance = OptimInstanceBatchSingleCrit$new( + instance = oi( objective = OBJ_1D, search_space = PS_1D, terminator = trm("evals", n_evals = 10), - callbacks = list(callback) + callbacks = callback ) optimizer = opt("random_search") optimizer$optimize(instance) expect_class(instance$objective$context, "ContextBatch") + expect_equal(instance$result$x, 1) expect_equal(instance$result$y, 2) }) -test_that("on_result in OptimInstanceBatchMultiCrit works", { +test_that("on_result_end in OptimInstanceBatchSingleCrit works", { callback = callback_batch(id = "test", + on_result_end = function(callback, context) { + context$result$y = 2 + } + ) + + instance = oi( + objective = OBJ_1D, + search_space = PS_1D, + terminator = trm("evals", n_evals = 10), + callbacks = callback + ) + + optimizer = opt("random_search") + optimizer$optimize(instance) + expect_class(instance$objective$context, "ContextBatch") + expect_equal(instance$result$y, 2) +}) + +test_that("on_result in OptimInstanceBatchSingleCrit works", { + expect_warning({callback = callback_batch(id = "test", on_result = function(callback, context) { - context$result$y1 = 2 - context$result$y2 = 2 + context$result$y = 2 } + )}, "deprecated") + + instance = oi( + objective = OBJ_1D, + search_space = PS_1D, + terminator = trm("evals", n_evals = 10), + callbacks = callback ) - instance = OptimInstanceBatchMultiCrit$new( + optimizer = opt("random_search") + optimizer$optimize(instance) + expect_class(instance$objective$context, "ContextBatch") + expect_equal(instance$result$y, 2) +}) + +# stages in $assign_result() in OptimInstanceBatchMultiCrit -------------------- + +test_that("on_result_begin in OptimInstanceBatchMultiCrit works", { + callback = callback_batch(id = "test", + on_result_begin = function(callback, context) { + context$result_xdt = data.table(x1 = 1, x2 = 1) + context$result_ydt = data.table(y1 = 2, y2 = 2) + } + ) + + instance = oi( objective = OBJ_2D_2D, search_space = PS_2D, terminator = trm("evals", n_evals = 10), - callbacks = list(callback) + callbacks = callback ) optimizer = opt("random_search") optimizer$optimize(instance) + expect_class(instance$objective$context, "ContextBatch") + expect_equal(instance$result$x1, 1) + expect_equal(instance$result$x2, 1) expect_equal(unique(instance$result$y1), 2) expect_equal(unique(instance$result$y2), 2) }) + +test_that("on_result_end in OptimInstanceBatchMultiCrit works", { + callback = callback_batch(id = "test", + on_result_end = function(callback, context) { + set(context$result, j = "y1", value = 2) + set(context$result, j = "y2", value = 3) + } + ) + + instance = oi( + objective = OBJ_2D_2D, + search_space = PS_2D, + terminator = trm("evals", n_evals = 10), + callbacks = callback + ) + + optimizer = opt("random_search") + optimizer$optimize(instance) + expect_class(instance$objective$context, "ContextBatch") + expect_equal(unique(instance$result$y1), 2) + expect_equal(unique(instance$result$y2), 3) +}) + +test_that("on_result in OptimInstanceBatchMultiCrit works", { + expect_warning({callback = callback_batch(id = "test", + on_result = function(callback, context) { + set(context$result, j = "y1", value = 2) + set(context$result, j = "y2", value = 3) + } + )}, "deprecated") + + instance = oi( + objective = OBJ_2D_2D, + search_space = PS_2D, + terminator = trm("evals", n_evals = 10), + callbacks = callback + ) + + optimizer = opt("random_search") + optimizer$optimize(instance) + expect_class(instance$objective$context, "ContextBatch") + expect_equal(unique(instance$result$y1), 2) + expect_equal(unique(instance$result$y2), 3) +}) +