diff --git a/DESCRIPTION b/DESCRIPTION index e4a237d..1912a7b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,6 +18,7 @@ Imports: mlr3misc, processx, redux, + rlecuyer, uuid Suggests: callr, diff --git a/NAMESPACE b/NAMESPACE index 89b68bf..2e74e07 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -15,6 +15,7 @@ import(checkmate) import(data.table) import(mlr3misc) import(redux) +import(rlecuyer) importFrom(jsonlite,fromJSON) importFrom(processx,process) importFrom(utils,str) diff --git a/R/Rush.R b/R/Rush.R index dee44e0..46ff1d5 100644 --- a/R/Rush.R +++ b/R/Rush.R @@ -97,8 +97,10 @@ #' @template param_heartbeat_expire #' @template param_lgr_thresholds #' @template param_lgr_buffer_size +#' @template param_seed #' @template param_data_format #' +#' #' @export Rush = R6::R6Class("Rush", public = list( @@ -181,6 +183,7 @@ Rush = R6::R6Class("Rush", lgr_thresholds = NULL, lgr_buffer_size = 0, max_retries = 0, + seed = NULL, supervise = TRUE, worker_loop = worker_loop_default, ... @@ -202,6 +205,7 @@ Rush = R6::R6Class("Rush", lgr_thresholds = lgr_thresholds, lgr_buffer_size = lgr_buffer_size, max_retries = max_retries, + seed = seed, worker_loop = worker_loop, ... ) @@ -548,8 +552,10 @@ Rush = R6::R6Class("Rush", lg$debug("Pushing %i task(s) to the shared queue", length(xss)) keys = self$write_hashes(xs = xss, xs_extra = extra) - r$command(c("LPUSH", private$.get_key("queued_tasks"), keys)) - r$command(c("SADD", private$.get_key("all_tasks"), keys)) + cmds = list( + c("RPUSH", private$.get_key("all_tasks"), keys), + c("LPUSH", private$.get_key("queued_tasks"), keys)) + r$pipeline(.commands = cmds) if (terminate_workers) r$command(c("SET", private$.get_key("terminate_on_idle"), 1)) return(invisible(keys)) @@ -592,7 +598,7 @@ Rush = R6::R6Class("Rush", } }) r$pipeline(.commands = cmds) - r$command(c("SADD", private$.get_key("all_tasks"), keys)) + r$command(c("RPUSH", private$.get_key("all_tasks"), keys)) return(invisible(keys)) }, @@ -983,7 +989,7 @@ Rush = R6::R6Class("Rush", #' Keys of all tasks. tasks = function() { r = self$connector - unlist(r$SMEMBERS(private$.get_key("all_tasks"))) + unlist(r$LRANGE(private$.get_key("all_tasks"), 0, -1)) }, #' @field queued_tasks (`character()`)\cr @@ -1055,7 +1061,7 @@ Rush = R6::R6Class("Rush", #' Number of all tasks. n_tasks = function() { r = self$connector - as.integer(r$SCARD(private$.get_key("all_tasks"))) %??% 0 + as.integer(r$LLEN(private$.get_key("all_tasks"))) %??% 0 }, #' @field data ([data.table::data.table])\cr @@ -1171,6 +1177,7 @@ Rush = R6::R6Class("Rush", lgr_thresholds = NULL, lgr_buffer_size = 0, max_retries = 0, + seed = NULL, worker_loop = worker_loop_default, ... ) { @@ -1182,6 +1189,7 @@ Rush = R6::R6Class("Rush", assert_vector(lgr_thresholds, names = "named", null.ok = TRUE) assert_count(lgr_buffer_size) assert_count(max_retries) + assert_int(seed, null.ok = TRUE) assert_function(worker_loop) dots = list(...) r = self$connector @@ -1201,6 +1209,7 @@ Rush = R6::R6Class("Rush", heartbeat_expire = heartbeat_expire, lgr_thresholds = lgr_thresholds, lgr_buffer_size = lgr_buffer_size, + seed = seed, max_retries = max_retries) # arguments needed for initializing the worker diff --git a/R/RushWorker.R b/R/RushWorker.R index 3ce3a9e..03060a1 100644 --- a/R/RushWorker.R +++ b/R/RushWorker.R @@ -15,6 +15,7 @@ #' @template param_heartbeat_expire #' @template param_lgr_thresholds #' @template param_lgr_buffer_size +#' @template param_seed #' #' @export RushWorker = R6::R6Class("RushWorker", @@ -44,6 +45,7 @@ RushWorker = R6::R6Class("RushWorker", heartbeat_expire = NULL, lgr_thresholds = NULL, lgr_buffer_size = 0, + seed = NULL, max_retries = 0 ) { super$initialize(network_id = network_id, config = config) @@ -101,6 +103,15 @@ RushWorker = R6::R6Class("RushWorker", } } + # initialize seed table + if (!is.null(seed)) { + private$.seed = TRUE + .lec.SetPackageSeed(seed) + walk(self$tasks, function(key) { + .lec.CreateStream(key) + }) + } + # register worker ids r$SADD(private$.get_key("worker_ids"), self$worker_id) r$SADD(private$.get_key("running_worker_ids"), self$worker_id) @@ -141,7 +152,7 @@ RushWorker = R6::R6Class("RushWorker", keys = self$write_hashes(xs = xss, xs_extra = extra) r$command(c("SADD", private$.get_key("running_tasks"), keys)) - r$command(c("SADD", private$.get_key("all_tasks"), keys)) + r$command(c("RPUSH", private$.get_key("all_tasks"), keys)) return(invisible(keys)) }, @@ -220,6 +231,28 @@ RushWorker = R6::R6Class("RushWorker", return(invisible(self)) }, + #' @description + #' Sets the seed for `key`. + #' Updates the seed table if necessary. + #' + #' @param key (`character(1)`)\cr + #' Key of the task. + set_seed = function(key) { + if (!private$.seed) return(invisible(self)) + r = self$connector + + # update seed table + n_streams = length(.lec.Random.seed.table$name) + if (self$n_tasks > n_streams) { + keys = r$LRANGE(private$.get_key("all_tasks"), n_streams, -1) + walk(keys, function(key) .lec.CreateStream(key)) + } + + # set seed + .lec.CurrentStream(key) + return(invisible(self)) + }, + #' @description #' Mark the worker as terminated. #' Last step in the worker loop before the worker terminates. @@ -248,5 +281,9 @@ RushWorker = R6::R6Class("RushWorker", r = self$connector as.logical(r$EXISTS(private$.get_key("terminate_on_idle"))) && !as.logical(self$n_queued_tasks) } + ), + + private = list( + .seed = NULL ) ) diff --git a/R/worker_loops.R b/R/worker_loops.R index 8623ba2..a5fbeff 100644 --- a/R/worker_loops.R +++ b/R/worker_loops.R @@ -20,6 +20,8 @@ worker_loop_default = function(fun, constants = NULL, rush) { while(!rush$terminated) { task = rush$pop_task() if (!is.null(task)) { + # set seed + rush$set_seed(task$key) tryCatch({ ys = mlr3misc::invoke(fun, .args = c(task$xs, constants)) rush$push_results(task$key, yss = list(ys)) diff --git a/R/zzz.R b/R/zzz.R index 3968320..ae12073 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -2,6 +2,7 @@ #' @import redux #' @import mlr3misc #' @import checkmate +#' @import rlecuyer #' @importFrom processx process #' @importFrom uuid UUIDgenerate #' @importFrom utils str diff --git a/man-roxygen/param_seed.R b/man-roxygen/param_seed.R new file mode 100644 index 0000000..b4c08a9 --- /dev/null +++ b/man-roxygen/param_seed.R @@ -0,0 +1,2 @@ +#' @param seed (`integer(1)`)\cr +#' Seed for the random number generator. diff --git a/man/Rush.Rd b/man/Rush.Rd index e6e920e..f562711 100644 --- a/man/Rush.Rd +++ b/man/Rush.Rd @@ -316,6 +316,7 @@ This function takes the arguments \code{fun} and optionally \code{constants} whi lgr_thresholds = NULL, lgr_buffer_size = 0, max_retries = 0, + seed = NULL, supervise = TRUE, worker_loop = worker_loop_default, ... @@ -351,6 +352,9 @@ By default (\code{lgr_buffer_size = 0}), the log messages are directly saved in If \code{lgr_buffer_size > 0}, the log messages are buffered and saved in the Redis data store when the buffer is full. This improves the performance of the logging.} +\item{\code{seed}}{(\code{integer(1)})\cr +Seed for the random number generator.} + \item{\code{supervise}}{(\code{logical(1)})\cr Whether to kill the workers when the main R process is shut down.} diff --git a/man/RushWorker.Rd b/man/RushWorker.Rd index c774086..f13c4a4 100644 --- a/man/RushWorker.Rd +++ b/man/RushWorker.Rd @@ -48,6 +48,7 @@ Used in the worker loop to determine whether to continue.} \item \href{#method-RushWorker-pop_task}{\code{RushWorker$pop_task()}} \item \href{#method-RushWorker-push_results}{\code{RushWorker$push_results()}} \item \href{#method-RushWorker-push_failed}{\code{RushWorker$push_failed()}} +\item \href{#method-RushWorker-set_seed}{\code{RushWorker$set_seed()}} \item \href{#method-RushWorker-set_terminated}{\code{RushWorker$set_terminated()}} \item \href{#method-RushWorker-clone}{\code{RushWorker$clone()}} } @@ -99,6 +100,7 @@ Creates a new instance of this \link[R6:R6Class]{R6} class. heartbeat_expire = NULL, lgr_thresholds = NULL, lgr_buffer_size = 0, + seed = NULL, max_retries = 0 )}\if{html}{\out{}} } @@ -139,6 +141,9 @@ Logger threshold on the workers e.g. \code{c(rush = "debug")}.} By default (\code{lgr_buffer_size = 0}), the log messages are directly saved in the Redis data store. If \code{lgr_buffer_size > 0}, the log messages are buffered and saved in the Redis data store when the buffer is full. This improves the performance of the logging.} + +\item{\code{seed}}{(\code{integer(1)})\cr +Seed for the random number generator.} } \if{html}{\out{}} } @@ -232,6 +237,25 @@ If \code{"error"} the tasks are moved to the failed tasks.} \if{html}{\out{