Skip to content

Commit

Permalink
refactor: read functions
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Jan 22, 2024
1 parent 2434d2c commit 8486a92
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 168 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import(mlr3misc)
import(redux)
importFrom(jsonlite,fromJSON)
importFrom(parallel,nextRNGStream)
importFrom(parallel,nextRNGSubStream)
importFrom(processx,process)
importFrom(utils,str)
importFrom(uuid,UUIDgenerate)
79 changes: 48 additions & 31 deletions R/Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -532,15 +532,6 @@ Rush = R6::R6Class("Rush",
tab[]
},

read_tasks = function(keys, fields = c("status", "seed", "timeout", "max_retries", "n_retries")) {
r = self$connector
map(keys, function(key) {
task = setNames(map(r$HMGET(key, fields), safe_bin_to_object), fields)
task$key = key
task
})
},

#' @description
#' Pushes a task to the queue.
#' Task is added to queued tasks.
Expand Down Expand Up @@ -673,13 +664,15 @@ Rush = R6::R6Class("Rush",
#'
#' @param keys (`character()`)\cr
#' Keys of the tasks to be retried.
#' @param ignore_max_retires (`logical(1)`)\cr
#' Whether to ignore the maximum number of retries.
#' @param next_seed (`logical(1)`)\cr
#' Whether to change the seed of the task.
retry_tasks = function(keys, ignore_max_retires = FALSE, next_seed = FALSE) {
assert_character(keys)
assert_flag(ignore_max_retires)
assert_flag(next_seed)
tasks = self$read_tasks(keys, fields = c("seed", "max_retries", "n_retries"))
tasks = self$read_hashes(keys, fields = c("seed", "max_retries", "n_retries"), flatten = FALSE)
keys = map_chr(tasks, "key")
seeds = map(tasks, "seed")
n_retries = map_int(tasks, function(task) task$n_retries %??% 0L)
Expand Down Expand Up @@ -995,26 +988,24 @@ Rush = R6::R6Class("Rush",
},

#' @description
#' Reads Redis hashes and combines the values of the fields into a list.
#' The function reads the values of the `fields` in the hashes stored at `keys`.
#' The values of a hash are deserialized and combined into a single list.
#'
#' Reads R Objects from Redis hashes.
#' The function reads the field-value pairs of the hashes stored at `keys`.
#' The values of a hash are deserialized and combined to a list.
#' If `flatten` is `TRUE`, the values are flattened to a single list e.g. list(xs = list(x1 = 1, x2 = 2), ys = list(y = 3)) becomes list(x1 = 1, x2 = 2, y = 3).
#' The reading functions combine the hashes to a table where the names of the inner lists are the column names.
#' For example, `xs = list(list(x1 = 1, x2 = 2), list(x1 = 3, x2 = 4)), ys = list(list(y = 3), list(y = 7))` becomes `data.table(x1 = c(1, 3), x2 = c(2, 4), y = c(3, 7))`.
#' Vectors in list columns must be wrapped in lists.
#' Otherwise, `$read_values()` will expand the table by the length of the vectors.
#' For example, `xs = list(list(x1 = 1, x2 = 2)), xs_extra = list(list(extra = c("A", "B", "C"))) does not work.
#' Pass `xs_extra = list(list(extra = list(c("A", "B", "C"))))` instead.
#'
#' @param keys (`character()`)\cr
#' Keys of the hashes.
#' @param fields (`character()`)\cr
#' Fields to be read from the hashes.
#' @param flatten (`logical(1)`)\cr
#' Whether to flatten the list.
#'
#' @return (list of `list()`)\cr
#' The outer list contains one element for each key.
#' The inner list is the combination of the lists stored at the different fields.
read_hashes = function(keys, fields) {
read_hashes = function(keys, fields, flatten = TRUE) {

lg$debug("Reading %i hash(es) with %i field(s)", length(keys), length(fields))

Expand All @@ -1027,18 +1018,44 @@ Rush = R6::R6Class("Rush",
# the values of the fields are serialized lists and atomics
hashes = self$connector$pipeline(.commands = cmds)

# unserialize lists of the second level
# combine elements of the third level to one list
# using mapply instead of pmap is somehow faster
map(hashes, function(hash) unlist(.mapply(function(bin_value, field) {
value = safe_bin_to_object(bin_value)
if (is.atomic(value) && !is.null(value)) {
# list column or column with type of value
if (length(value) > 1) value = list(value)
value = setNames(list(value), field)
}
value
}, list(bin_value = hash, field = fields), NULL), recursive = FALSE))
if (flatten) {
# unserialize elements of the second level
# flatten elements of the third level to one list
# using mapply instead of pmap is faster
map(hashes, function(hash) unlist(.mapply(function(bin_value, field) {
# unserialize value
value = safe_bin_to_object(bin_value)
# wrap atomic values in list and name by field
if (is.atomic(value) && !is.null(value)) {
# list column or column with type of value
if (length(value) > 1) value = list(value)
value = setNames(list(value), field)
}
value
}, list(bin_value = hash, field = fields), NULL), recursive = FALSE))
} else {
# unserialize elements of the second level
map(hashes, function(hash) setNames(map(hash, function(bin_value) {
safe_bin_to_object(bin_value)
}), fields))
}
},

#' @description
#' Reads a single Redis hash and returns the values as a list named by the fields.
#'
#' @param keys (`character()`)\cr
#' Keys of the hashes.
#' @param fields (`character()`)\cr
#' Fields to be read from the hashes.
#'
#' @return (list of `list()`)\cr
#' The outer list contains one element for each key.
#' The inner list is the combination of the lists stored at the different fields.
read_hash = function(key, fields) {
lg$debug("Reading hash with %i field(s)", length(fields))

setNames(map(self$connector$HMGET(key, fields), safe_bin_to_object), fields)
},

#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @importFrom uuid UUIDgenerate
#' @importFrom utils str
#' @importFrom jsonlite fromJSON
#' @importFrom parallel nextRNGStream
#' @importFrom parallel nextRNGStream nextRNGSubStream
"_PACKAGE"

.onLoad = function(libname, pkgname) {
Expand Down
32 changes: 17 additions & 15 deletions man/Rush.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

151 changes: 151 additions & 0 deletions tests/testthat/test-Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,157 @@ test_that("a remote worker is killed via the heartbeat", {
expect_rush_reset(rush)
})

# low level read and write -----------------------------------------------------

test_that("reading and writing a hash works with flatten", {
skip_on_cran()
skip_on_ci()

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local")

# one field with list
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)))
expect_equal(rush$read_hashes(key, "xs"), list(list(x1 = 1, x2 = 2)))

# one field with atomic
key = rush$write_hashes(timeout = 1)
expect_equal(rush$read_hashes(key, "timeout"), list(list(timeout = 1)))

# two fields with lists
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), ys = list(list(y = 3)))
expect_equal(rush$read_hashes(key, c("xs", "ys")), list(list(x1 = 1, x2 = 2, y = 3)))

# two fields with list and empty list
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), ys = list())
expect_equal(rush$read_hashes(key, c("xs", "ys")), list(list(x1 = 1, x2 = 2)))

