From b7b6871c73e03aefbe0f74d05e89ad4adb26dafe Mon Sep 17 00:00:00 2001 From: be-marc Date: Mon, 4 Dec 2023 21:13:43 +0100 Subject: [PATCH] feat: support seeds with L'Ecuyer RngStreams --- DESCRIPTION | 1 + NAMESPACE | 1 + R/Rush.R | 19 +++++++++++---- R/RushWorker.R | 39 +++++++++++++++++++++++++++++- R/worker_loops.R | 2 ++ R/zzz.R | 1 + man-roxygen/param_seed.R | 2 ++ man/Rush.Rd | 4 +++ man/RushWorker.Rd | 24 ++++++++++++++++++ tests/testthat/test-Rush.R | 23 ++++++++++++++++++ tests/testthat/test-RushWorker.R | 30 +++++++++++++++++++++++ tests/testthat/test-worker_loops.R | 18 ++++++++++++++ 12 files changed, 158 insertions(+), 6 deletions(-) create mode 100644 man-roxygen/param_seed.R diff --git a/DESCRIPTION b/DESCRIPTION index e4a237d..1912a7b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,6 +18,7 @@ Imports: mlr3misc, processx, redux, + rlecuyer, uuid Suggests: callr, diff --git a/NAMESPACE b/NAMESPACE index 89b68bf..2e74e07 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -15,6 +15,7 @@ import(checkmate) import(data.table) import(mlr3misc) import(redux) +import(rlecuyer) importFrom(jsonlite,fromJSON) importFrom(processx,process) importFrom(utils,str) diff --git a/R/Rush.R b/R/Rush.R index dee44e0..46ff1d5 100644 --- a/R/Rush.R +++ b/R/Rush.R @@ -97,8 +97,10 @@ #' @template param_heartbeat_expire #' @template param_lgr_thresholds #' @template param_lgr_buffer_size +#' @template param_seed #' @template param_data_format #' +#' #' @export Rush = R6::R6Class("Rush", public = list( @@ -181,6 +183,7 @@ Rush = R6::R6Class("Rush", lgr_thresholds = NULL, lgr_buffer_size = 0, max_retries = 0, + seed = NULL, supervise = TRUE, worker_loop = worker_loop_default, ... @@ -202,6 +205,7 @@ Rush = R6::R6Class("Rush", lgr_thresholds = lgr_thresholds, lgr_buffer_size = lgr_buffer_size, max_retries = max_retries, + seed = seed, worker_loop = worker_loop, ... ) @@ -548,8 +552,10 @@ Rush = R6::R6Class("Rush", lg$debug("Pushing %i task(s) to the shared queue", length(xss)) keys = self$write_hashes(xs = xss, xs_extra = extra) - r$command(c("LPUSH", private$.get_key("queued_tasks"), keys)) - r$command(c("SADD", private$.get_key("all_tasks"), keys)) + cmds = list( + c("RPUSH", private$.get_key("all_tasks"), keys), + c("LPUSH", private$.get_key("queued_tasks"), keys)) + r$pipeline(.commands = cmds) if (terminate_workers) r$command(c("SET", private$.get_key("terminate_on_idle"), 1)) return(invisible(keys)) @@ -592,7 +598,7 @@ Rush = R6::R6Class("Rush", } }) r$pipeline(.commands = cmds) - r$command(c("SADD", private$.get_key("all_tasks"), keys)) + r$command(c("RPUSH", private$.get_key("all_tasks"), keys)) return(invisible(keys)) }, @@ -983,7 +989,7 @@ Rush = R6::R6Class("Rush", #' Keys of all tasks. tasks = function() { r = self$connector - unlist(r$SMEMBERS(private$.get_key("all_tasks"))) + unlist(r$LRANGE(private$.get_key("all_tasks"), 0, -1)) }, #' @field queued_tasks (`character()`)\cr @@ -1055,7 +1061,7 @@ Rush = R6::R6Class("Rush", #' Number of all tasks. n_tasks = function() { r = self$connector - as.integer(r$SCARD(private$.get_key("all_tasks"))) %??% 0 + as.integer(r$LLEN(private$.get_key("all_tasks"))) %??% 0 }, #' @field data ([data.table::data.table])\cr @@ -1171,6 +1177,7 @@ Rush = R6::R6Class("Rush", lgr_thresholds = NULL, lgr_buffer_size = 0, max_retries = 0, + seed = NULL, worker_loop = worker_loop_default, ... ) { @@ -1182,6 +1189,7 @@ Rush = R6::R6Class("Rush", assert_vector(lgr_thresholds, names = "named", null.ok = TRUE) assert_count(lgr_buffer_size) assert_count(max_retries) + assert_int(seed, null.ok = TRUE) assert_function(worker_loop) dots = list(...) r = self$connector @@ -1201,6 +1209,7 @@ Rush = R6::R6Class("Rush", heartbeat_expire = heartbeat_expire, lgr_thresholds = lgr_thresholds, lgr_buffer_size = lgr_buffer_size, + seed = seed, max_retries = max_retries) # arguments needed for initializing the worker diff --git a/R/RushWorker.R b/R/RushWorker.R index 3ce3a9e..03060a1 100644 --- a/R/RushWorker.R +++ b/R/RushWorker.R @@ -15,6 +15,7 @@ #' @template param_heartbeat_expire #' @template param_lgr_thresholds #' @template param_lgr_buffer_size +#' @template param_seed #' #' @export RushWorker = R6::R6Class("RushWorker", @@ -44,6 +45,7 @@ RushWorker = R6::R6Class("RushWorker", heartbeat_expire = NULL, lgr_thresholds = NULL, lgr_buffer_size = 0, + seed = NULL, max_retries = 0 ) { super$initialize(network_id = network_id, config = config) @@ -101,6 +103,15 @@ RushWorker = R6::R6Class("RushWorker", } } + # initialize seed table + if (!is.null(seed)) { + private$.seed = TRUE + .lec.SetPackageSeed(seed) + walk(self$tasks, function(key) { + .lec.CreateStream(key) + }) + } + # register worker ids r$SADD(private$.get_key("worker_ids"), self$worker_id) r$SADD(private$.get_key("running_worker_ids"), self$worker_id) @@ -141,7 +152,7 @@ RushWorker = R6::R6Class("RushWorker", keys = self$write_hashes(xs = xss, xs_extra = extra) r$command(c("SADD", private$.get_key("running_tasks"), keys)) - r$command(c("SADD", private$.get_key("all_tasks"), keys)) + r$command(c("RPUSH", private$.get_key("all_tasks"), keys)) return(invisible(keys)) }, @@ -220,6 +231,28 @@ RushWorker = R6::R6Class("RushWorker", return(invisible(self)) }, + #' @description + #' Sets the seed for `key`. + #' Updates the seed table if necessary. + #' + #' @param key (`character(1)`)\cr + #' Key of the task. + set_seed = function(key) { + if (!private$.seed) return(invisible(self)) + r = self$connector + + # update seed table + n_streams = length(.lec.Random.seed.table$name) + if (self$n_tasks > n_streams) { + keys = r$LRANGE(private$.get_key("all_tasks"), n_streams, -1) + walk(keys, function(key) .lec.CreateStream(key)) + } + + # set seed + .lec.CurrentStream(key) + return(invisible(self)) + }, + #' @description #' Mark the worker as terminated. #' Last step in the worker loop before the worker terminates. @@ -248,5 +281,9 @@ RushWorker = R6::R6Class("RushWorker", r = self$connector as.logical(r$EXISTS(private$.get_key("terminate_on_idle"))) && !as.logical(self$n_queued_tasks) } + ), + + private = list( + .seed = NULL ) ) diff --git a/R/worker_loops.R b/R/worker_loops.R index 8623ba2..a5fbeff 100644 --- a/R/worker_loops.R +++ b/R/worker_loops.R @@ -20,6 +20,8 @@ worker_loop_default = function(fun, constants = NULL, rush) { while(!rush$terminated) { task = rush$pop_task() if (!is.null(task)) { + # set seed + rush$set_seed(task$key) tryCatch({ ys = mlr3misc::invoke(fun, .args = c(task$xs, constants)) rush$push_results(task$key, yss = list(ys)) diff --git a/R/zzz.R b/R/zzz.R index 3968320..ae12073 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -2,6 +2,7 @@ #' @import redux #' @import mlr3misc #' @import checkmate +#' @import rlecuyer #' @importFrom processx process #' @importFrom uuid UUIDgenerate #' @importFrom utils str diff --git a/man-roxygen/param_seed.R b/man-roxygen/param_seed.R new file mode 100644 index 0000000..b4c08a9 --- /dev/null +++ b/man-roxygen/param_seed.R @@ -0,0 +1,2 @@ +#' @param seed (`integer(1)`)\cr +#' Seed for the random number generator. diff --git a/man/Rush.Rd b/man/Rush.Rd index e6e920e..f562711 100644 --- a/man/Rush.Rd +++ b/man/Rush.Rd @@ -316,6 +316,7 @@ This function takes the arguments \code{fun} and optionally \code{constants} whi lgr_thresholds = NULL, lgr_buffer_size = 0, max_retries = 0, + seed = NULL, supervise = TRUE, worker_loop = worker_loop_default, ... @@ -351,6 +352,9 @@ By default (\code{lgr_buffer_size = 0}), the log messages are directly saved in If \code{lgr_buffer_size > 0}, the log messages are buffered and saved in the Redis data store when the buffer is full. This improves the performance of the logging.} +\item{\code{seed}}{(\code{integer(1)})\cr +Seed for the random number generator.} + \item{\code{supervise}}{(\code{logical(1)})\cr Whether to kill the workers when the main R process is shut down.} diff --git a/man/RushWorker.Rd b/man/RushWorker.Rd index c774086..f13c4a4 100644 --- a/man/RushWorker.Rd +++ b/man/RushWorker.Rd @@ -48,6 +48,7 @@ Used in the worker loop to determine whether to continue.} \item \href{#method-RushWorker-pop_task}{\code{RushWorker$pop_task()}} \item \href{#method-RushWorker-push_results}{\code{RushWorker$push_results()}} \item \href{#method-RushWorker-push_failed}{\code{RushWorker$push_failed()}} +\item \href{#method-RushWorker-set_seed}{\code{RushWorker$set_seed()}} \item \href{#method-RushWorker-set_terminated}{\code{RushWorker$set_terminated()}} \item \href{#method-RushWorker-clone}{\code{RushWorker$clone()}} } @@ -99,6 +100,7 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. heartbeat_expire = NULL, lgr_thresholds = NULL, lgr_buffer_size = 0, + seed = NULL, max_retries = 0 )}\if{html}{\out{}} } @@ -139,6 +141,9 @@ Logger threshold on the workers e.g. \code{c(rush = "debug")}.} By default (\code{lgr_buffer_size = 0}), the log messages are directly saved in the Redis data store. If \code{lgr_buffer_size > 0}, the log messages are buffered and saved in the Redis data store when the buffer is full. This improves the performance of the logging.} + +\item{\code{seed}}{(\code{integer(1)})\cr +Seed for the random number generator.} } \if{html}{\out{}} } @@ -232,6 +237,25 @@ If \code{"error"} the tasks are moved to the failed tasks.} \if{html}{\out{
}}\preformatted{RushWorker$push_failed(keys, conditions)}\if{html}{\out{
}} } +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-RushWorker-set_seed}{}}} +\subsection{Method \code{set_seed()}}{ +Sets the seed for \code{key}. +Updates the seed table if necessary. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{RushWorker$set_seed(key)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{key}}{(\code{character(1)})\cr +Key of the task.} +} +\if{html}{\out{
}} +} } \if{html}{\out{
}} \if{html}{\out{}} diff --git a/tests/testthat/test-Rush.R b/tests/testthat/test-Rush.R index 87f36b0..5010de0 100644 --- a/tests/testthat/test-Rush.R +++ b/tests/testthat/test-Rush.R @@ -797,3 +797,26 @@ test_that("network without controller works", { expect_rush_reset(rush) }) + +# seed ------------------------------------------------------------------------- + +test_that("seed is set correctly on two workers", { + skip_on_cran() + skip_on_ci() + + config = start_flush_redis() + rush = Rush$new(network_id = "test-rush", config = config) + fun = function(x1, x2, ...) list(y = sample(10000, 1)) + worker_ids = rush$start_workers(fun = fun, n_workers = 2, seed = 123456, wait_for_workers = TRUE) + + .keys = rush$push_tasks(list(list(x1 = 1, x2 = 2), list(x1 = 2, x2 = 2), list(x1 = 2, x2 = 3), list(x1 = 2, x2 = 4))) + rush$wait_for_tasks(.keys) + + finished_tasks = rush$fetch_finished_tasks() + expect_equal(finished_tasks[.keys[1], y, on = "keys"], 4492) + expect_equal(finished_tasks[.keys[2], y, on = "keys"], 9223) + expect_equal(finished_tasks[.keys[3], y, on = "keys"], 2926) + expect_equal(finished_tasks[.keys[4], y, on = "keys"], 4937) + + expect_rush_reset(rush, type = "terminate") +}) diff --git a/tests/testthat/test-RushWorker.R b/tests/testthat/test-RushWorker.R index aacd6b8..0e9ac4c 100644 --- a/tests/testthat/test-RushWorker.R +++ b/tests/testthat/test-RushWorker.R @@ -877,6 +877,9 @@ test_that("n_retries method works", { }) test_that("terminate on idle works", { + skip_on_cran() + skip_on_ci() + config = start_flush_redis() rush = RushWorker$new(network_id = "test-rush", config = config, host = "local") @@ -889,3 +892,30 @@ test_that("terminate on idle works", { expect_rush_reset(rush, type = "terminate") }) + + +# seed ------------------------------------------------------------------------- + +test_that("seed is set correctly", { + skip_on_cran() + skip_on_ci() + + on.exit({ + .lec.exit() + }) + + config = start_flush_redis() + rush = RushWorker$new(network_id = "test-rush", config = config, host = "local", seed = 123456) + + expect_null(.lec.Random.seed.table$name) + + rush$push_tasks(list(list(x1 = 1, x2 = 2))) + task = rush$pop_task() + rush$set_seed(task$key) + + expect_equal(.lec.Random.seed.table$name, task$key) + + expect_equal(sample(seq(100000), 1), 86412) + + expect_rush_reset(rush, type = "terminate") +}) diff --git a/tests/testthat/test-worker_loops.R b/tests/testthat/test-worker_loops.R index de09f52..a3bdb94 100644 --- a/tests/testthat/test-worker_loops.R +++ b/tests/testthat/test-worker_loops.R @@ -37,4 +37,22 @@ test_that("worker_loop_default works with terminate ", { expect_rush_reset(rush, type = "terminate") }) +test_that("seed is set correctly", { + + on.exit({ + .lec.exit() + }) + + config = start_flush_redis() + rush = RushWorker$new(network_id = "test-rush", config = config, host = "local", seed = 123456) + xss = list(list(x1 = 1, x2 = 2)) + rush$push_tasks(xss, terminate_workers = TRUE) + fun = function(x1, x2, ...) list(y = sample(10000, 1)) + + expect_null(worker_loop_default(fun, rush = rush)) + + expect_equal(rush$fetch_finished_tasks()$y, 4492) + + expect_rush_reset(rush, type = "terminate") +})