Skip to content

Commit

Permalink
feat: use register funtions for loading and unloading (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke authored Aug 20, 2024
1 parent 8f0cd29 commit 560496d
Show file tree
Hide file tree
Showing 31 changed files with 102 additions and 94 deletions.
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.2
Collate:
'LearnerClust.R'
'aaa.R'
'zzz.R'
'LearnerClustAffinityPropagation.R'
'LearnerClustAgnes.R'
'LearnerClustBICO.R'
Expand Down Expand Up @@ -82,4 +82,3 @@ Collate:
'as_task_clust.R'
'bibentries.R'
'helper.R'
'zzz.R'
4 changes: 2 additions & 2 deletions R/LearnerClustAffinityPropagation.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ LearnerClustAP = R6Class("LearnerClustAP",
)
)

#' @include aaa.R
learners[["clust.ap"]] = LearnerClustAP
#' @include zzz.R
register_learner("clust.ap", LearnerClustAP)
4 changes: 2 additions & 2 deletions R/LearnerClustAgnes.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,5 @@ LearnerClustAgnes = R6Class("LearnerClustAgnes",
)
)

#' @include aaa.R
learners[["clust.agnes"]] = LearnerClustAgnes
#' @include zzz.R
register_learner("clust.agnes", LearnerClustAgnes)
4 changes: 2 additions & 2 deletions R/LearnerClustBICO.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ LearnerClustBICO = R6Class("LearnerClustBICO",
)
)

#' @include aaa.R
learners[["clust.bico"]] = LearnerClustBICO
#' @include zzz.R
register_learner("clust.bico", LearnerClustBICO)
4 changes: 2 additions & 2 deletions R/LearnerClustBIRCH.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ LearnerClustBIRCH = R6Class("LearnerClustBIRCH",
)
)

#' @include aaa.R
learners[["clust.birch"]] = LearnerClustBIRCH
#' @include zzz.R
register_learner("clust.birch", LearnerClustBIRCH)
4 changes: 2 additions & 2 deletions R/LearnerClustCMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans",
)
)

#' @include aaa.R
learners[["clust.cmeans"]] = LearnerClustCMeans
#' @include zzz.R
register_learner("clust.cmeans", LearnerClustCMeans)
4 changes: 2 additions & 2 deletions R/LearnerClustCobweb.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ LearnerClustCobweb = R6Class("LearnerClustCobweb",
)
)

#' @include aaa.R
learners[["clust.cobweb"]] = LearnerClustCobweb
#' @include zzz.R
register_learner("clust.cobweb", LearnerClustCobweb)
4 changes: 2 additions & 2 deletions R/LearnerClustDBSCAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ LearnerClustDBSCAN = R6Class("LearnerClustDBSCAN",
)
)

#' @include aaa.R
learners[["clust.dbscan"]] = LearnerClustDBSCAN
#' @include zzz.R
register_learner("clust.dbscan", LearnerClustDBSCAN)
4 changes: 2 additions & 2 deletions R/LearnerClustDBSCANfpc.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@ LearnerClustDBSCANfpc = R6Class("LearnerClustDBSCANfpc",
)
)

#' @include aaa.R
learners[["clust.dbscan_fpc"]] = LearnerClustDBSCANfpc
#' @include zzz.R
register_learner("clust.dbscan_fpc", LearnerClustDBSCANfpc)
4 changes: 2 additions & 2 deletions R/LearnerClustDiana.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ LearnerClustDiana = R6Class("LearnerClustDiana",
)
)

#' @include aaa.R
learners[["clust.diana"]] = LearnerClustDiana
#' @include zzz.R
register_learner("clust.diana", LearnerClustDiana)
4 changes: 2 additions & 2 deletions R/LearnerClustEM.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,5 @@ LearnerClustEM = R6Class("LearnerClustEM",
)
)

#' @include aaa.R
learners[["clust.em"]] = LearnerClustEM
#' @include zzz.R
register_learner("clust.em", LearnerClustEM)
4 changes: 2 additions & 2 deletions R/LearnerClustFanny.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ LearnerClustFanny = R6Class("LearnerClustFanny",
)
)

