Skip to content

Commit

Permalink
feat: support seeds with L'Ecuyer RngStreams
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 4, 2023
1 parent 23fe3db commit b7b6871
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 6 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Imports:
mlr3misc,
processx,
redux,
rlecuyer,
uuid
Suggests:
callr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import(checkmate)
import(data.table)
import(mlr3misc)
import(redux)
import(rlecuyer)
importFrom(jsonlite,fromJSON)
importFrom(processx,process)
importFrom(utils,str)
Expand Down
19 changes: 14 additions & 5 deletions R/Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@
#' @template param_heartbeat_expire
#' @template param_lgr_thresholds
#' @template param_lgr_buffer_size
#' @template param_seed
#' @template param_data_format
#'
#'
#' @export
Rush = R6::R6Class("Rush",
public = list(
Expand Down Expand Up @@ -181,6 +183,7 @@ Rush = R6::R6Class("Rush",
lgr_thresholds = NULL,
lgr_buffer_size = 0,
max_retries = 0,
seed = NULL,
supervise = TRUE,
worker_loop = worker_loop_default,
...
Expand All @@ -202,6 +205,7 @@ Rush = R6::R6Class("Rush",
lgr_thresholds = lgr_thresholds,
lgr_buffer_size = lgr_buffer_size,
max_retries = max_retries,
seed = seed,
worker_loop = worker_loop,
...
)
Expand Down Expand Up @@ -548,8 +552,10 @@ Rush = R6::R6Class("Rush",
lg$debug("Pushing %i task(s) to the shared queue", length(xss))

keys = self$write_hashes(xs = xss, xs_extra = extra)
r$command(c("LPUSH", private$.get_key("queued_tasks"), keys))
r$command(c("SADD", private$.get_key("all_tasks"), keys))
cmds = list(
c("RPUSH", private$.get_key("all_tasks"), keys),
c("LPUSH", private$.get_key("queued_tasks"), keys))
r$pipeline(.commands = cmds)
if (terminate_workers) r$command(c("SET", private$.get_key("terminate_on_idle"), 1))

return(invisible(keys))
Expand Down Expand Up @@ -592,7 +598,7 @@ Rush = R6::R6Class("Rush",
}
})
r$pipeline(.commands = cmds)
r$command(c("SADD", private$.get_key("all_tasks"), keys))
r$command(c("RPUSH", private$.get_key("all_tasks"), keys))

return(invisible(keys))
},
Expand Down Expand Up @@ -983,7 +989,7 @@ Rush = R6::R6Class("Rush",
#' Keys of all tasks.
tasks = function() {
r = self$connector
unlist(r$SMEMBERS(private$.get_key("all_tasks")))
unlist(r$LRANGE(private$.get_key("all_tasks"), 0, -1))
},

#' @field queued_tasks (`character()`)\cr
Expand Down Expand Up @@ -1055,7 +1061,7 @@ Rush = R6::R6Class("Rush",
#' Number of all tasks.
n_tasks = function() {
r = self$connector
as.integer(r$SCARD(private$.get_key("all_tasks"))) %??% 0
as.integer(r$LLEN(private$.get_key("all_tasks"))) %??% 0
},

#' @field data ([data.table::data.table])\cr
Expand Down Expand Up @@ -1171,6 +1177,7 @@ Rush = R6::R6Class("Rush",
lgr_thresholds = NULL,
lgr_buffer_size = 0,
max_retries = 0,
seed = NULL,
worker_loop = worker_loop_default,
...
) {
Expand All @@ -1182,6 +1189,7 @@ Rush = R6::R6Class("Rush",
assert_vector(lgr_thresholds, names = "named", null.ok = TRUE)
assert_count(lgr_buffer_size)
assert_count(max_retries)
assert_int(seed, null.ok = TRUE)
assert_function(worker_loop)
dots = list(...)
r = self$connector
Expand All @@ -1201,6 +1209,7 @@ Rush = R6::R6Class("Rush",
heartbeat_expire = heartbeat_expire,
lgr_thresholds = lgr_thresholds,
lgr_buffer_size = lgr_buffer_size,
seed = seed,
max_retries = max_retries)

# arguments needed for initializing the worker
Expand Down
39 changes: 38 additions & 1 deletion R/RushWorker.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#' @template param_heartbeat_expire
#' @template param_lgr_thresholds
#' @template param_lgr_buffer_size
#' @template param_seed
#'
#' @export
RushWorker = R6::R6Class("RushWorker",
Expand Down Expand Up @@ -44,6 +45,7 @@ RushWorker = R6::R6Class("RushWorker",
heartbeat_expire = NULL,
lgr_thresholds = NULL,
lgr_buffer_size = 0,
seed = NULL,
max_retries = 0
) {
super$initialize(network_id = network_id, config = config)
Expand Down Expand Up @@ -101,6 +103,15 @@ RushWorker = R6::R6Class("RushWorker",
}
}

# initialize seed table
if (!is.null(seed)) {
private$.seed = TRUE
.lec.SetPackageSeed(seed)
walk(self$tasks, function(key) {
.lec.CreateStream(key)
})
}

# register worker ids
r$SADD(private$.get_key("worker_ids"), self$worker_id)
r$SADD(private$.get_key("running_worker_ids"), self$worker_id)
Expand Down Expand Up @@ -141,7 +152,7 @@ RushWorker = R6::R6Class("RushWorker",

keys = self$write_hashes(xs = xss, xs_extra = extra)
r$command(c("SADD", private$.get_key("running_tasks"), keys))
r$command(c("SADD", private$.get_key("all_tasks"), keys))
r$command(c("RPUSH", private$.get_key("all_tasks"), keys))

return(invisible(keys))
},
Expand Down Expand Up @@ -220,6 +231,28 @@ RushWorker = R6::R6Class("RushWorker",
return(invisible(self))
},

#' @description
#' Sets the seed for `key`.
#' Updates the seed table if necessary.
#'
#' @param key (`character(1)`)\cr
#' Key of the task.
set_seed = function(key) {
if (!private$.seed) return(invisible(self))
r = self$connector

# update seed table
n_streams = length(.lec.Random.seed.table$name)
if (self$n_tasks > n_streams) {
keys = r$LRANGE(private$.get_key("all_tasks"), n_streams, -1)
walk(keys, function(key) .lec.CreateStream(key))
}

# set seed
.lec.CurrentStream(key)
return(invisible(self))
},

#' @description
#' Mark the worker as terminated.
#' Last step in the worker loop before the worker terminates.
Expand Down Expand Up @@ -248,5 +281,9 @@ RushWorker = R6::R6Class("RushWorker",
r = self$connector
as.logical(r$EXISTS(private$.get_key("terminate_on_idle"))) && !as.logical(self$n_queued_tasks)
}
),

private = list(
.seed = NULL
)
)
2 changes: 2 additions & 0 deletions R/worker_loops.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ worker_loop_default = function(fun, constants = NULL, rush) {
while(!rush$terminated) {
task = rush$pop_task()
if (!is.null(task)) {
# set seed
rush$set_seed(task$key)
tryCatch({
ys = mlr3misc::invoke(fun, .args = c(task$xs, constants))
rush$push_results(task$key, yss = list(ys))
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @import redux
#' @import mlr3misc
#' @import checkmate
#' @import rlecuyer
#' @importFrom processx process
#' @importFrom uuid UUIDgenerate
#' @importFrom utils str
Expand Down
2 changes: 2 additions & 0 deletions man-roxygen/param_seed.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#' @param seed (`integer(1)`)\cr
#' Seed for the random number generator.
4 changes: 4 additions & 0 deletions man/Rush.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/RushWorker.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/test-Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,26 @@ test_that("network without controller works", {

expect_rush_reset(rush)
})

# seed -------------------------------------------------------------------------

test_that("seed is set correctly on two workers", {
skip_on_cran()
skip_on_ci()

config = start_flush_redis()
rush = Rush$new(network_id = "test-rush", config = config)
fun = function(x1, x2, ...) list(y = sample(10000, 1))
worker_ids = rush$start_workers(fun = fun, n_workers = 2, seed = 123456, wait_for_workers = TRUE)

.keys = rush$push_tasks(list(list(x1 = 1, x2 = 2), list(x1 = 2, x2 = 2), list(x1 = 2, x2 = 3), list(x1 = 2, x2 = 4)))
rush$wait_for_tasks(.keys)

finished_tasks = rush$fetch_finished_tasks()
expect_equal(finished_tasks[.keys[1], y, on = "keys"], 4492)
expect_equal(finished_tasks[.keys[2], y, on = "keys"], 9223)
expect_equal(finished_tasks[.keys[3], y, on = "keys"], 2926)
expect_equal(finished_tasks[.keys[4], y, on = "keys"], 4937)

expect_rush_reset(rush, type = "terminate")
})
30 changes: 30 additions & 0 deletions tests/testthat/test-RushWorker.R
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,9 @@ test_that("n_retries method works", {
})

test_that("terminate on idle works", {
skip_on_cran()
skip_on_ci()

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local")

Expand All @@ -889,3 +892,30 @@ test_that("terminate on idle works", {

expect_rush_reset(rush, type = "terminate")
})


# seed -------------------------------------------------------------------------

test_that("seed is set correctly", {
skip_on_cran()
skip_on_ci()

on.exit({
.lec.exit()
})

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local", seed = 123456)

expect_null(.lec.Random.seed.table$name)

rush$push_tasks(list(list(x1 = 1, x2 = 2)))
task = rush$pop_task()
rush$set_seed(task$key)

expect_equal(.lec.Random.seed.table$name, task$key)

expect_equal(sample(seq(100000), 1), 86412)

expect_rush_reset(rush, type = "terminate")
})
18 changes: 18 additions & 0 deletions tests/testthat/test-worker_loops.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,22 @@ test_that("worker_loop_default works with terminate ", {
expect_rush_reset(rush, type = "terminate")
})

test_that("seed is set correctly", {

on.exit({
.lec.exit()
})

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local", seed = 123456)
xss = list(list(x1 = 1, x2 = 2))
rush$push_tasks(xss, terminate_workers = TRUE)
fun = function(x1, x2, ...) list(y = sample(10000, 1))

expect_null(worker_loop_default(fun, rush = rush))

expect_equal(rush$fetch_finished_tasks()$y, 4492)

expect_rush_reset(rush, type = "terminate")
})

0 comments on commit b7b6871

Please sign in to comment.