# two fields with list and atomic
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), timeout = 1)
expect_equal(rush$read_hashes(key, c("xs", "timeout")), list(list(x1 = 1, x2 = 2, timeout = 1)))
})

test_that("reading and writing a hash works without flatten", {
skip_on_cran()
skip_on_ci()

config = start_flush_redis()
rush = Rush$new(network_id = "test-rush", config = config)

# one field with list
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)))
expect_equal(rush$read_hashes(key, "xs", flatten = FALSE), list(list(xs = list(x1 = 1, x2 = 2))))

# one field with atomic
key = rush$write_hashes(timeout = 1)
expect_equal(rush$read_hashes(key, "timeout", flatten = FALSE), list(list(timeout = 1)))

# two fields with lists
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), ys = list(list(y = 3)))
expect_equal(rush$read_hashes(key, c("xs", "ys"), flatten = FALSE), list(list(xs = list(x1 = 1, x2 = 2), ys = list(y = 3))))

# two fields with list and empty list
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), ys = list())
expect_equal(rush$read_hashes(key, c("xs", "ys"), flatten = FALSE), list(list(xs = list(x1 = 1, x2 = 2), ys = NULL)))

# two fields with list and atomic
key = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), timeout = 1)
expect_equal(rush$read_hashes(key, c("xs", "timeout"), flatten = FALSE), list(list(xs = list(x1 = 1, x2 = 2), timeout = 1)))
})

