diff --git a/DESCRIPTION b/DESCRIPTION index d3f912d8..332a98ec 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -39,6 +39,8 @@ Suggests: RWeka, stream, testthat (>= 3.0.0) +Remotes: + mlr-org/mlr3 Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE, r6 = TRUE) diff --git a/NAMESPACE b/NAMESPACE index a0e4519c..8e8e87e1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -10,6 +10,7 @@ S3method(as_task_clust,data.frame) S3method(as_task_clust,formula) S3method(c,PredictionDataClust) S3method(check_prediction_data,PredictionDataClust) +S3method(create_empty_prediction_data,TaskClust) S3method(filter_prediction_data,PredictionDataClust) S3method(is_missing_prediction_data,PredictionDataClust) export(LearnerClust) diff --git a/R/PredictionDataClust.R b/R/PredictionDataClust.R index 5047eaba..0b339163 100644 --- a/R/PredictionDataClust.R +++ b/R/PredictionDataClust.R @@ -89,3 +89,19 @@ filter_prediction_data.PredictionDataClust = function(pdata, row_ids, ...) { pdata } + +#' @export +create_empty_prediction_data.TaskClust = function(task, learner) { + predict_types = mlr_reflections$learner_predict_types[["clust"]][[learner$predict_type]] + + pdata = list( + row_ids = integer(), + partition = integer() + ) + + if ("prob" %in% predict_types) { + pdata$prob = matrix(integer()) + } + + set_class(pdata, "PredictionDataClust") +} diff --git a/tests/testthat/test_PredictionClust.R b/tests/testthat/test_PredictionClust.R index b694ab6e..c005987b 100644 --- a/tests/testthat/test_PredictionClust.R +++ b/tests/testthat/test_PredictionClust.R @@ -25,3 +25,27 @@ test_that("filter works", { expect_set_equal(pdata$row_ids, 1:3) expect_integer(pdata$partition, len = 3) }) + +test_that("construction of empty PredictionDataClust", { + task = tsk("usarrests") + + learner = lrn("clust.featureless", predict_type = "partition") + learner$train(task) + pred = learner$predict(task, row_ids = integer()) + expect_prediction(pred) + expect_set_equal(pred$predict_types, "partition") + expect_integer(pred$row_ids, len = 0L) + expect_numeric(pred$partition, len = 0L) + expect_null(pred$prob) + expect_data_table(as.data.table(pred), nrows = 0L, ncols = 2L) + + learner = lrn("clust.featureless", predict_type = "prob") + learner$train(task) + pred = learner$predict(task, row_ids = integer()) + expect_prediction(pred) + expect_set_equal(pred$predict_types, c("partition", "prob")) + expect_integer(pred$row_ids, len = 0L) + expect_numeric(pred$partition, len = 0L) + expect_numeric(pred$prob, len = 0L) + expect_data_table(as.data.table(pred), nrows = 0L, ncols = 3L) +})