From 1b631a2d3d6dd3f6f9f6fd26fa1ec97d6e3e8a4b Mon Sep 17 00:00:00 2001 From: sumny Date: Tue, 20 Apr 2021 14:20:59 +0200 Subject: [PATCH 1/6] cs to ps wrapper --- R/cs_ps.R | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 R/cs_ps.R diff --git a/R/cs_ps.R b/R/cs_ps.R new file mode 100644 index 00000000..8e65ba0e --- /dev/null +++ b/R/cs_ps.R @@ -0,0 +1,89 @@ +get_type = function(x) { + clx = c("ConfigSpace.hyperparameters.UniformIntegerHyperparameter", + "ConfigSpace.hyperparameters.NormalIntegerHyperparameter", + "ConfigSpace.hyperparameters.UniformFloatHyperparameter", + "ConfigSpace.hyperparameters.NormalFloatHyperparameter", + "ConfigSpace.hyperparameters.CategoricalHyperparameter", + "ConfigSpace.hyperparameters.OrdinalHyperparameter", + "ConfigSpace.hyperparameters.Constant") + # FIXME: we cannot represent ordered so we use the same as for categorical + tlx = c("uniform_int", "normal_int", "uniform_float", "normal_float", "categorical", "categorical", "constant") + tlx[match(class(x)[1L], clx)] +} + + +### cs (ConfigurationSpace initialized via reticulate in python) +cs_to_ps = function(cs) { + # params + # FIXME: if q or weights is set for any parameter, we cannot respect this + ps = ParamSet$new(map(cs$get_hyperparameters(), .f = function(x) { + switch(get_type(x), + "uniform_int" = + if (x$log) { + ParamDbl$new(make.names(x$name), lower = log(x$lower), upper = log(x$upper), default = log(x$default_value), tags = c("log", "int")) + } else { + ParamInt$new(make.names(x$name), lower = x$lower, upper = x$upper, default = x$default_value) + }, + # FIXME: we do not differentiate between uniform and normal ints + "normal_int" = + if (x$log) { + ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = log(x$default_value), tags = c("log", "int")) + } else { + ParamInt$new(make.names(x$name), lower = -Inf, upper = Inf, default = x$default_value) + }, + "uniform_float" = + if (x$log) { + ParamDbl$new(make.names(x$name), lower = log(x$lower), upper = log(x$upper), default = log(x$default_value), tags = "log") + } else { + ParamDbl$new(make.names(x$name), lower = x$lower, upper = x$upper, default = x$default_value) + }, + # FIXME: we do not differentiate between uniform and normal floats + "normal_float" = + if (x$log) { + ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = log(x$default_value), tags = "log") + } else { + ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = x$default_value) + }, + "categorical" = + if (length(x$choices) == 1L) { + ParamUty$new(make.names(x$name), default = unlist(x$choices), tags = "constant") + } else { + if (every(x$choices, is.logical)) { + ParamLgl$new(make.names(x$name), default = x$default_value) + } else { + ParamFct$new(make.names(x$name), levels = unlist(x$choices), default = x$default_value) + } + }, + "constant" = + ParamUty$new(make.names(x$name), default = x$default_value, tags = "constant") + ) + })) + + # trafo + ps$trafo = function(x, param_set) { + for(i in names(which(map_lgl(param_set$tags, .f = function(tags) "log" %in% tags)))) { + x[[i]] = if ("int" %in% ps$params[[i]]$tags) as.integer(round(exp(x[[i]]))) else exp(x[[i]]) + } + x + } + + # deps + cnds = cs$get_conditions() + for(cnd in cnds) { + if ("ConfigSpace.conditions.InCondition" %in% class(cnd)) { + ps$add_dep(make.names(cnd$child$name), on = make.names(cnd$parent$name), CondAnyOf$new(cnd$values)) + } else if ("ConfigSpace.conditions.EqualsCondition" %in% class(cdn)) { + ps$add_dep(make.names(cnd$child$name), on = make.names(cnd$parent$name), CondEqual$new(cnd$value)) + } else { + stop("Not implemented.") + } + } + + if (length(cs$get_forbiddens())) { + stop("Not implemented.") + } + + ps + +} + From 6505aaea306fc2617af4089508f7ab1d989e05ef Mon Sep 17 00:00:00 2001 From: sumny Date: Wed, 21 Apr 2021 13:41:16 +0200 Subject: [PATCH 2/6] first version of ps <-> cs --- DESCRIPTION | 6 ++- NAMESPACE | 2 + R/cs_ps.R | 93 ++++++++++++++++++++-------------- R/ps_cs.R | 132 ++++++++++++++++++++++++++++++++++++++++++++++++ R/zzz.R | 3 ++ man/cs_to_ps.Rd | 25 +++++++++ man/ps_to_cs.Rd | 73 ++++++++++++++++++++++++++ 7 files changed, 295 insertions(+), 39 deletions(-) create mode 100644 R/ps_cs.R create mode 100644 man/cs_to_ps.Rd create mode 100644 man/ps_to_cs.Rd diff --git a/DESCRIPTION b/DESCRIPTION index bea05e9a..482dc862 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -51,7 +51,9 @@ Imports: Suggests: knitr, lhs, - testthat + testthat, + reticulate, + jsonlite Encoding: UTF-8 Config/testthat/edition: 3 Config/testthat/parallel: false @@ -76,12 +78,14 @@ Collate: 'SamplerJointIndep.R' 'SamplerUnif.R' 'asserts.R' + 'cs_ps.R' 'helper.R' 'domain.R' 'generate_design_grid.R' 'generate_design_lhs.R' 'generate_design_random.R' 'ps.R' + 'ps_cs.R' 'reexports.R' 'to_tune.R' 'zzz.R' diff --git a/NAMESPACE b/NAMESPACE index 6153cbcd..f87f8006 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -33,6 +33,7 @@ export(SamplerUnif) export(as.data.table) export(assert_param) export(assert_param_set) +export(cs_to_ps) export(generate_design_grid) export(generate_design_lhs) export(generate_design_random) @@ -42,6 +43,7 @@ export(p_int) export(p_lgl) export(p_uty) export(ps) +export(ps_to_cs) export(to_tune) import(checkmate) import(data.table) diff --git a/R/cs_ps.R b/R/cs_ps.R index 8e65ba0e..568f4b5d 100644 --- a/R/cs_ps.R +++ b/R/cs_ps.R @@ -1,58 +1,80 @@ -get_type = function(x) { - clx = c("ConfigSpace.hyperparameters.UniformIntegerHyperparameter", - "ConfigSpace.hyperparameters.NormalIntegerHyperparameter", - "ConfigSpace.hyperparameters.UniformFloatHyperparameter", - "ConfigSpace.hyperparameters.NormalFloatHyperparameter", - "ConfigSpace.hyperparameters.CategoricalHyperparameter", - "ConfigSpace.hyperparameters.OrdinalHyperparameter", - "ConfigSpace.hyperparameters.Constant") - # FIXME: we cannot represent ordered so we use the same as for categorical +get_type_ps = function(x) { + clx = c("UniformIntegerHyperparameter", + "NormalIntegerHyperparameter", + "UniformFloatHyperparameter", + "NormalFloatHyperparameter", + "CategoricalHyperparameter", + "OrdinalHyperparameter", + "Constant") + # we cannot represent ordered so we use the same as for categorical tlx = c("uniform_int", "normal_int", "uniform_float", "normal_float", "categorical", "categorical", "constant") - tlx[match(class(x)[1L], clx)] + tlx[match(x$`__class__`$`__name__`, clx)] } -### cs (ConfigurationSpace initialized via reticulate in python) + +#' @title Map a ConfigSpace to a ParamSet +#' +#' @description +#' Maps a ConfigSpace (loaded via \CRANpkg{reticulate}) to a [ParamSet]. +#' NormalIntegerHyperparameters are treated as UniformIntegerHyperparameter and the same holds for floats. +#' OrdinalHyperparameters are treated as CategoricalHyperparameters. +#' Constants are mapped to [ParamUty]s with tag `"constant"`. +#' Names are subject to changes via [base::make.names()]. +#' q, weights and meta fields are ignored. +#' +#' @param cs (ConfigSpace). +#' @return [ParamSet] +#' +#' @export +#' @examples +#' # see ps_to_cs cs_to_ps = function(cs) { + # FIXME: we could do some additional safety checks here + assert_true(cs$`__class__`$`__name__` == "ConfigurationSpace") + if (length(cs$get_forbiddens())) { + stop("Forbiddens are not implemented.") + } + # params - # FIXME: if q or weights is set for any parameter, we cannot respect this - ps = ParamSet$new(map(cs$get_hyperparameters(), .f = function(x) { - switch(get_type(x), + # if q or weights is set for any parameter, we cannot respect this + ps = ParamSet$new(mlr3misc::map(cs$get_hyperparameters(), .f = function(x) { + switch(get_type_ps(x), "uniform_int" = if (x$log) { - ParamDbl$new(make.names(x$name), lower = log(x$lower), upper = log(x$upper), default = log(x$default_value), tags = c("log", "int")) + ParamDbl$new(make.names(x$name), lower = log(x$lower), upper = log(x$upper), default = log(x$default_value), tags = c("int", "log")) } else { ParamInt$new(make.names(x$name), lower = x$lower, upper = x$upper, default = x$default_value) }, - # FIXME: we do not differentiate between uniform and normal ints - "normal_int" = + # we do not differentiate between uniform and normal ints + "normal_int" = { + warning("Normal ints are treated as uniform ints.") if (x$log) { - ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = log(x$default_value), tags = c("log", "int")) + ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = log(x$default_value), tags = c("int", "log")) } else { ParamInt$new(make.names(x$name), lower = -Inf, upper = Inf, default = x$default_value) - }, + } + }, "uniform_float" = if (x$log) { ParamDbl$new(make.names(x$name), lower = log(x$lower), upper = log(x$upper), default = log(x$default_value), tags = "log") } else { ParamDbl$new(make.names(x$name), lower = x$lower, upper = x$upper, default = x$default_value) }, - # FIXME: we do not differentiate between uniform and normal floats - "normal_float" = + # we do not differentiate between uniform and normal floats + "normal_float" = { + warning("Normal floats are treated as uniform floats.") if (x$log) { ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = log(x$default_value), tags = "log") } else { ParamDbl$new(make.names(x$name), lower = -Inf, upper = Inf, default = x$default_value) - }, + } + }, "categorical" = - if (length(x$choices) == 1L) { - ParamUty$new(make.names(x$name), default = unlist(x$choices), tags = "constant") + if (every(x$choices, is.logical)) { + ParamLgl$new(make.names(x$name), default = x$default_value) } else { - if (every(x$choices, is.logical)) { - ParamLgl$new(make.names(x$name), default = x$default_value) - } else { - ParamFct$new(make.names(x$name), levels = unlist(x$choices), default = x$default_value) - } + ParamFct$new(make.names(x$name), levels = unlist(x$choices), default = x$default_value) }, "constant" = ParamUty$new(make.names(x$name), default = x$default_value, tags = "constant") @@ -61,7 +83,7 @@ cs_to_ps = function(cs) { # trafo ps$trafo = function(x, param_set) { - for(i in names(which(map_lgl(param_set$tags, .f = function(tags) "log" %in% tags)))) { + for (i in names(which(mlr3misc::map_lgl(param_set$tags, .f = function(tags) "log" %in% tags)))) { x[[i]] = if ("int" %in% ps$params[[i]]$tags) as.integer(round(exp(x[[i]]))) else exp(x[[i]]) } x @@ -69,21 +91,16 @@ cs_to_ps = function(cs) { # deps cnds = cs$get_conditions() - for(cnd in cnds) { - if ("ConfigSpace.conditions.InCondition" %in% class(cnd)) { + for (cnd in cnds) { + if ("InCondition" %in% cnd$`__class__`$`__name__`) { ps$add_dep(make.names(cnd$child$name), on = make.names(cnd$parent$name), CondAnyOf$new(cnd$values)) - } else if ("ConfigSpace.conditions.EqualsCondition" %in% class(cdn)) { + } else if ("EqualsCondition" %in% cnd$`__class__`$`__name__`) { ps$add_dep(make.names(cnd$child$name), on = make.names(cnd$parent$name), CondEqual$new(cnd$value)) } else { stop("Not implemented.") } } - if (length(cs$get_forbiddens())) { - stop("Not implemented.") - } - ps - } diff --git a/R/ps_cs.R b/R/ps_cs.R new file mode 100644 index 00000000..e3844387 --- /dev/null +++ b/R/ps_cs.R @@ -0,0 +1,132 @@ +get_type_cs = function(x) { + # FIXME: we cannot represent ParamUty + # FIXME: we may want to map ParamUty or ParamFct with a single possible value to ConfigSpace.hyperparameters.Constant + clx = c("ParamInt", "ParamDbl", "ParamLgl", "ParamFct") + tlx = c("uniform_int", "uniform_float", "categorical", "categorical") + tlx[match(class(x)[1L], clx)] +} + +#' @title Map a ParamSet to a ConfigSpace +#' +#' @description +#' Maps a ParamSet to a ConfigSpace. +#' [ParamUty]s cannot be represented. +#' Transformation functions except for log transformations are NOT automatically handled (and ConfigSpace in general cannot do this). +#' To automatically handle a log transformation, set a `"log"` tag for the [Param], +#' if the [Param] would be a [ParamInt] after transformation, additionally set an `"int"` tag, see examples below. +#' Only [Condition]s of class [CondEqual] and [CondAnyOf] are supported. +#' +#' Requires \CRANpkg{reticulate} and \CRANpkg{jsonlite} (if saving in json format is desired) to be installed. +#' +#' @param ps [ParamSet]. +#' @param json_file (`character(1)`). \cr +#' Optional filename ending with `".json"`. +#' If specified, the returned ConfigSpace is additionally saved in the json format. +#' Useful for using the ConfigSpace in a pure python session. +#' +#' @return ConfigSpace +#' +#' @export +#' @examples +#'ps = ParamSet$new(list( +#' ParamDbl$new("x1", lower = log(10), upper = log(20), default = log(15), tags = c("int", "log")), +#' ParamInt$new("x2", lower = 10, upper = 20, default = 15), +#' ParamDbl$new("x3", lower = log(10), upper = log(20), default = log(15), tags = "log"), +#' ParamDbl$new("x4", lower = 10, upper = 20, default = 15), +#' ParamLgl$new("x5", default = TRUE), +#' ParamFct$new("x6", levels = c("a", "b", "c"), default = "c")) +#') +#' +#'ps$trafo = function(x, param_set) { +#' for (i in names(which(mlr3misc::map_lgl(param_set$tags, .f = function(tags) "log" %in% tags)))) { +#' x[[i]] = if ("int" %in% ps$params[[i]]$tags) as.integer(round(exp(x[[i]]))) else exp(x[[i]]) +#' } +#' x +#'} +#' +#'ps$add_dep("x6", on = "x5", cond = CondEqual$new(TRUE)) +#'ps$add_dep("x4", on = "x6", cond = CondAnyOf$new(c("a", "b"))) +#' +#'cs = ps_to_cs(ps) +#' +#'dt_ps = data.table::rbindlist(generate_design_random(ps, n = 1000L)$transpose(filter_na = FALSE)) +#'dt_cs = data.table::rbindlist(mlr3misc::map(cs$sample_configuration(1000L), function(x) { +#' x$get_dictionary() +#'}), fill = TRUE) +#'summary(dt_ps) +#'summary(dt_cs) +#'all(is.na(dt_ps[x5 == FALSE][["x6"]])) # first dependency +#'all(is.na(dt_cs[x5 == FALSE][["x6"]])) # first dependency +#'all(is.na(dt_ps[x6 == "c"][["x4"]])) # second dependency +#'all(is.na(dt_cs[x6 == "c"][["x4"]])) # second dependency +#' +#'ps_ = cs_to_ps(cs) +#'psparams = ps$params +#'ps_params = ps_$params +#'all.equal(psparams, ps_params[names(psparams)]) +#'all.equal(ps$deps, ps_$deps) +#'# ps$trafo, ps_$trafo +#'dt_ps_ = data.table::rbindlist(generate_design_random(ps, n = 1000L)$transpose(filter_na = FALSE)) +#'summary(dt_ps_) +#'all(is.na(dt_ps_[x5 == FALSE][["x6"]])) # first dependency +#'all(is.na(dt_ps_[x6 == "c"][["x4"]])) # second dependency +ps_to_cs = function(ps, json_file = NULL) { + # FIXME: we could do some additional safety checks here + assert_param_set(ps) + if (!is.null(json_file)) { + assert_path_for_output(json_file) + assert_true(endsWith(json_file, suffix = ".json")) + } + + requireNamespace("reticulate") + requireNamespace("jsonlite") + + CS = reticulate::import("ConfigSpace", as = "CS") + CSH = reticulate::import("ConfigSpace.hyperparameters", as = "CSH") + json = reticulate::import("ConfigSpace.read_and_write.json") + + cs = CS$ConfigurationSpace() + + # params + for (i in seq_along(ps$params)) { + param = ps$params[[i]] + tmp = switch(get_type_cs(param), + "uniform_int" = + CSH$UniformIntegerHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = param$default), + "uniform_float" = if (all(c("int", "log") %in% param$tags)) { + CSH$UniformIntegerHyperparameter(name = param$id, lower = as.integer(round(exp(param$lower))), upper = as.integer(round(exp(param$upper))), default_value = as.integer(round(exp(param$default))), log = TRUE) + } else if ("log" %in% param$tags) { + CSH$UniformFloatHyperparameter(name = param$id, lower = exp(param$lower), upper = exp(param$upper), default_value = exp(param$default), log = TRUE) + } else { + CSH$UniformFloatHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = param$default) + }, + "categorical" = + CSH$CategoricalHyperparameter(name = param$id, choices = param$levels, default_value = param$default), + ) + cs$add_hyperparameter(tmp) + } + + # trafo + if (ps$has_trafo) { + warning("Only log trafos can be respected automatically. Please check your trafos.") + } + + # deps + for (i in seq_len(NROW(ps$deps))) { + child = cs$get_hyperparameter(ps$deps[i, id]) + parent = cs$get_hyperparameter(ps$deps[i, on]) + cond = ps$deps[i, cond][[1L]] + cnd = if (checkmate::test_r6(cond, classes = "CondAnyOf")) { + CS$InCondition(child = child, parent = parent, values = cond$rhs) + } else if (checkmate::test_r6(cond, classes = "CondEqual")) { + CS$EqualsCondition(child = child, parent = parent, value = cond$rhs) + } else { + stop("Not implemented.") + } + cs$add_condition(cnd) + } + + if (!is.null(json_file)) write(json$write(cs), json_file) + cs +} + diff --git a/R/zzz.R b/R/zzz.R index 4771bc1c..aae97ccf 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -11,4 +11,7 @@ backports::import(pkgname) } # nocov end +# static code checks should not complain about commonly used data.table columns +utils::globalVariables(c("id", "on")) + leanify_package() diff --git a/man/cs_to_ps.Rd b/man/cs_to_ps.Rd new file mode 100644 index 00000000..22bb27ae --- /dev/null +++ b/man/cs_to_ps.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cs_ps.R +\name{cs_to_ps} +\alias{cs_to_ps} +\title{Map a ConfigSpace to a ParamSet} +\usage{ +cs_to_ps(cs) +} +\arguments{ +\item{cs}{(ConfigSpace).} +} +\value{ +\link{ParamSet} +} +\description{ +Maps a ConfigSpace (loaded via \CRANpkg{reticulate}) to a \link{ParamSet}. +NormalIntegerHyperparameters are treated as UniformIntegerHyperparameter and the same holds for floats. +OrdinalHyperparameters are treated as CategoricalHyperparameters. +Constants are mapped to \link{ParamUty}s with tag \code{"constant"}. +Names are subject to changes via \code{\link[base:make.names]{base::make.names()}}. +q, weights and meta fields are ignored. +} +\examples{ +# see ps_to_cs +} diff --git a/man/ps_to_cs.Rd b/man/ps_to_cs.Rd new file mode 100644 index 00000000..8e708755 --- /dev/null +++ b/man/ps_to_cs.Rd @@ -0,0 +1,73 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ps_cs.R +\name{ps_to_cs} +\alias{ps_to_cs} +\title{Map a ParamSet to a ConfigSpace} +\usage{ +ps_to_cs(ps, json_file = NULL) +} +\arguments{ +\item{ps}{\link{ParamSet}.} + +\item{json_file}{(\code{character(1)}). \cr +Optional filename ending with \code{".json"}. +If specified, the returned ConfigSpace is additionally saved in the json format. +Useful for using the ConfigSpace in a pure python session.} +} +\value{ +ConfigSpace +} +\description{ +Maps a ParamSet to a ConfigSpace. +\link{ParamUty}s cannot be represented. +Transformation functions except for log transformations are NOT automatically handled (and ConfigSpace in general cannot do this). +To automatically handle a log transformation, set a \code{"log"} tag for the \link{Param}, +if the \link{Param} would be a \link{ParamInt} after transformation, additionally set an \code{"int"} tag, see examples below. +Only \link{Condition}s of class \link{CondEqual} and \link{CondAnyOf} are supported. + +Requires \CRANpkg{reticulate} and \CRANpkg{jsonlite} (if saving in json format is desired) to be installed. +} +\examples{ +ps = ParamSet$new(list( + ParamDbl$new("x1", lower = log(10), upper = log(20), default = log(15), tags = c("int", "log")), + ParamInt$new("x2", lower = 10, upper = 20, default = 15), + ParamDbl$new("x3", lower = log(10), upper = log(20), default = log(15), tags = "log"), + ParamDbl$new("x4", lower = 10, upper = 20, default = 15), + ParamLgl$new("x5", default = TRUE), + ParamFct$new("x6", levels = c("a", "b", "c"), default = "c")) +) + +ps$trafo = function(x, param_set) { + for (i in names(which(mlr3misc::map_lgl(param_set$tags, .f = function(tags) "log" \%in\% tags)))) { + x[[i]] = if ("int" \%in\% ps$params[[i]]$tags) as.integer(round(exp(x[[i]]))) else exp(x[[i]]) + } + x +} + +ps$add_dep("x6", on = "x5", cond = CondEqual$new(TRUE)) +ps$add_dep("x4", on = "x6", cond = CondAnyOf$new(c("a", "b"))) + +cs = ps_to_cs(ps) + +dt_ps = data.table::rbindlist(generate_design_random(ps, n = 1000L)$transpose(filter_na = FALSE)) +dt_cs = data.table::rbindlist(mlr3misc::map(cs$sample_configuration(1000L), function(x) { + x$get_dictionary() +}), fill = TRUE) +summary(dt_ps) +summary(dt_cs) +all(is.na(dt_ps[x5 == FALSE][["x6"]])) # first dependency +all(is.na(dt_cs[x5 == FALSE][["x6"]])) # first dependency +all(is.na(dt_ps[x6 == "c"][["x4"]])) # second dependency +all(is.na(dt_cs[x6 == "c"][["x4"]])) # second dependency + +ps_ = cs_to_ps(cs) +psparams = ps$params +ps_params = ps_$params +all.equal(psparams, ps_params[names(psparams)]) +all.equal(ps$deps, ps_$deps) +# ps$trafo, ps_$trafo +dt_ps_ = data.table::rbindlist(generate_design_random(ps, n = 1000L)$transpose(filter_na = FALSE)) +summary(dt_ps_) +all(is.na(dt_ps_[x5 == FALSE][["x6"]])) # first dependency +all(is.na(dt_ps_[x6 == "c"][["x4"]])) # second dependency +} From 6494faf150c3df62ba17df0e2e412929ddaefec6 Mon Sep 17 00:00:00 2001 From: sumny Date: Wed, 21 Apr 2021 13:55:04 +0200 Subject: [PATCH 3/6] don't run example --- R/ps_cs.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/ps_cs.R b/R/ps_cs.R index e3844387..272e0311 100644 --- a/R/ps_cs.R +++ b/R/ps_cs.R @@ -28,6 +28,7 @@ get_type_cs = function(x) { #' #' @export #' @examples +#'\dontrun{ #'ps = ParamSet$new(list( #' ParamDbl$new("x1", lower = log(10), upper = log(20), default = log(15), tags = c("int", "log")), #' ParamInt$new("x2", lower = 10, upper = 20, default = 15), @@ -70,6 +71,7 @@ get_type_cs = function(x) { #'summary(dt_ps_) #'all(is.na(dt_ps_[x5 == FALSE][["x6"]])) # first dependency #'all(is.na(dt_ps_[x6 == "c"][["x4"]])) # second dependency +#'} ps_to_cs = function(ps, json_file = NULL) { # FIXME: we could do some additional safety checks here assert_param_set(ps) From aab88fcb1c302aefaf60307d7588223130e93194 Mon Sep 17 00:00:00 2001 From: sumny Date: Mon, 10 May 2021 09:44:52 +0200 Subject: [PATCH 4/6] default wrapper --- R/ps_cs.R | 19 ++++++++++++++----- man/ps_to_cs.Rd | 2 ++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/R/ps_cs.R b/R/ps_cs.R index 272e0311..9830ed2f 100644 --- a/R/ps_cs.R +++ b/R/ps_cs.R @@ -6,6 +6,15 @@ get_type_cs = function(x) { tlx[match(class(x)[1L], clx)] } +wrap_default = function(default, trafo = identity) { + checkmate::assert_function(trafo) + if (checkmate::test_r6(default, classes = "NoDefault")) { + NULL + } else { + trafo(default) + } +} + #' @title Map a ParamSet to a ConfigSpace #' #' @description @@ -94,16 +103,16 @@ ps_to_cs = function(ps, json_file = NULL) { param = ps$params[[i]] tmp = switch(get_type_cs(param), "uniform_int" = - CSH$UniformIntegerHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = param$default), + CSH$UniformIntegerHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = wrap_default(param$default)), "uniform_float" = if (all(c("int", "log") %in% param$tags)) { - CSH$UniformIntegerHyperparameter(name = param$id, lower = as.integer(round(exp(param$lower))), upper = as.integer(round(exp(param$upper))), default_value = as.integer(round(exp(param$default))), log = TRUE) + CSH$UniformIntegerHyperparameter(name = param$id, lower = as.integer(round(exp(param$lower))), upper = as.integer(round(exp(param$upper))), default_value = wrap(param$default, trafo = function(x) as.integer(round(exp(x)))), log = TRUE) } else if ("log" %in% param$tags) { - CSH$UniformFloatHyperparameter(name = param$id, lower = exp(param$lower), upper = exp(param$upper), default_value = exp(param$default), log = TRUE) + CSH$UniformFloatHyperparameter(name = param$id, lower = exp(param$lower), upper = exp(param$upper), default_value = wrap_default(param$default, trafo = exp), log = TRUE) } else { - CSH$UniformFloatHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = param$default) + CSH$UniformFloatHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = wrap_default(param$default)) }, "categorical" = - CSH$CategoricalHyperparameter(name = param$id, choices = param$levels, default_value = param$default), + CSH$CategoricalHyperparameter(name = param$id, choices = param$levels, default_value = wrap_default(param$default)), ) cs$add_hyperparameter(tmp) } diff --git a/man/ps_to_cs.Rd b/man/ps_to_cs.Rd index 8e708755..222d8470 100644 --- a/man/ps_to_cs.Rd +++ b/man/ps_to_cs.Rd @@ -28,6 +28,7 @@ Only \link{Condition}s of class \link{CondEqual} and \link{CondAnyOf} are suppor Requires \CRANpkg{reticulate} and \CRANpkg{jsonlite} (if saving in json format is desired) to be installed. } \examples{ +\dontrun{ ps = ParamSet$new(list( ParamDbl$new("x1", lower = log(10), upper = log(20), default = log(15), tags = c("int", "log")), ParamInt$new("x2", lower = 10, upper = 20, default = 15), @@ -71,3 +72,4 @@ summary(dt_ps_) all(is.na(dt_ps_[x5 == FALSE][["x6"]])) # first dependency all(is.na(dt_ps_[x6 == "c"][["x4"]])) # second dependency } +} From e92ee97d34563af0227c87a7017bd04e72610738 Mon Sep 17 00:00:00 2001 From: sumny Date: Mon, 10 May 2021 10:04:08 +0200 Subject: [PATCH 5/6] .. --- R/ps_cs.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/ps_cs.R b/R/ps_cs.R index 9830ed2f..5d890bfa 100644 --- a/R/ps_cs.R +++ b/R/ps_cs.R @@ -82,6 +82,7 @@ wrap_default = function(default, trafo = identity) { #'all(is.na(dt_ps_[x6 == "c"][["x4"]])) # second dependency #'} ps_to_cs = function(ps, json_file = NULL) { + # FIXME: could add an argument to ignore budget params (because most python optimizers do not use budget params in the cs # FIXME: we could do some additional safety checks here assert_param_set(ps) if (!is.null(json_file)) { @@ -105,7 +106,7 @@ ps_to_cs = function(ps, json_file = NULL) { "uniform_int" = CSH$UniformIntegerHyperparameter(name = param$id, lower = param$lower, upper = param$upper, default_value = wrap_default(param$default)), "uniform_float" = if (all(c("int", "log") %in% param$tags)) { - CSH$UniformIntegerHyperparameter(name = param$id, lower = as.integer(round(exp(param$lower))), upper = as.integer(round(exp(param$upper))), default_value = wrap(param$default, trafo = function(x) as.integer(round(exp(x)))), log = TRUE) + CSH$UniformIntegerHyperparameter(name = param$id, lower = as.integer(round(exp(param$lower))), upper = as.integer(round(exp(param$upper))), default_value = wrap_default(param$default, trafo = function(x) as.integer(round(exp(x)))), log = TRUE) } else if ("log" %in% param$tags) { CSH$UniformFloatHyperparameter(name = param$id, lower = exp(param$lower), upper = exp(param$upper), default_value = wrap_default(param$default, trafo = exp), log = TRUE) } else { From 231cd2d91a100d8f47db9df0314f6372f31a79cd Mon Sep 17 00:00:00 2001 From: sumny Date: Wed, 29 Sep 2021 10:15:23 +0200 Subject: [PATCH 6/6] fix for deps --- R/ps_cs.R | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/R/ps_cs.R b/R/ps_cs.R index 5d890bfa..13761181 100644 --- a/R/ps_cs.R +++ b/R/ps_cs.R @@ -123,11 +123,37 @@ ps_to_cs = function(ps, json_file = NULL) { warning("Only log trafos can be respected automatically. Please check your trafos.") } + # FIXME: check + # deps, and_conditions treated separately + conditions = ps$deps + and_conditions = which(table(conditions$id) > 1L) + and_conditions = conditions[id %in% names(and_conditions)] + rest_conditions = conditions[id %nin% and_conditions$id] + + for (id in unique(and_conditions$id)) { + ids = which(and_conditions$id == id) + conds = map(ids, function(x) { + child = cs$get_hyperparameter(and_conditions[x, id]) + parent = cs$get_hyperparameter(and_conditions[x, on]) + cond = and_conditions[x, cond][[1L]] + cnd = if (checkmate::test_r6(cond, classes = "CondAnyOf")) { + CS$InCondition(child = child, parent = parent, values = cond$rhs) + } else if (checkmate::test_r6(cond, classes = "CondEqual")) { + CS$EqualsCondition(child = child, parent = parent, value = cond$rhs) + } else { + stop("Not implemented.") + } + cnd + }) + cnd = mlr3misc::invoke(CS$AndConjunction, .args = conds) + cs$add_condition(cnd) + } + # deps - for (i in seq_len(NROW(ps$deps))) { - child = cs$get_hyperparameter(ps$deps[i, id]) - parent = cs$get_hyperparameter(ps$deps[i, on]) - cond = ps$deps[i, cond][[1L]] + for (i in seq_len(NROW(rest_conditions))) { + child = cs$get_hyperparameter(rest_conditions[i, id]) + parent = cs$get_hyperparameter(rest_conditions[i, on]) + cond = rest_conditions[i, cond][[1L]] cnd = if (checkmate::test_r6(cond, classes = "CondAnyOf")) { CS$InCondition(child = child, parent = parent, values = cond$rhs) } else if (checkmate::test_r6(cond, classes = "CondEqual")) {