test_that("reading and writing hashes works", {
skip_on_cran()
skip_on_ci()

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local")

# one field with list
keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)))
expect_equal(rush$read_hashes(keys, "xs"), list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)))

# one field atomic
keys = rush$write_hashes(timeout = c(1, 1))
expect_equal(rush$read_hashes(keys, "timeout"), list(list(timeout = 1), list(timeout = 1)))

# two fields with list and recycled atomic
keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), timeout = 1)
expect_equal(rush$read_hashes(keys, c("xs", "timeout")), list(list(x1 = 1, x2 = 2, timeout = 1), list(x1 = 1, x2 = 3, timeout = 1)))

# two fields
keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), ys = list(list(y = 3), list(y = 4)))
expect_equal(rush$read_hashes(keys, c("xs", "ys")), list(list(x1 = 1, x2 = 2, y = 3), list(x1 = 1, x2 = 3, y = 4)))

# two fields with list and atomic
keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), timeout = c(1, 1))
expect_equal(rush$read_hashes(keys, c("xs", "timeout")), list(list(x1 = 1, x2 = 2, timeout = 1), list(x1 = 1, x2 = 3, timeout = 1)))

# two fields with list and recycled atomic
keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), timeout = 1)
expect_equal(rush$read_hashes(keys, c("xs", "timeout")), list(list(x1 = 1, x2 = 2, timeout = 1), list(x1 = 1, x2 = 3, timeout = 1)))

# two fields, one empty
keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), ys = list())
expect_equal(rush$read_hashes(keys, c("xs", "ys")), list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)))

# recycle
keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), ys = list(list(y = 3)))
expect_equal(rush$read_hashes(keys, c("xs", "ys")), list(list(x1 = 1, x2 = 2, y = 3), list(x1 = 1, x2 = 3, y = 3)))
})

test_that("writing hashes to specific keys works", {
skip_on_cran()
skip_on_ci()

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local")

# one element
keys = uuid::UUIDgenerate()
rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), keys = keys)
expect_equal(rush$read_hashes(keys, "xs"), list(list(x1 = 1, x2 = 2)))

# two elements
keys = uuid::UUIDgenerate(n = 2)
rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), keys = keys)
expect_equal(rush$read_hashes(keys, "xs"), list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)))

# wrong number of keys
keys = uuid::UUIDgenerate()
expect_error(rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 1, x2 = 3)), keys = keys), "Assertion on 'keys' failed")
})


test_that("writing list columns works", {
skip_on_cran()
skip_on_ci()

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local")

keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), xs_extra = list(list(extra = list("A"))))
rush$connector$command(c("LPUSH", "test-rush:finished_tasks", keys))

expect_list(rush$fetch_finished_tasks()$extra, len = 1)

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local")

keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2)), xs_extra = list(list(extra = list(letters[1:3]))))
rush$connector$command(c("LPUSH", "test-rush:finished_tasks", keys))

expect_list(rush$fetch_finished_tasks()$extra, len = 1)

config = start_flush_redis()
rush = RushWorker$new(network_id = "test-rush", config = config, host = "local")

keys = rush$write_hashes(xs = list(list(x1 = 1, x2 = 2), list(x1 = 2, x2 = 2)), xs_extra = list(list(extra = list("A")), list(extra = list("B"))))
rush$connector$command(c("LPUSH", "test-rush:finished_tasks", keys))
rush$read_hashes(keys, c("xs", "xs_extra"))

expect_list(rush$fetch_finished_tasks()$extra, len = 2)
})

# task evaluation --------------------------------------------------------------

test_that("evaluating a task works", {
Expand Down
Loading

0 comments on commit 8486a92

Please sign in to comment.