#' @include aaa.R
learners[["clust.fanny"]] = LearnerClustFanny
#' @include zzz.R
register_learner("clust.fanny", LearnerClustFanny)
4 changes: 2 additions & 2 deletions R/LearnerClustFarthestFirst.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ LearnerClustFarthestFirst = R6Class("LearnerClustFF",
)
)

#' @include aaa.R
learners[["clust.ff"]] = LearnerClustFarthestFirst
#' @include zzz.R
register_learner("clust.ff", LearnerClustFarthestFirst)
4 changes: 2 additions & 2 deletions R/LearnerClustFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ LearnerClustFeatureless = R6Class("LearnerClustFeatureless",
)
)

#' @include aaa.R
learners[["clust.featureless"]] = LearnerClustFeatureless
#' @include zzz.R
register_learner("clust.featureless", LearnerClustFeatureless)
4 changes: 2 additions & 2 deletions R/LearnerClustHDBSCAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ LearnerClustHDBSCAN = R6Class("LearnerClustHDBSCAN",
)
)

#' @include aaa.R
learners[["clust.hdbscan"]] = LearnerClustHDBSCAN
#' @include zzz.R
register_learner("clust.hdbscan", LearnerClustHDBSCAN)
4 changes: 2 additions & 2 deletions R/LearnerClustHclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,5 @@ LearnerClustHclust = R6Class("LearnerClustHclust",
)
)

#' @include aaa.R
learners[["clust.hclust"]] = LearnerClustHclust
#' @include zzz.R
register_learner("clust.hclust", LearnerClustHclust)
4 changes: 2 additions & 2 deletions R/LearnerClustKKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,5 @@ LearnerClustKKMeans = R6Class("LearnerClustKKMeans",
)
)

#' @include aaa.R
learners[["clust.kkmeans"]] = LearnerClustKKMeans
#' @include zzz.R
register_learner("clust.kkmeans", LearnerClustKKMeans)
4 changes: 2 additions & 2 deletions R/LearnerClustKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",
)
)

#' @include aaa.R
learners[["clust.kmeans"]] = LearnerClustKMeans
#' @include zzz.R
register_learner("clust.kmeans", LearnerClustKMeans)
4 changes: 2 additions & 2 deletions R/LearnerClustMclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ LearnerClustMclust = R6Class("LearnerClustMclust",
)
)

#' @include aaa.R
learners[["clust.mclust"]] = LearnerClustMclust
#' @include zzz.R
register_learner("clust.mclust", LearnerClustMclust)
4 changes: 2 additions & 2 deletions R/LearnerClustMeanShift.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ LearnerClustMeanShift = R6Class("LearnerClustMeanShift",
)
)

#' @include aaa.R
learners[["clust.meanshift"]] = LearnerClustMeanShift
#' @include zzz.R
register_learner("clust.meanshift", LearnerClustMeanShift)
4 changes: 2 additions & 2 deletions R/LearnerClustMiniBatchKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",
)
)

#' @include aaa.R
learners[["clust.MBatchKMeans"]] = LearnerClustMiniBatchKMeans
#' @include zzz.R
register_learner("clust.MBatchKMeans", LearnerClustMiniBatchKMeans)
4 changes: 2 additions & 2 deletions R/LearnerClustOPTICS.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ LearnerClustOPTICS = R6Class("LearnerClustOPTICS",
)
)

#' @include aaa.R
learners[["clust.optics"]] = LearnerClustOPTICS
#' @include zzz.R
register_learner("clust.optics", LearnerClustOPTICS)
4 changes: 2 additions & 2 deletions R/LearnerClustPAM.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ LearnerClustPAM = R6Class("LearnerClustPAM",
)
)

