Skip to content

Commit

Permalink
feat: add safe methods for dictionary retrieval (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Sep 19, 2023
1 parent 5343fc5 commit 61941d6
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 12 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Collate:
'crate.R'
'cross_join.R'
'dictionary_sugar.R'
'dictionary_sugar_safe.R'
'did_you_mean.R'
'distinct_values.R'
'encapsulate.R'
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ export(deframe)
export(detect)
export(dictionary_sugar)
export(dictionary_sugar_get)
export(dictionary_sugar_get_safe)
export(dictionary_sugar_inc_get)
export(dictionary_sugar_inc_get_safe)
export(dictionary_sugar_inc_mget)
export(dictionary_sugar_inc_mget_safe)
export(dictionary_sugar_mget)
export(dictionary_sugar_mget_safe)
export(did_you_mean)
export(discard)
export(distinct_values)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3misc 0.12.0-9000

* Added safe methods for dictionary retrieval (#83)
* fix: Fixed an important bug that caused serialized objects to be overly large
when installing mlr3 with `--with-keep.source` (#88)

Expand Down
171 changes: 171 additions & 0 deletions R/dictionary_sugar_safe.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#' @title A Quick Way to Initialize Objects from Dictionaries
#'
#' @description
#' Given a [Dictionary], retrieve objects with provided keys.
#' * `dictionary_sugar_get_safe()` to retrieve a single object with key `.key`.
#' * `dictionary_sugar_mget_safe()` to retrieve a list of objects with keys `.keys`.
#' * If `.key` or `.keys` is missing, the dictionary itself is returned.
#' * Dictionary getters without the `_safe` suffix are discouraged as this sometimes caused unintended partial
#' argument matching.
#'
#' Arguments in `...` must be named and are consumed in the following order:
#'
#' 1. All arguments whose names match the name of an argument of the constructor
#' are passed to the `$get()` method of the [Dictionary] for construction.
#' 2. All arguments whose names match the name of a parameter of the [paradox::ParamSet] of the
#' constructed object are set as parameters. If there is no [paradox::ParamSet] in `obj$param_set`, this
#' step is skipped.
#' 3. All remaining arguments are assumed to be regular fields of the constructed R6 instance, and
#' are assigned via [`<-`].
#'
#' @param .dict ([Dictionary])\cr
#' The dictionary from which to retrieve the elements.
#' @param .key (`character(1)`)\cr
#' Key of the object to construct.
#' @param .keys (`character()`)\cr
#' Keys of the objects to construct.
#' @param ... (`any`)\cr
#' See description.
#' @return [R6::R6Class()]
#' @include dictionary_sugar.R
#' @export
#' @examples
#' library(R6)
#' item = R6Class("Item", public = list(x = 0))
#' d = Dictionary$new()
#' d$add("key", item)
#' dictionary_sugar_get_safe(d, "key", x = 2)
dictionary_sugar_get_safe = function(.dict, .key, ...) {
assert_class(.dict, "Dictionary")
if (missing(.key)) {
return(.dict)
}
assert_string(.key)
if (...length() == 0L) {
return(dictionary_get(.dict, .key))
}
dots = assert_list(list(...), .var.name = "additional arguments passed to Dictionary")
assert_list(dots[!is.na(names2(dots))], names = "unique", .var.name = "named arguments passed to Dictionary")

obj = dictionary_retrieve_item(.dict, .key)
if (length(dots) == 0L) {
return(assert_r6(dictionary_initialize_item(.key, obj)))
}

# pass args to constructor and remove them
constructor_args = get_constructor_formals(obj$value)
ii = is.na(names2(dots)) | names2(dots) %in% constructor_args
instance = assert_r6(dictionary_initialize_item(.key, obj, dots[ii]))
dots = dots[!ii]


# set params in ParamSet
if (length(dots) && exists("param_set", envir = instance, inherits = FALSE)) {
param_ids = instance$param_set$ids()
ii = names(dots) %in% param_ids
if (any(ii)) {
instance$param_set$values = insert_named(instance$param_set$values, dots[ii])
dots = dots[!ii]
}
} else {
param_ids = character()
}

# remaining args go into fields
if (length(dots)) {
ndots = names(dots)
for (i in seq_along(dots)) {
nn = ndots[[i]]
if (!exists(nn, envir = instance, inherits = FALSE)) {
stopf("Cannot set argument '%s' for '%s' (not a constructor argument, not a parameter, not a field.%s",
nn, class(instance)[1L], did_you_mean(nn, c(constructor_args, param_ids, fields(obj$value))))
}
instance[[nn]] = dots[[i]]
}
}

return(instance)
}


#' @rdname dictionary_sugar_get_safe
#' @export
dictionary_sugar_mget_safe = function(.dict, .keys, ...) {
if (missing(.keys)) {
return(.dict)
}
objs = lapply(.keys, dictionary_sugar_get_safe, .dict = .dict, ...)
if (!is.null(names(.keys))) {
nn = names2(.keys)
ii = which(!is.na(nn))
for (i in ii) {
objs[[i]]$id = nn[i]
}
}
names(objs) = map_chr(objs, "id")
objs
}
#' @title A Quick Way to Initialize Objects from Dictionaries with Incremented ID
#'
#' @description
#' Covenience wrapper around [dictionary_sugar_get_safe] and [dictionary_sugar_mget_safe] to allow easier avoidance of of ID
#' clashes which is useful when the same object is used multiple times and the ids have to be unique.
#' Let `<key>` be the key of the object to retrieve. When passing the `<key>_<n>` to this
#' function, where `<n>` is any natural numer, the object with key `<key>` is retrieved and the
#' suffix `_<n>` is appended to the id after the object is constructed.
#'
#' @param .dict ([Dictionary])\cr
#' Dictionary from which to retrieve an element.
#' @param .key (`character(1)`)\cr
#' Key of the object to construct - possibly with a suffix of the form `_<n>` which will be appended to the id.
#' @param .keys (`character()`)\cr
#' Keys of the objects to construct - possibly with suffixes of the form `_<n>` which will be appended to the ids.
#' @param ... (any)\cr
#' See description of [mlr3misc::dictionary_sugar_get_safe].
#'
#' @return An element from the dictionary.
#'
#' @examples
#' d = Dictionary$new()
#' d$add("a", R6::R6Class("A", public = list(id = "a")))
#' d$add("b", R6::R6Class("B", public = list(id = "c")))
#' obj1 = dictionary_sugar_inc_get_safe(d, "a_1")
#' obj1$id
#'
#' obj2 = dictionary_sugar_inc_get_safe(d, "b_1")
#' obj2$id
#'
#' objs = dictionary_sugar_inc_mget_safe(d, c("a_10", "b_2"))
#' map(objs, "id")
#'
#' @rdname dictionary_sugar_inc_get_safe
#' @export
dictionary_sugar_inc_get_safe = function(.dict, .key, ...) {
m = regexpr("_\\d+$", .key)
if (attr(m, "match.length") == -1L) {
return(dictionary_sugar_get_safe(.dict = .dict, .key = key, ...))
}
assert_true(!methods::hasArg("id"))
split = regmatches(.key, m, invert = NA)[[1L]]
newkey = split[[1L]]
suffix = split[[2L]]
obj = dictionary_sugar_get_safe(.dict = .dict, .key = newkey, ...)
obj$id = paste0(obj$id, suffix)
obj

}

#' @rdname dictionary_sugar_inc_get_safe
#' @export
dictionary_sugar_inc_mget_safe = function(.dict, .keys, ...) {
objs = lapply(.keys, dictionary_sugar_inc_get_safe, .dict = .dict, ...)
if (!is.null(names(.keys))) {
nn = names2(.keys)
ii = which(!is.na(nn))
for (i in ii) {
objs[[i]]$id = nn[i]
}
}
names(objs) = map_chr(objs, "id")
objs
}
55 changes: 55 additions & 0 deletions man/dictionary_sugar_get_safe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 48 additions & 0 deletions man/dictionary_sugar_inc_get_safe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 21 additions & 12 deletions tests/testthat/test_Dictionary.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,27 @@ test_that("Dictionary throws exception on unnamed args", {
expect_error(x$mget("a", "b"), "names")
})

test_that("dictionary_sugar_get", {
test_that("dictionary_sugar_get_safe", {
Foo = R6::R6Class("Foo", public = list(x = 0, y = 0, key = 0, initialize = function(y, key = -1) {
self$y = y
self$key = key
}), cloneable = TRUE)
d = Dictionary$new()
d$add("f1", Foo)
x = dictionary_sugar_get(d, "f1", y = 99, x = 1)
x = dictionary_sugar_get_safe(d, "f1", y = 99, x = 1)
expect_equal(x$x, 1)
expect_equal(x$y, 99)
expect_equal(x$key, -1)
x2 = dictionary_sugar_get(d, "f1", 99, x = 1)
x2 = dictionary_sugar_get_safe(d, "f1", 99, x = 1)
expect_equal(x, x2)
x2 = dictionary_sugar_get(d, "f1", x = 1, 99)
x2 = dictionary_sugar_get_safe(d, "f1", x = 1, 99)
expect_equal(x, x2)

x = dictionary_sugar_get(d, "f1", 1, 99)
x = dictionary_sugar_get_safe(d, "f1", 1, 99)
expect_equal(x$x, 0)
expect_equal(x$y, 1)
expect_equal(x$key, 99)
x2 = dictionary_sugar_get(d, "f1", key = 99, y = 1)
x2 = dictionary_sugar_get_safe(d, "f1", key = 99, y = 1)
expect_equal(x, x2)
})

Expand All @@ -84,13 +84,13 @@ test_that("mget", {
d = Dictionary$new()
d$add("f1", Foo)
d$add("f2", Foo)
x = dictionary_sugar_mget(d, "f1", y = 99, x = 1)
x = dictionary_sugar_mget_safe(d, "f1", y = 99, x = 1)
expect_list(x, len = 1, types = "Foo")
x = dictionary_sugar_mget(d, c("f1", "f2"), y = 99)
x = dictionary_sugar_mget_safe(d, c("f1", "f2"), y = 99)
expect_list(x, len = 2, types = "Foo")
expect_equal(ids(x), c("foo", "foo"))

x = dictionary_sugar_mget(d, c(a = "f1", b = "f2"), y = 99)
x = dictionary_sugar_mget_safe(d, c(a = "f1", b = "f2"), y = 99)
expect_list(x, len = 2, types = "Foo")
expect_equal(ids(x), c("a", "b"))
})
Expand All @@ -99,15 +99,15 @@ test_that("incrementing ids works", {
d = Dictionary$new()
d$add("a", R6Class("A", public = list(id = "a")))
d$add("b", R6Class("B", public = list(id = "c")))
obj1 = dictionary_sugar_inc_get(d, "a_1")
obj1 = dictionary_sugar_inc_get_safe(d, "a_1")
expect_r6(obj1, "A")
expect_true(obj1$id == "a_1")

obj2 = dictionary_sugar_inc_get(d, "b_1")
obj2 = dictionary_sugar_inc_get_safe(d, "b_1")
expect_r6(obj2, "B")
expect_true(obj2$id == "c_1")

objs = dictionary_sugar_inc_mget(d, c("a_10", "b_2"))
objs = dictionary_sugar_inc_mget_safe(d, c("a_10", "b_2"))
expect_r6(objs$a_10, "A")
expect_true(objs$a_10$id == "a_10")
expect_r6(objs$c_2, "B")
Expand All @@ -119,3 +119,12 @@ test_that("incrementing ids works", {
obj = dictionary_sugar_inc_get(d, "a")
expect_class(obj, "A")
})

test_that("avoid unintended partial argument matching", {
d = Dictionary$new()
A = R6Class("A", public = list(d = NULL))
d$add("a", function() A$new())
a = dictionary_sugar_get_safe(d, "a", d = 1)
expect_r6(a, "A")
expect_equal(a$d, 1)
})

0 comments on commit 61941d6

Please sign in to comment.