Skip to content

Commit

Permalink
feat: add data_format option to fetch methods
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 28, 2023
1 parent 6deb8bf commit 3fee22c
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 171 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import(checkmate)
import(data.table)
import(mlr3misc)
import(redux)
importFrom(jsonlite,fromJSON)
importFrom(processx,process)
importFrom(utils,str)
importFrom(uuid,UUIDgenerate)
219 changes: 132 additions & 87 deletions R/Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
#' @template param_heartbeat_expire
#' @template param_lgr_thresholds
#' @template param_lgr_buffer_size
#' @template param_data_format
#'
#' @export
Rush = R6::R6Class("Rush",
Expand Down Expand Up @@ -394,11 +395,16 @@ Rush = R6::R6Class("Rush",

#' @description
#' Stop workers and delete data stored in redis.
reset = function() {
#' @param type (`character(1)`)\cr
#' Type of stopping.
#' Either `"terminate"` or `"kill"`.
#' If `"terminate"` the workers evaluate the currently running task and then terminate.
#' If `"kill"` the workers are stopped immediately.
reset = function(type = "kill") {
r = self$connector

# stop workers
self$stop_workers(type = "kill")
self$stop_workers(type = type)

# reset fields set by starting workers
self$processes = NULL
Expand Down Expand Up @@ -438,7 +444,7 @@ Rush = R6::R6Class("Rush",

# reset counters and caches
private$.cached_results = data.table()
private$.cached_data = data.table()
private$.cached_tasks_dt = data.table()
private$.cached_worker_info = data.table()
private$.n_seen_results = 0

Expand All @@ -457,7 +463,7 @@ Rush = R6::R6Class("Rush",
cmds = map(worker_ids, function(worker_id) c("LRANGE", private$.get_worker_key("events", worker_id), 0, -1))
worker_logs = set_names(r$pipeline(.commands = cmds), worker_ids)
tab = rbindlist(set_names(map(worker_logs, function(logs) {
rbindlist(map(logs, jsonlite::fromJSON))
rbindlist(map(logs, fromJSON))
}), worker_ids), idcol = "worker_id")
if (nrow(tab)) setkeyv(tab, "timestamp")
tab[]
Expand Down Expand Up @@ -542,17 +548,25 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Latest results.
fetch_latest_results = function(fields = "ys") {
fetch_latest_results = function(fields = "ys", data_format = "data.table") {
assert_character(fields)
assert_choice(data_format, c("data.table", "list"))
r = self$connector

if (self$n_finished_tasks == private$.n_seen_results) return(data.table())
# return empty data.table or list if all results are fetched
if (self$n_finished_tasks == private$.n_seen_results) {
data = if (data_format == "list") list() else data.table()
return(data)
}
keys = r$command(c("LRANGE", private$.get_key("finished_tasks"), private$.n_seen_results, -1))

# increase seen results counter
private$.n_seen_results = private$.n_seen_results + length(keys)

# read results from hashes
rbindlist(self$read_hashes(keys, "ys"), use.names = TRUE, fill = TRUE)
data = self$read_hashes(keys, "ys")
if (data_format == "list") return(set_names(data, keys))
rbindlist(data, use.names = TRUE, fill = TRUE)
},

#' @description
Expand All @@ -566,11 +580,11 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Latest results.
block_latest_results = function(fields = "ys", timeout = Inf) {
wait_for_latest_results = function(fields = "ys", timeout = Inf, data_format = "data.table") {
start_time = Sys.time()
while(start_time + timeout > Sys.time()) {
latest_results = self$fetch_latest_results(fields)
if (nrow(latest_results)) break
latest_results = self$fetch_latest_results(fields, data_format = data_format)
if (length(latest_results)) break
Sys.sleep(0.01)
}
latest_results
Expand All @@ -588,24 +602,17 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`
#' Results.
fetch_results = function(fields = "ys", reset_cache = FALSE) {
r = self$connector
fetch_results = function(fields = "ys", reset_cache = FALSE, data_format = "data.table") {
assert_character(fields)
assert_flag(reset_cache)
if (reset_cache) private$.cached_results = data.table()

if (self$n_finished_tasks > nrow(private$.cached_results)) {
keys = r$command(c("LRANGE", private$.get_key("finished_tasks"), nrow(private$.cached_results), -1))

lg$debug("Caching %i result(s)", length(keys))
assert_choice(data_format, c("data.table", "list"))
r = self$connector

# cache results
results = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
results[, keys := unlist(keys)]
private$.cached_results = rbindlist(list(private$.cached_results, results))
if (data_format == "data.table") {
private$.fetch_cache_dt(fields, ".cached_results_dt", reset_cache)
} else {
private$.fetch_cache_list(fields, ".cached_results_list", reset_cache)
}

private$.cached_results
},

#' @description
Expand All @@ -617,16 +624,9 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Table of queued tasks.
fetch_queued_tasks = function(fields = c("xs", "xs_extra", "state")) {
r = self$connector
assert_character(fields)

fetch_queued_tasks = function(fields = c("xs", "xs_extra", "state"), data_format = "data.table") {
keys = self$queued_tasks
if (is.null(keys)) return(data.table())

data = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
data[]
private$.fetch_default(keys, fields, data_format)
},

#' @description
Expand All @@ -638,19 +638,19 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Table of queued priority tasks.
fetch_priority_tasks = function(fields = c("xs", "xs_extra", "state")) {
r = self$connector
fetch_priority_tasks = function(fields = c("xs", "xs_extra", "state"), data_format = "data.table") {
assert_character(fields)
assert_choice(data_format, c("data.table", "list"))
r = self$connector

cmds = map(self$worker_ids, function(worker_id) c("LRANGE", private$.get_worker_key("queued_tasks", worker_id), "0", "-1"))
if (!length(cmds)) return(data.table())
if (!length(cmds)) {
data = if (data_format == "list") list() else data.table()
return(data)
}

keys = unlist(r$pipeline(.commands = cmds))
if (is.null(keys)) return(data.table())

data = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
data[]
private$.fetch_default(keys, fields, data_format)
},

#' @description
Expand All @@ -662,16 +662,9 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Table of running tasks.
fetch_running_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "state")) {
r = self$connector
assert_character(fields)

fetch_running_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "state"), data_format = "data.table") {
keys = self$running_tasks
if (is.null(keys)) return(data.table())

data = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
data[]
private$.fetch_default(keys, fields, data_format)
},

#' @description
Expand All @@ -686,26 +679,18 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Table of finished tasks.
fetch_finished_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "ys", "ys_extra", "state"), reset_cache = FALSE) {
fetch_finished_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "ys", "ys_extra", "state"), reset_cache = FALSE, data_format = "data.table") {
r = self$connector
assert_character(fields)
assert_flag(reset_cache)
if (reset_cache) private$.cached_data = data.table()

if (self$n_finished_tasks > nrow(private$.cached_data)) {
keys = r$command(c("LRANGE", private$.get_key("finished_tasks"), nrow(private$.cached_data), -1))
assert_choice(data_format, c("data.table", "list"))

lg$debug("Caching %i finished task(s)", length(keys))

# cache results
data = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
private$.cached_data = rbindlist(list(private$.cached_data, data), use.names = TRUE, fill = TRUE)
if (data_format == "data.table") {
private$.fetch_cache_dt(fields, cache = ".cached_tasks_dt", reset_cache)
} else {
private$.fetch_cache_list(fields, cache = ".cached_tasks_list", reset_cache)
}

private$.cached_data
},

#' @description
#' Block process until a new finished task is available.
#' Returns all finished tasks or `NULL` if no new task is available after `timeout` seconds.
Expand All @@ -718,10 +703,15 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Table of finished tasks.
block_finished_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "ys", "ys_extra", "state"), timeout = Inf) {
wait_for_finished_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "ys", "ys_extra", "state"), timeout = Inf, data_format = "data.table") {
start_time = Sys.time()

while(start_time + timeout > Sys.time()) {
if (self$n_finished_tasks > nrow(private$.cached_data)) return(self$fetch_finished_tasks(fields))
if (data_format == "data.table" && self$n_finished_tasks > nrow(private$.cached_tasks_dt)) {
return(self$fetch_finished_tasks(fields, data_format = data_format))
} else if (data_format == "list" && self$n_finished_tasks > length(private$.cached_tasks_list)) {
return(self$fetch_finished_tasks(fields, data_format = data_format))
}
Sys.sleep(0.01)
}
NULL
Expand All @@ -736,16 +726,9 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Table of failed tasks.
fetch_failed_tasks = function(fields = c("xs", "worker_extra", "condition", "state")) {
r = self$connector
assert_character(fields)

fetch_failed_tasks = function(fields = c("xs", "worker_extra", "condition", "state"), data_format = "data.table") {
keys = self$failed_tasks
if (is.null(keys)) return(data.table())

data = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
data[]
private$.fetch_default(keys, fields, data_format)
},

#' @description
Expand All @@ -757,16 +740,9 @@ Rush = R6::R6Class("Rush",
#'
#' @return `data.table()`\cr
#' Table of all tasks.
fetch_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "ys", "ys_extra", "condition", "state")) {
r = self$connector
assert_character(fields)

fetch_tasks = function(fields = c("xs", "xs_extra", "worker_extra", "ys", "ys_extra", "condition", "state"), data_format = "data.table") {
keys = self$tasks
if (is.null(keys)) return(data.table())

data = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
data[]
private$.fetch_default(keys, fields, data_format)
},

#' @description
Expand All @@ -778,7 +754,7 @@ Rush = R6::R6Class("Rush",
#' @param detect_lost_tasks (`logical(1)`)\cr
#' Whether to detect failed tasks.
#' Comes with an overhead.
await_tasks = function(keys, detect_lost_tasks = FALSE) {
wait_for_tasks = function(keys, detect_lost_tasks = FALSE) {
assert_character(keys, min.len = 1)
assert_flag(detect_lost_tasks)

Expand Down Expand Up @@ -1024,7 +1000,7 @@ Rush = R6::R6Class("Rush",
data = function(rhs) {
assert_ro_binding(rhs)
self$fetch_finished_tasks()
private$.cached_data
private$.cached_tasks_dt
},

#' @field worker_info ([data.table::data.table()])\cr
Expand Down Expand Up @@ -1092,18 +1068,28 @@ Rush = R6::R6Class("Rush",

private = list(

# cache of the finished tasks and results
# we split the data.table and list caches so that a key can be set on the data.table
.cached_results = data.table(),

.cached_data = data.table(),
.cached_results_list = list(),

.cached_tasks_dt = data.table(),

.cached_tasks_list = list(),

# cache of the worker info which usually does not change after starting the workers
.cached_worker_info = data.table(),

# counter of the seen results for the latest results methods
.n_seen_results = 0,

# cached pid_exists function
.pid_exists = NULL,

.snapshot_schedule = NULL,

#
.hostname = NULL,

# prefix key with instance id
Expand Down Expand Up @@ -1156,6 +1142,65 @@ Rush = R6::R6Class("Rush",

# serialize and push arguments to redis
r$command(list("SET", private$.get_key("start_args"), redux::object_to_bin(start_args)))
},

# fetch tasks
.fetch_default = function(keys, fields, data_format = "data.table") {
r = self$connector
assert_character(fields)
assert_choice(data_format, c("data.table", "list"))

if (is.null(keys)) {
data = if (data_format == "list") list() else data.table()
return(data)
}

data = self$read_hashes(keys, fields)
if (data_format == "list") return(set_names(data, keys))
data = rbindlist(data, use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
data[]
},

# fetch and cache tasks as data.table
.fetch_cache_dt = function(fields, cache, reset_cache = FALSE) {
r = self$connector
if (reset_cache) private[[cache]] = data.table()

if (self$n_finished_tasks > nrow(private[[cache]])) {

# get keys of new results
keys = r$command(c("LRANGE", private$.get_key("finished_tasks"), nrow(private[[cache]]), -1))

lg$debug("Caching %i finished task(s)", length(keys))

# bind new results to cached results
data = rbindlist(self$read_hashes(keys, fields), use.names = TRUE, fill = TRUE)
data[, keys := unlist(keys)]
private[[cache]] = rbindlist(list(private[[cache]], data), use.names = TRUE, fill = TRUE)
}

private[[cache]]
},

# fetch and cache tasks as list
.fetch_cache_list = function(fields, cache, reset_cache = FALSE) {
r = self$connector
if (reset_cache) private[[cache]] = list()

if (self$n_finished_tasks > length(private[[cache]])) {

# get keys of new results
keys = r$command(c("LRANGE", private$.get_key("finished_tasks"), length(private[[cache]]), -1))

lg$debug("Caching %i finished task(s)", length(keys))

# bind new results to cached results
data = set_names(self$read_hashes(keys, fields), keys)
private[[cache]] = c(private[[cache]], data)
}

private[[cache]]
}
)
)
Expand Down
Loading

0 comments on commit 3fee22c

Please sign in to comment.