Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConfigSpace <-> ParamSet conversion #348

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ Imports:
Suggests:
knitr,
lhs,
testthat
testthat,
reticulate,
jsonlite
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
Expand All @@ -76,12 +78,14 @@ Collate:
'SamplerJointIndep.R'
'SamplerUnif.R'
'asserts.R'
'cs_ps.R'
'helper.R'
'domain.R'
'generate_design_grid.R'
'generate_design_lhs.R'
'generate_design_random.R'
'ps.R'
'ps_cs.R'
'reexports.R'
'to_tune.R'
'zzz.R'
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export(SamplerUnif)
export(as.data.table)
export(assert_param)
export(assert_param_set)
export(cs_to_ps)
export(generate_design_grid)
export(generate_design_lhs)
export(generate_design_random)
Expand All @@ -42,6 +43,7 @@ export(p_int)
export(p_lgl)
export(p_uty)
export(ps)
export(ps_to_cs)
export(to_tune)
import(checkmate)
import(data.table)
Expand Down
106 changes: 106 additions & 0 deletions R/cs_ps.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
get_type_ps = function(x) {
clx = c("UniformIntegerHyperparameter",
"NormalIntegerHyperparameter",
"UniformFloatHyperparameter",
"NormalFloatHyperparameter",
"CategoricalHyperparameter",
"OrdinalHyperparameter",
"Constant")
# we cannot represent ordered so we use the same as for categorical
tlx = c("uniform_int", "normal_int", "uniform_float", "normal_float", "categorical", "categorical", "constant")
tlx[match(x$`__class__`$`__name__`, clx)]
}



#' @title Map a ConfigSpace to a ParamSet
#'
#' @description
#' Maps a ConfigSpace (loaded via \CRANpkg{reticulate}) to a [ParamSet].
#' NormalIntegerHyperparameters are treated as UniformIntegerHyperparameter and the same holds for floats.
#' OrdinalHyperparameters are treated as CategoricalHyperparameters.
#' Constants are mapped to [ParamUty]s with tag `"constant"`.
#' Names are subject to changes via [base::make.names()].
#' q, weights and meta fields are ignored.
#'
#' @param cs (ConfigSpace).
#' @return [ParamSet]
#'
#' @export
#' @examples
#' # see ps_to_cs
cs_to_ps = function(cs) {
# FIXME: we could do some additional safety checks here
assert_true(cs$`__class__`$`__name__` == "ConfigurationSpace")
if (length(cs$get_forbiddens())) {
stop("Forbiddens are not implemented.")
}

# params
# if q or weights is set for any parameter, we cannot respect this
ps = ParamSet$new(mlr3misc::map(cs$get_hyperparameters(), .f = function(x) {
switch(get_type_ps(x),
"uniform_int" =
if (x$log) {
ParamDbl$new(make.names(x$name), lower = log(x$lower), upper = log(x$upper), default = log(x$default_value), tags = c("int", "log"))
} else {
ParamInt$new(make.names(x$name), lower = x$lower, upper = x$upper, default = x$default_value)
},
# we do not differentiate between uniform and normal ints
"normal_int" = {
warning("Normal ints are treated as uniform ints.")
if (x$log) {
ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = log(x$default_value), tags = c("int", "log"))
} else {
ParamInt$new(make.names(x$name), lower = -Inf, upper = Inf, default = x$default_value)
}
},
"uniform_float" =
if (x$log) {
ParamDbl$new(make.names(x$name), lower = log(x$lower), upper = log(x$upper), default = log(x$default_value), tags = "log")
} else {
ParamDbl$new(make.names(x$name), lower = x$lower, upper = x$upper, default = x$default_value)
},
# we do not differentiate between uniform and normal floats
"normal_float" = {
warning("Normal floats are treated as uniform floats.")
if (x$log) {
ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = log(x$default_value), tags = "log")
} else {
ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = x$default_value)
}
},
"categorical" =
if (every(x$choices, is.logical)) {
ParamLgl$new(make.names(x$name), default = x$default_value)
} else {
ParamFct$new(make.names(x$name), levels = unlist(x$choices), default = x$default_value)
},
"constant" =
ParamUty$new(make.names(x$name), default = x$default_value, tags = "constant")
)
}))

# trafo
ps$trafo = function(x, param_set) {
for (i in names(which(mlr3misc::map_lgl(param_set$tags, .f = function(tags) "log" %in% tags)))) {
x[[i]] = if ("int" %in% ps$params[[i]]$tags) as.integer(round(exp(x[[i]]))) else exp(x[[i]])
}
x
}

