diff --git a/NAMESPACE b/NAMESPACE index d0299bd4..4d87b050 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -15,6 +15,7 @@ S3method(autoplot,LearnerClustHierarchical) S3method(autoplot,LearnerRegrCVGlmnet) S3method(autoplot,LearnerRegrGlmnet) S3method(autoplot,LearnerRegrRpart) +S3method(autoplot,LearnerSurvRpart) S3method(autoplot,OptimInstanceSingleCrit) S3method(autoplot,PredictionClassif) S3method(autoplot,PredictionClust) @@ -41,6 +42,7 @@ S3method(plot,LearnerClassifRpart) S3method(plot,LearnerRegrCVGlmnet) S3method(plot,LearnerRegrGlmnet) S3method(plot,LearnerRegrRpart) +S3method(plot,LearnerSurvRpart) S3method(plot,PredictionClassif) S3method(plot,PredictionRegr) S3method(plot,ResampleResult) diff --git a/R/LearnerSurvRpart.R b/R/LearnerSurvRpart.R new file mode 100644 index 00000000..94d0ad88 --- /dev/null +++ b/R/LearnerSurvRpart.R @@ -0,0 +1,23 @@ +#' @export +#' @rdname autoplot.LearnerClassifRpart +autoplot.LearnerSurvRpart = function(object, ...) { # nolint + if (is.null(object$model)) { + stopf("Learner '%s' must be trained first", object$id) + } + require_namespaces(c("partykit", "ggparty")) + + target = all.vars(object$model$terms)[1L] + autoplot(partykit::as.party(object$model), ...) + + ggparty::geom_node_plot(gglist = list( + geom_boxplot(aes_string(target)), + coord_flip(), + theme(axis.ticks.x = element_blank(), axis.text.x = element_blank()) + )) + + ggparty::geom_node_label(aes(label = paste0("n=", .data[["nodesize"]])), + nudge_y = 0.03, ids = "terminal") +} + +#' @export +plot.LearnerSurvRpart = function(x, ...) { + print(autoplot(x, ...)) +} diff --git a/man/autoplot.LearnerClassifRpart.Rd b/man/autoplot.LearnerClassifRpart.Rd index 60d0f34f..221faf9f 100644 --- a/man/autoplot.LearnerClassifRpart.Rd +++ b/man/autoplot.LearnerClassifRpart.Rd @@ -1,13 +1,17 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/LearnerClassifRpart.R, R/LearnerRegrRpart.R +% Please edit documentation in R/LearnerClassifRpart.R, R/LearnerRegrRpart.R, +% R/LearnerSurvRpart.R \name{autoplot.LearnerClassifRpart} \alias{autoplot.LearnerClassifRpart} \alias{autoplot.LearnerRegrRpart} +\alias{autoplot.LearnerSurvRpart} \title{Plot for LearnerClassifRpart / LearnerRegrRpart} \usage{ \method{autoplot}{LearnerClassifRpart}(object, ...) \method{autoplot}{LearnerRegrRpart}(object, ...) + +\method{autoplot}{LearnerSurvRpart}(object, ...) } \arguments{ \item{object}{(\link[mlr3:LearnerClassifRpart]{mlr3::LearnerClassifRpart} | \link[mlr3:LearnerRegrRpart]{mlr3::LearnerRegrRpart}).} diff --git a/tests/testthat/test_LearnerSurvRpart.R b/tests/testthat/test_LearnerSurvRpart.R new file mode 100644 index 00000000..64ea64aa --- /dev/null +++ b/tests/testthat/test_LearnerSurvRpart.R @@ -0,0 +1,12 @@ +skip_if_not_installed("survival") +skip_if_not_installed("mlr3proba") +skip_if_not_installed("rpart") +skip_if_not_installed("partykit") +skip_if_not_installed("ggparty") + +test_that("autoplot.LearnerSurvRpart", { + learner = mlr3::lrn("surv.rpart")$train(mlr3::tsk("rats")) + p = autoplot(learner) + expect_true(is.ggplot(p)) + vdiffr::expect_doppelganger("learner_regr.rpart", p) +})