diff --git a/R/Graph.R b/R/Graph.R index d05f5b83d..96362b250 100644 --- a/R/Graph.R +++ b/R/Graph.R @@ -26,7 +26,7 @@ #' @section Fields: #' * `pipeops` :: named `list` of [`PipeOp`] \cr #' Contains all [`PipeOp`]s in the [`Graph`], named by the [`PipeOp`]'s `$id`s. -#' * `edges` :: [`data.table`] with columns `src_id` (`character`), `src_channel` (`character`), `dst_id` (`character`), `dst_channel` (`character`)\cr +#' * `edges` :: [`data.table`] with columns `src_id` (`character`), `src_channel` (`character`), `dst_id` (`character`), `dst_channel` (`character`)\cr #' Table of connections between the [`PipeOp`]s. A [`data.table`]. `src_id` and `dst_id` are `$id`s of [`PipeOp`]s that must be present in #' the `$pipeops` list. `src_channel` and `dst_channel` must respectively be `$output` and `$input` channel names of the #' respective [`PipeOp`]s. @@ -72,6 +72,10 @@ #' be supplied, e.g. a [`Learner`][mlr3::Learner] or a [`Filter`][mlr3filters::Filter]; see [`as_pipeop()`]. #' The argument given as `op` is always cloned; to access a `Graph`'s [`PipeOp`]s by-reference, use `$pipeops`.\cr #' Note that `$add_pipeop()` is a relatively low-level operation, it is recommended to build graphs using [`%>>%`]. +#' * `remove_pipeop(id)` \cr +#' (`character(1)`) -> `self` \cr +#' Mutates [`Graph`] by removing the [`PipeOp`] with the matching id from the [`Graph`]. +#' Corresponding edges are also removed as well as the corresponding [`ParamSet`][paradox::ParamSet]. #' * `add_edge(src_id, dst_id, src_channel = NULL, dst_channel = NULL)` \cr #' (`character(1)`, `character(1)`, #' `character(1)` | `numeric(1)` | `NULL`, @@ -81,6 +85,10 @@ #' channel `dst_channel` (identified by its name or number as listed in the [`PipeOp`]'s `$input`). #' If source or destination [`PipeOp`] have only one input / output channel and `src_channel` / `dst_channel` #' are therefore unambiguous, they can be omitted (i.e. left as `NULL`). +#' * `replace_subgraph(ids, substitute)` \cr +#' (`character()`, [`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `...`) -> `self` \cr +#' Mutates [`Graph`] by replacing a subgraph specified via ids with the supplied substitute subgraph. +#' Note that the supplied ids are always reordered in topological order with respect to the [`Graph`]. #' * `plot(html)` \cr #' (`logical(1)`) -> `NULL` \cr #' Plot the [`Graph`], using either the \pkg{igraph} package (for `html = FALSE`, default) or @@ -248,6 +256,115 @@ Graph = R6Class("Graph", invisible(self) }, + remove_pipeop = function(id) { + assert_subset(id, choices = self$ids(TRUE), empty.ok = FALSE) + self$pipeops[[id]] = NULL + self$edges = self$edges[src_id != id & dst_id != id] + + if (!is.null(private$.param_set)) { + # param_set is built on-demand; if it has not been requested before, its value may be NULL + # and we don't need to remove anything. + private$.param_set$remove_sets(id) + } + invisible(self) + }, + + replace_subgraph = function(ids, substitute) { + # if this fails, pipeops, edges and param_set get reset + old_pipeops = self$pipeops + old_edges = self$edges + old_ps = private$.param_set + on.exit({ + self$pipeops = old_pipeops + self$edges = old_edges + private$.param_set = old_ps + }) + + assert_subset(ids, choices = self$ids(TRUE), empty.ok = FALSE) + ids = self$ids(TRUE)[match(ids, self$ids(TRUE))] # always reorder ids topologically + substitute = as_graph(substitute, clone = TRUE) + + # FIXME: check that ids are actually a valid subgraph of graph + + # FIXME: + # check whether the input of the substitute is a vararg channel + #if (any(strip_multiplicity_type(substitute$input$channel.name) == "...")) { + # stopf("Using a substitute with a vararg input channel is not supported (yet).") + #} + + # check whether the last id that is to be replaced connects to a varag channel + #if (nrow(self$edges)) { # this can be a data table with zero rows + # type = self$edges[src_id == range(ids)[2L], dst_channel] + # if (length(type)) { # can be of length 0 if this is the end of the graph + # if (strip_multiplicity_type(type) == "...") { + # stopf("Replacing a Subgraph that is connected to a vararg channel is not supported (yet).") + # } + # } + #} + + input_orig = self$input + output_orig = self$output + + for (id in ids) { + self$remove_pipeop(id) # also handles param_set + } + + input = self$input[name != input_orig$name] + output = self$output[name != output_orig$name] + + for (pipeop in substitute$pipeops) { + self$add_pipeop(pipeop) # also handles param_set + } + if (nrow(substitute$edges)) { + self$edges = rbind(self$edges, substitute$edges) + } + + # FIXME: this reuses a lot of `%>>%`, we could write a general helper + # build edges from free output channels of substitute and free input channels of self + n_input = nrow(input) + if (n_input) { + # FIXME: check number of inputs / outputs + for (row in seq_len(n_input)) { + if (!are_types_compatible(strip_multiplicity_type(substitute$output$train[row]), strip_multiplicity_type(input$train[row]))) { + stopf("Output type of PipeOp %s during training (%s) incompatible with input type of PipeOp %s (%s)", + substitute$output$op.id[row], substitute$output$train[row], input$op.id[row], input$train[row]) + } + if (!are_types_compatible(strip_multiplicity_type(substitute$output$predict[row]), strip_multiplicity_type(input$predict[row]))) { + stopf("Output type of PipeOp %s during prediction (%s) incompatible with input type of PipeOp %s (%s)", + substitute$output$op.id[row], substitute$output$predict[row], input$op.id[row], input$predict[row]) + } + } + new_edges = cbind(substitute$output[, list(src_id = get("op.id"), src_channel = get("channel.name"))], input[, list(dst_id = get("op.id"), dst_channel = get("channel.name"))]) + self$edges = rbind(self$edges, new_edges) + } + + # build edges from free output channels of self and free input channels of substitute + n_output = nrow(output) + if (n_output) { + # FIXME: check number of inputs / outputs + for (row in seq_len(n_output)) { + if (!are_types_compatible(strip_multiplicity_type(output$train[row]), strip_multiplicity_type(substitute$input$train[row]))) { + stopf("Output type of PipeOp %s during training (%s) incompatible with input type of PipeOp %s (%s)", + output$op.id[row], output$train[row], substitute$input$op.id[row], substitute$input$train[row]) + } + if (!are_types_compatible(strip_multiplicity_type(output$predict[row]), strip_multiplicity_type(substitute$input$predict[row]))) { + stopf("Output type of PipeOp %s during prediction (%s) incompatible with input type of PipeOp %s (%s)", + output$op.id[row], output$predict[row], substitute$input$op.id[row], substitute$input$predict[row]) + } + } + new_edges = cbind(output[, list(src_id = get("op.id"), src_channel = get("channel.name"))], substitute$input[, list(dst_id = get("op.id"), dst_channel = get("channel.name"))]) + self$edges = rbind(self$edges, new_edges) + } + + # check if valid DAG + invisible(tryCatch(self$ids(TRUE), error = function(error_condition) { + stopf("Failed to infer new Graph structure. Resetting.") + })) + + on.exit({}) + invisible(self) + }, + plot = function(html = FALSE) { assert_flag(html) if (!length(self$pipeops)) { diff --git a/R/zzz.R b/R/zzz.R index 0573770ba..01fc0ebf4 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -37,6 +37,6 @@ register_mlr3 = function() { } # nocov end # static code checks should not complain about commonly used data.table columns -utils::globalVariables(c("src_id", "dst_id", "name", "op.id", "response", "truth")) +utils::globalVariables(c("src_id", "dst_id", "src_channel", "dst_channel", "name", "op.id", "response", "truth")) leanify_package() diff --git a/man/Graph.Rd b/man/Graph.Rd index 97aab6cd4..ccac33593 100644 --- a/man/Graph.Rd +++ b/man/Graph.Rd @@ -34,7 +34,7 @@ the \code{\link{PipeOp}} results along the edges as input to other \code{\link{P \itemize{ \item \code{pipeops} :: named \code{list} of \code{\link{PipeOp}} \cr Contains all \code{\link{PipeOp}}s in the \code{\link{Graph}}, named by the \code{\link{PipeOp}}'s \verb{$id}s. -\item \code{edges} :: \code{\link{data.table}} with columns \code{src_id} (\code{character}), \code{src_channel} (\code{character}), \code{dst_id} (\code{character}), \code{dst_channel} (\code{character})\cr +\item \code{edges} :: \code{\link{data.table}} with columns \code{src_id} (\code{character}), \code{src_channel} (\code{character}), \code{dst_id} (\code{character}), \code{dst_channel} (\code{character})\cr Table of connections between the \code{\link{PipeOp}}s. A \code{\link{data.table}}. \code{src_id} and \code{dst_id} are \verb{$id}s of \code{\link{PipeOp}}s that must be present in the \verb{$pipeops} list. \code{src_channel} and \code{dst_channel} must respectively be \verb{$output} and \verb{$input} channel names of the respective \code{\link{PipeOp}}s. @@ -84,6 +84,10 @@ Instead of supplying a \code{\link{PipeOp}} directly, an object that can natural be supplied, e.g. a \code{\link[mlr3:Learner]{Learner}} or a \code{\link[mlr3filters:Filter]{Filter}}; see \code{\link[=as_pipeop]{as_pipeop()}}. The argument given as \code{op} is always cloned; to access a \code{Graph}'s \code{\link{PipeOp}}s by-reference, use \verb{$pipeops}.\cr Note that \verb{$add_pipeop()} is a relatively low-level operation, it is recommended to build graphs using \code{\link{\%>>\%}}. +\item \code{remove_pipeop(id)} \cr +(\code{character(1)}) -> \code{self} \cr +Mutates \code{\link{Graph}} by removing the \code{\link{PipeOp}} with the matching id from the \code{\link{Graph}}. +Corresponding edges are also removed as well as the corresponding \code{\link[paradox:ParamSet]{ParamSet}}. \item \code{add_edge(src_id, dst_id, src_channel = NULL, dst_channel = NULL)} \cr (\code{character(1)}, \code{character(1)}, \code{character(1)} | \code{numeric(1)} | \code{NULL}, @@ -93,6 +97,10 @@ Add an edge from \code{\link{PipeOp}} \code{src_id}, and its channel \code{src_c channel \code{dst_channel} (identified by its name or number as listed in the \code{\link{PipeOp}}'s \verb{$input}). If source or destination \code{\link{PipeOp}} have only one input / output channel and \code{src_channel} / \code{dst_channel} are therefore unambiguous, they can be omitted (i.e. left as \code{NULL}). +\item \code{replace_subgraph(ids, substitute)} \cr +(\code{character()}, \code{\link{Graph}} | \code{\link{PipeOp}} | \code{\link[mlr3:Learner]{Learner}} | \code{\link[mlr3filters:Filter]{Filter}} | \code{...}) -> \code{self} \cr +Mutates \code{\link{Graph}} by replacing a subgraph specified via ids with the supplied substitute subgraph. +Note that the supplied ids are always reordered in topological order with respect to the \code{\link{Graph}}. \item \code{plot(html)} \cr (\code{logical(1)}) -> \code{NULL} \cr Plot the \code{\link{Graph}}, using either the \pkg{igraph} package (for \code{html = FALSE}, default) or diff --git a/tests/testthat/test_Graph.R b/tests/testthat/test_Graph.R index 0dadd4d07..8d4ebf6b5 100644 --- a/tests/testthat/test_Graph.R +++ b/tests/testthat/test_Graph.R @@ -4,9 +4,9 @@ test_that("linear graph", { g = Graph$new() expect_equal(g$ids(sorted = TRUE), character(0)) - # FIXME: we should "dummy" ops, so we can change properties of the ops at will + # FIXME: we should use "dummy" ops, so we can change properties of the ops at will # we should NOT use PipeOpNOP, because we want to check that $train/$predict actually does something. - # FIXME: we should packages of the graph + # FIXME: we should check packages of the graph op_ds = PipeOpSubsample$new() op_pca = PipeOpPCA$new() op_lrn = PipeOpLearner$new(mlr_learners$get("classif.rpart")) @@ -434,3 +434,79 @@ test_that("dot output", { "6 [label=\"OUTPUT", "nop_output\",fontsize=24]"), out[-c(1L, 15L)]) }) + + + +test_that("replace_subgraph", { + task = tsk("iris") + + # Basics + gr = Graph$new()$add_pipeop(PipeOpDebugMulti$new(2, 2)) + address_old = address(gr) + gr_old = gr$clone(deep = TRUE) + expect_error(gr$replace_subgraph("id_not_present", PipeOpDebugMulti$new(2, 2)), + regexp = "Assertion on 'ids' failed") + expect_error(gr$replace_subgraph("debug.multi", NULL), + regexp = "op can not be converted to PipeOp") + expect_equal(gr, gr_old) # error results in a clean reset + expect_true(address_old == address(gr)) + expect_deep_clone(gr_old, gr) + + gr$replace_subgraph("debug.multi", substitute = PipeOpDebugMulti$new(2, 2)) + expect_equal(gr_old, gr) + expect_true(address_old == address(gr)) # in place modification + expect_deep_clone(gr_old, gr) # replacing with exactly the same pipeop is the same as a deep clone + + # Linear Graph + gr = po("scale") %>>% po("pca") %>>% lrn("classif.rpart") + gr_old = gr$clone(deep = TRUE) + gr$replace_subgraph("scale", substitute = po("scalemaxabs")) # replace beginning + expect_set_equal(gr$ids(), c("scalemaxabs", "pca", "classif.rpart")) + expect_true(gr$input$op.id == "scalemaxabs") + expect_true(gr$output$op.id == "classif.rpart") + expect_null(gr$train(task)[[1L]]) + expect_prediction_classif(gr$predict(task)[[1L]]) + + gr = gr_old$clone(deep = TRUE) + gr$replace_subgraph("classif.rpart", substitute = lrn("classif.featureless")) # replace end + expect_set_equal(gr$ids(), c("scale", "pca", "classif.featureless")) + expect_true(gr$input$op.id == "scale") + expect_true(gr$output$op.id == "classif.featureless") + expect_null(gr$train(task)[[1L]]) + expect_prediction_classif(gr$predict(task)[[1L]]) + + gr = gr_old$clone(deep = TRUE) + gr$replace_subgraph(c("scale", "pca", "classif.rpart"), substitute = po("scalemaxabs") %>>% po("ica") %>>% lrn("classif.featureless")) # replace whole graph + expect_set_equal(gr$ids(), c("scalemaxabs", "ica", "classif.featureless")) + expect_true(gr$input$op.id == "scalemaxabs") + expect_true(gr$output$op.id == "classif.featureless") + expect_null(gr$train(task)[[1L]]) + expect_prediction_classif(gr$predict(task)[[1L]]) + + gr = gr_old$clone(deep = TRUE) + gr$replace_subgraph(c("pca", "scale"), substitute = po("scalemaxabs") %>>% po("ica")) # replace linear subgraph + expect_set_equal(gr$ids(), c("scalemaxabs", "ica", "classif.rpart")) + expect_true(gr$input$op.id == "scalemaxabs") + expect_true(gr$output$op.id == "classif.rpart") + expect_null(gr$train(task)[[1L]]) + expect_prediction_classif(gr$predict(task)[[1L]]) + + # Non linear Graph + gr = po("scale") %>>% po("branch", c("pca", "nop")) %>>% gunion(list(po("pca"), po("nop"))) %>>% po("unbranch") %>>% lrn("classif.rpart") + gr_old = gr$clone(deep = TRUE) + #expect_error(gr$replace_subgraph(c("nop"), substitute = po("ica")), regexp = "connected to a vararg channel is not supported") # FIXME: + expect_error(gr$replace_subgraph(c("branch", "pca", "nop", "unbranch"), substitute = lrn("classif.featureless")), + regexp = "Output type of PipeOp classif.featureless during training") + gr$replace_subgraph(c("branch", "pca", "nop", "unbranch"), substitute = po("branch", c("pca", "ica")) %>>% gunion(list(po("pca"), po("ica"))) %>>% po("unbranch")) + expect_set_equal(gr$ids(TRUE), c("scale", "branch", "pca", "ica", "unbranch", "classif.rpart")) + expect_true(gr$input$op.id == "scale") + expect_true(gr$output$op.id == "classif.rpart") + expect_null(gr$train(task)[[1L]]) + state1 = gr$state + gr$param_set$values$branch.selection = "ica" + expect_null(gr$train(task)[[1L]]) + state2 = gr$state + expect_true(test_r6(state1$ica, classes = "NO_OP")) + expect_true(test_r6(state2$pca, classes = "NO_OP")) + expect_prediction_classif(gr$predict(task)[[1L]]) +})