# deps
cnds = cs$get_conditions()
for (cnd in cnds) {
if ("InCondition" %in% cnd$`__class__`$`__name__`) {
ps$add_dep(make.names(cnd$child$name), on = make.names(cnd$parent$name), CondAnyOf$new(cnd$values))
} else if ("EqualsCondition" %in% cnd$`__class__`$`__name__`) {
ps$add_dep(make.names(cnd$child$name), on = make.names(cnd$parent$name), CondEqual$new(cnd$value))
} else {
stop("Not implemented.")
}
}

ps
}

170 changes: 170 additions & 0 deletions R/ps_cs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
get_type_cs = function(x) {
# FIXME: we cannot represent ParamUty
# FIXME: we may want to map ParamUty or ParamFct with a single possible value to ConfigSpace.hyperparameters.Constant
clx = c("ParamInt", "ParamDbl", "ParamLgl", "ParamFct")
tlx = c("uniform_int", "uniform_float", "categorical", "categorical")
tlx[match(class(x)[1L], clx)]
}

wrap_default = function(default, trafo = identity) {
checkmate::assert_function(trafo)
if (checkmate::test_r6(default, classes = "NoDefault")) {
NULL
} else {
trafo(default)
}
}

#' @title Map a ParamSet to a ConfigSpace
#'
#' @description
#' Maps a ParamSet to a ConfigSpace.
#' [ParamUty]s cannot be represented.
#' Transformation functions except for log transformations are NOT automatically handled (and ConfigSpace in general cannot do this).
#' To automatically handle a log transformation, set a `"log"` tag for the [Param],
#' if the [Param] would be a [ParamInt] after transformation, additionally set an `"int"` tag, see examples below.
#' Only [Condition]s of class [CondEqual] and [CondAnyOf] are supported.
#'
#' Requires \CRANpkg{reticulate} and \CRANpkg{jsonlite} (if saving in json format is desired) to be installed.
#'
#' @param ps [ParamSet].
#' @param json_file (`character(1)`). \cr
#' Optional filename ending with `".json"`.
#' If specified, the returned ConfigSpace is additionally saved in the json format.
#' Useful for using the ConfigSpace in a pure python session.
#'
#' @return ConfigSpace
#'
#' @export
#' @examples
#'\dontrun{
#'ps = ParamSet$new(list(
#' ParamDbl$new("x1", lower = log(10), upper = log(20), default = log(15), tags = c("int", "log")),
#' ParamInt$new("x2", lower = 10, upper = 20, default = 15),
#' ParamDbl$new("x3", lower = log(10), upper = log(20), default = log(15), tags = "log"),
#' ParamDbl$new("x4", lower = 10, upper = 20, default = 15),
#' ParamLgl$new("x5", default = TRUE),
#' ParamFct$new("x6", levels = c("a", "b", "c"), default = "c"))
#')
#'
#'ps$trafo = function(x, param_set) {
#' for (i in names(which(mlr3misc::map_lgl(param_set$tags, .f = function(tags) "log" %in% tags)))) {
#' x[[i]] = if ("int" %in% ps$params[[i]]$tags) as.integer(round(exp(x[[i]]))) else exp(x[[i]])
#' }
#' x
#'}
#'
#'ps$add_dep("x6", on = "x5", cond = CondEqual$new(TRUE))
#'ps$add_dep("x4", on = "x6", cond = CondAnyOf$new(c("a", "b")))
#'
#'cs = ps_to_cs(ps)
#'
#'dt_ps = data.table::rbindlist(generate_design_random(ps, n = 1000L)$transpose(filter_na = FALSE))
#'dt_cs = data.table::rbindlist(mlr3misc::map(cs$sample_configuration(1000L), function(x) {
#' x$get_dictionary()
#'}), fill = TRUE)
#'summary(dt_ps)
#'summary(dt_cs)
#'all(is.na(dt_ps[x5 == FALSE][["x6"]])) # first dependency
#'all(is.na(dt_cs[x5 == FALSE][["x6"]])) # first dependency
#'all(is.na(dt_ps[x6 == "c"][["x4"]])) # second dependency
#'all(is.na(dt_cs[x6 == "c"][["x4"]])) # second dependency
#'
#'ps_ = cs_to_ps(cs)
#'psparams = ps$params
#'ps_params = ps_$params
#'all.equal(psparams, ps_params[names(psparams)])
#'all.equal(ps$deps, ps_$deps)
#'# ps$trafo, ps_$trafo
#'dt_ps_ = data.table::rbindlist(generate_design_random(ps, n = 1000L)$transpose(filter_na = FALSE))
#'summary(dt_ps_)
#'all(is.na(dt_ps_[x5 == FALSE][["x6"]])) # first dependency
#'all(is.na(dt_ps_[x6 == "c"][["x4"]])) # second dependency
#'}
ps_to_cs = function(ps, json_file = NULL) {
# FIXME: could add an argument to ignore budget params (because most python optimizers do not use budget params in the cs
# FIXME: we could do some additional safety checks here
assert_param_set(ps)
if (!is.null(json_file)) {
assert_path_for_output(json_file)
assert_true(endsWith(json_file, suffix = ".json"))
}

requireNamespace("reticulate")
requireNamespace("jsonlite")

CS = reticulate::import("ConfigSpace", as = "CS")
CSH = reticulate::import("ConfigSpace.hyperparameters", as = "CSH")
json = reticulate::import("ConfigSpace.read_and_write.json")

cs = CS$ConfigurationSpace()

# params
for (i in seq_along(ps$params)) {
param = ps$params[[i]]
tmp = switch(get_type_cs(param),
"uniform_int" =
CSH$UniformIntegerHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = wrap_default(param$default)),
"uniform_float" = if (all(c("int", "log") %in% param$tags)) {
CSH$UniformIntegerHyperparameter(name = param$id, lower = as.integer(round(exp(param$lower))), upper = as.integer(round(exp(param$upper))), default_value = wrap_default(param$default, trafo = function(x) as.integer(round(exp(x)))), log = TRUE)
} else if ("log" %in% param$tags) {
CSH$UniformFloatHyperparameter(name = param$id, lower = exp(param$lower), upper = exp(param$upper), default_value = wrap_default(param$default, trafo = exp), log = TRUE)
} else {
CSH$UniformFloatHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = wrap_default(param$default))
},
"categorical" =
CSH$CategoricalHyperparameter(name = param$id, choices = param$levels, default_value = wrap_default(param$default)),
)
cs$add_hyperparameter(tmp)
}

