Skip to content

Commit

Permalink
refactor: always use param_set$get_values() to access param_set values
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Apr 25, 2024
1 parent 9cd8434 commit 262ad22
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 29 deletions.
3 changes: 2 additions & 1 deletion R/LearnerClustAffinityPropagation.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ LearnerClustAP = R6Class("LearnerClustAP",
},

.predict = function(task) {
sim_func = self$param_set$values$s
pv = self$param_set$get_values()
sim_func = pv$s
exemplar_data = attributes(self$model)$exemplar_data

d = task$data()
Expand Down
13 changes: 9 additions & 4 deletions R/LearnerClustAgnes.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,22 @@ LearnerClustAgnes = R6Class("LearnerClustAgnes",
),
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(cluster::agnes, x = task$data(), diss = FALSE, .args = pv)
pv = self$param_set$get_values()
m = invoke(cluster::agnes,
x = task$data(),
diss = FALSE,
.args = remove_named(pv, "k")
)
if (self$save_assignments) {
self$assignments = stats::cutree(m, self$param_set$values$k)
self$assignments = stats::cutree(m, pv$k)
}

return(m)
},

.predict = function(task) {
if (self$param_set$values$k > task$nrow) {
pv = self$param_set$get_values(tags = "predict")
if (pv$k > task$nrow) {
stopf("`k` needs to be between 1 and %i", task$nrow)
}

Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustCMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans",
),
private = list(
.train = function(task) {
assert_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
assert_centers_param(pv$centers, task, test_data_frame, "centers")

m = invoke(e1071::cmeans, x = task$data(), .args = pv, .opts = allow_partial_matching)
if (self$save_assignments) {
self$assignments = m$cluster
Expand Down
13 changes: 9 additions & 4 deletions R/LearnerClustDiana.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,22 @@ LearnerClustDiana = R6Class("LearnerClustDiana",
),
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(cluster::diana, x = task$data(), diss = FALSE, .args = pv)
pv = self$param_set$get_values()
m = invoke(cluster::diana,
x = task$data(),
diss = FALSE,
.args = remove_named(pv, "k")
)
if (self$save_assignments) {
self$assignments = stats::cutree(m, self$param_set$values$k)
self$assignments = stats::cutree(m, pv$k)
}

return(m)
},

.predict = function(task) {
if (self$param_set$values$k > task$nrow) {
pv = self$param_set$get_values(tags = "predict")
if (pv$k > task$nrow) {
stopf("`k` needs to be between 1 and %s", task$nrow)
}

Expand Down
17 changes: 10 additions & 7 deletions R/LearnerClustHclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,26 @@ LearnerClustHclust = R6Class("LearnerClustHclust",
),
private = list(
.train = function(task) {
d = self$param_set$values$distmethod
dist_arg = self$param_set$get_values(tags = c("train", "dist"))
pv = self$param_set$get_values()
dist = invoke(stats::dist,
x = task$data(),
method = ifelse(is.null(d), "euclidean", d), .args = dist_arg
method = pv$d %??% "euclidean",
.args = self$param_set$get_values(tags = c("train", "dist"))
)
m = invoke(stats::hclust,
d = dist,
.args = self$param_set$get_values(tags = c("train", "hclust"))
)
pv = self$param_set$get_values(tags = c("train", "hclust"))
m = invoke(stats::hclust, d = dist, .args = pv)
if (self$save_assignments) {
self$assignments = stats::cutree(m, self$param_set$values$k)
self$assignments = stats::cutree(m, pv$k)
}

return(m)
},

.predict = function(task) {
if (self$param_set$values$k > task$nrow) {
pv = self$param_set$get_values(tags = "predict")
if (pv$k > task$nrow) {
stopf("`k` needs to be between 1 and %i", task$nrow)
}

Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustKKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ LearnerClustKKMeans = R6Class("LearnerClustKKMeans",
),
private = list(
.train = function(task) {
assert_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
assert_centers_param(pv$centers, task, test_data_frame, "centers")

m = invoke(kernlab::kkmeans, x = as.matrix(task$data()), .args = pv)
if (self$save_assignments) {
self$assignments = m[seq_along(m)]
Expand Down
6 changes: 3 additions & 3 deletions R/LearnerClustKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",

private = list(
.train = function(task) {
if ("nstart" %in% names(self$param_set$values) && !test_int(self$param_set$values$centers)) {
pv = self$param_set$get_values(tags = "train")
if (!is.null(pv$nstart) && !test_int(pv$centers)) {
warningf("`nstart` parameter is only relevant when `centers` is integer.")
}

assert_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")
assert_centers_param(pv$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
m = invoke(stats::kmeans, x = task$data(), .args = pv)
if (self$save_assignments) {
self$assignments = m$cluster
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustMeanShift.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ LearnerClustMeanShift = R6Class("LearnerClustMeanShift",
),
private = list(
.train = function(task) {
if (!is.null(self$param_set$values$subset) && length(self$param_set$values$subset) > task$nrow) {
pv = self$param_set$get_values(tags = "train")
if (!is.null(pv$subset) && length(pv$subset) > task$nrow) {
stopf("`subset` length must be less than or equal to number of observations in task")
}

pv = self$param_set$get_values(tags = "train")
m = invoke(LPCM::ms, X = task$data(), .args = pv)
if (self$save_assignments) {
self$assignments = m$cluster.label
Expand Down
7 changes: 3 additions & 4 deletions R/LearnerClustMiniBatchKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",
),
private = list(
.train = function(task) {
assert_centers_param(self$param_set$values$CENTROIDS, task, test_matrix, "CENTROIDS")
if (test_matrix(self$param_set$values$CENTROIDS) &&
nrow(self$param_set$values$CENTROIDS) != self$param_set$values$clusters) {
pv = self$param_set$get_values(tags = "train")
assert_centers_param(pv$CENTROIDS, task, test_matrix, "CENTROIDS")
if (test_matrix(pv$CENTROIDS) && nrow(pv$CENTROIDS) != pv$clusters) {
stopf("`CENTROIDS` must have same number of rows as `clusters`")
}

pv = self$param_set$get_values(tags = "train")
m = invoke(ClusterR::MiniBatchKmeans, data = task$data(), .args = pv)
if (self$save_assignments) {
self$assignments = unclass(invoke(ClusterR::predict_MBatchKMeans,
Expand Down

0 comments on commit 262ad22

Please sign in to comment.