#' @include aaa.R
learners[["clust.pam"]] = LearnerClustPAM
#' @include zzz.R
register_learner("clust.pam", LearnerClustPAM)
4 changes: 2 additions & 2 deletions R/LearnerClustSimpleKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ LearnerClustSimpleKMeans = R6Class("LearnerClustSimpleKMeans",
)
)

#' @include aaa.R
learners[["clust.SimpleKMeans"]] = LearnerClustSimpleKMeans
#' @include zzz.R
register_learner("clust.SimpleKMeans", LearnerClustSimpleKMeans)
4 changes: 2 additions & 2 deletions R/LearnerClustXMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ LearnerClustXMeans = R6Class("LearnerClustXMeans",
)
)

#' @include aaa.R
learners[["clust.xmeans"]] = LearnerClustXMeans
#' @include zzz.R
register_learner("clust.xmeans", LearnerClustXMeans)
4 changes: 2 additions & 2 deletions R/TaskClust_ruspini.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ load_task_ruspini = function(id = "ruspini") {
task
}

#' @include aaa.R
tasks[["ruspini"]] = load_task_ruspini
#' @include zzz.R
register_task("ruspini", load_task_ruspini)
4 changes: 2 additions & 2 deletions R/TaskClust_usarrest.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ load_task_usarrests = function(id = "usarrests") {
task
}

#' @include aaa.R
tasks[["usarrests"]] = load_task_usarrests
#' @include zzz.R
register_task("usarrests", load_task_usarrests)
2 changes: 0 additions & 2 deletions R/aaa.R

This file was deleted.

2 changes: 1 addition & 1 deletion R/as_prediction_clust.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ as_prediction_clust.data.frame = function(x, ...) { # nolint

if (length(prob_cols)) {
if (!all(startsWith(prob_cols, "prob."))) {
stopf("Table may only contain columns 'row_ids', 'partition' as well as columns prefixed with 'prob.' for class probabilities")
stopf("Table may only contain columns 'row_ids', 'partition' as well as columns prefixed with 'prob.' for class probabilities") # nolint
}
prob = as.matrix(x[, prob_cols, with = FALSE])
colnames(prob) = substr(colnames(prob), 6L, nchar(colnames(prob)))
Expand Down
81 changes: 46 additions & 35 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,70 +1,81 @@
#' @import checkmate
#' @import data.table
#' @import mlr3
#' @import mlr3misc
#' @import paradox
#' @import mlr3
#' @import checkmate
#' @importFrom R6 R6Class
#' @importFrom clue cl_predict
#' @importFrom fpc cluster.stats
#' @importFrom cluster silhouette
#' @importFrom fpc cluster.stats
#' @importFrom stats model.frame terms predict runif dist
"_PACKAGE"

utils::globalVariables("type")

mlr3cluster_tasks = new.env()
mlr3cluster_learners = new.env()

register_task = function(name, constructor) {
if (name %in% names(mlr3cluster_tasks)) stopf("task %s registered twice", name)
mlr3cluster_tasks[[name]] = constructor
}

register_learner = function(name, constructor) {
if (name %in% names(mlr3cluster_learners)) stopf("learner %s registered twice", name)
mlr3cluster_learners[[name]] = constructor
}

register_mlr3 = function() {
# reflections
x = utils::getFromNamespace("mlr_reflections", ns = "mlr3")

# task
x$task_types = x$task_types[!"clust"]
x$task_types = setkeyv(rbind(x$task_types, rowwise_table(
mlr_reflections = utils::getFromNamespace("mlr_reflections", ns = "mlr3")
mlr_reflections$task_types = mlr_reflections$task_types[type != "clust"]
mlr_reflections$task_types = setkeyv(rbind(mlr_reflections$task_types, rowwise_table(
~type, ~package, ~task, ~learner, ~prediction, ~prediction_data, ~measure,
"clust", "mlr3cluster", "TaskClust", "LearnerClust", "PredictionClust", "PredictionDataClust", "MeasureClust"
), fill = TRUE), "type")

x$task_col_roles$clust = x$task_col_roles$regr
x$task_properties$clust = x$task_properties$regr
x$learner_properties$clust = c(
mlr_reflections$task_col_roles$clust = mlr_reflections$task_col_roles$regr
mlr_reflections$task_properties$clust = mlr_reflections$task_properties$regr
mlr_reflections$learner_properties$clust = c(
"missings", "partitional", "hierarchical", "exclusive", "overlapping", "fuzzy", "complete", "partial", "density"
)

# measure
x$measure_properties$clust = x$measure_properties$regr

# learner
x$learner_predict_types$clust = list(partition = "partition", prob = c("partition", "prob"))
x$default_measures$clust = "clust.dunn"
mlr_reflections$learner_predict_types$clust = list(partition = "partition", prob = c("partition", "prob"))
mlr_reflections$measure_properties$clust = mlr_reflections$measure_properties$regr
mlr_reflections$default_measures$clust = "clust.dunn"

# tasks
x = utils::getFromNamespace("mlr_tasks", ns = "mlr3")
x$add("usarrests", load_task_usarrests)
x$add("ruspini", load_task_ruspini)
mlr_tasks = utils::getFromNamespace("mlr_tasks", ns = "mlr3")
iwalk(as.list(mlr3cluster_tasks), function(task, id) mlr_tasks$add(id, task))

# learners
x = utils::getFromNamespace("mlr_learners", ns = "mlr3")
iwalk(learners, function(obj, nm) x$add(nm, obj))
mlr_learners = utils::getFromNamespace("mlr_learners", ns = "mlr3")
iwalk(as.list(mlr3cluster_learners), function(learner, id) mlr_learners$add(id, learner))

# measures
x = utils::getFromNamespace("mlr_measures", ns = "mlr3")
x$add("clust.silhouette", MeasureClustSil, name = "silhouette", label = "Silhouette")
x$add("clust.dunn", MeasureClustFPC, name = "dunn", label = "Dunn")
x$add("clust.ch", MeasureClustFPC, name = "ch", label = "Calinski Harabasz")
x$add("clust.wss", MeasureClustFPC, name = "wss", label = "Within Sum of Squares")
mlr_measures = utils::getFromNamespace("mlr_measures", ns = "mlr3")
mlr_measures$add("clust.silhouette", MeasureClustSil, name = "silhouette", label = "Silhouette")
mlr_measures$add("clust.dunn", MeasureClustFPC, name = "dunn", label = "Dunn")
mlr_measures$add("clust.ch", MeasureClustFPC, name = "ch", label = "Calinski Harabasz")
mlr_measures$add("clust.wss", MeasureClustFPC, name = "wss", label = "Within Sum of Squares")
}

.onLoad = function(libname, pkgname) {
backports::import(pkgname)

register_mlr3()
register_namespace_callback(pkgname, "mlr3", register_mlr3)
}

.onUnload = function(libpaths) { # nolint
mlr_learners = mlr3::mlr_learners
mlr_measures = mlr3::mlr_measures
mlr_tasks = mlr3::mlr_tasks

walk(names(learners), function(id) mlr_learners$remove(id))
walk(names(mlr3cluster_tasks), function(id) mlr_tasks$remove(id))
walk(names(mlr3cluster_learners), function(id) mlr_learners$remove(id))
walk(names(measures), function(id) mlr_measures$remove(paste("clust", id, sep = ".")))
walk(names(tasks), function(id) mlr_tasks$remove(id))

mlr_reflections$task_types = mlr_reflections$task_types[type != "clust"]
reflections = c(
"measure_properties", "default_measures", "learner_properties",
"learner_predict_types", "task_properties", "task_col_roles"
)
walk(reflections, function(x) mlr_reflections[[x]] = remove_named(mlr_reflections[[x]], "clust"))
}

leanify_package()
4 changes: 2 additions & 2 deletions tests/testthat/helper.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
library(testthat)
library(checkmate)
library(mlr3)
library(mlr3cluster)
library(checkmate)
library(testthat)

0 comments on commit 560496d

Please sign in to comment.