# trafo
if (ps$has_trafo) {
warning("Only log trafos can be respected automatically. Please check your trafos.")
}

# FIXME: check
# deps, and_conditions treated separately
conditions = ps$deps
and_conditions = which(table(conditions$id) > 1L)
and_conditions = conditions[id %in% names(and_conditions)]
rest_conditions = conditions[id %nin% and_conditions$id]

for (id in unique(and_conditions$id)) {
ids = which(and_conditions$id == id)
conds = map(ids, function(x) {
child = cs$get_hyperparameter(and_conditions[x, id])
parent = cs$get_hyperparameter(and_conditions[x, on])
cond = and_conditions[x, cond][[1L]]
cnd = if (checkmate::test_r6(cond, classes = "CondAnyOf")) {
CS$InCondition(child = child, parent = parent, values = cond$rhs)
} else if (checkmate::test_r6(cond, classes = "CondEqual")) {
CS$EqualsCondition(child = child, parent = parent, value = cond$rhs)
} else {
stop("Not implemented.")
}
cnd
})
cnd = mlr3misc::invoke(CS$AndConjunction, .args = conds)
cs$add_condition(cnd)
}

# deps
for (i in seq_len(NROW(rest_conditions))) {
child = cs$get_hyperparameter(rest_conditions[i, id])
parent = cs$get_hyperparameter(rest_conditions[i, on])
cond = rest_conditions[i, cond][[1L]]
cnd = if (checkmate::test_r6(cond, classes = "CondAnyOf")) {
CS$InCondition(child = child, parent = parent, values = cond$rhs)
} else if (checkmate::test_r6(cond, classes = "CondEqual")) {
CS$EqualsCondition(child = child, parent = parent, value = cond$rhs)
} else {
stop("Not implemented.")
}
cs$add_condition(cnd)
}

if (!is.null(json_file)) write(json$write(cs), json_file)
cs
}

3 changes: 3 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@
backports::import(pkgname)
} # nocov end

# static code checks should not complain about commonly used data.table columns
utils::globalVariables(c("id", "on"))

leanify_package()
25 changes: 25 additions & 0 deletions man/cs_to_ps.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading