Skip to content

Commit

Permalink
feat: add 3 params for abess in R
Browse files Browse the repository at this point in the history
include: fit.intercept, beta.low, beta.high
  • Loading branch information
bbayukari committed Sep 3, 2023
1 parent d039462 commit 96ed869
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
19 changes: 18 additions & 1 deletion R-package/R/abess.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ abess <- function(x, ...) UseMethod("abess")
#' \code{normalize = 3} for scaling the columns of \code{x} to have \eqn{\sqrt n} norm.
#' If \code{normalize = NULL}, \code{normalize} will be set \code{1} for \code{"gaussian"} and \code{"mgaussian"},
#' \code{3} for \code{"cox"}. Default is \code{normalize = NULL}.
#' @param fit.intercept A boolean value indicating whether to fit an intercept.
#' We assume the data has been centered if \code{fit.intercept = FALSE}.
#' Default: \code{fit.intercept = FALSE}.
#' @param beta.low A single value specifying the lower bound of \eqn{\beta}. Default is \code{-.Machine$double.xmax}。
#' @param beta.high A single value specifying the upper bound of \eqn{\beta}. Default is \code{.Machine$double.xmax}。
#' @param c.max an integer splicing size. Default is: \code{c.max = 2}.
#' @param weight Observation weights. When \code{weight = NULL},
#' we set \code{weight = 1} for each observation as default.
Expand Down Expand Up @@ -303,6 +308,9 @@ abess.default <- function(x,
tune.type = c("gic", "ebic", "bic", "aic", "cv"),
weight = NULL,
normalize = NULL,
fit.intercept = TRUE,
beta.low = -.Machine$double.xmax,
beta.high = .Machine$double.xmax,
c.max = 2,
support.size = NULL,
gs.range = NULL,
Expand Down Expand Up @@ -352,10 +360,13 @@ abess.default <- function(x,
tune.path=tune.path,
max.newton.iter=max.newton.iter,
lambda=lambda,
beta.low=beta.low,
beta.high=beta.high,
family=family,
screening.num=screening.num,
gs.range=gs.range,
early.stop=early.stop,
fit.intercept=fit.intercept,
weight=weight,
cov.update=cov.update,
normalize=normalize,
Expand All @@ -373,6 +384,8 @@ abess.default <- function(x,
x <- data$x
tune.path <- para$tune.path
lambda <- para$lambda
beta_low <- para$beta_low
beta_high <- para$beta_high
family <- para$family
gs.range <- para$gs.range
weight <- para$weight
Expand Down Expand Up @@ -410,6 +423,7 @@ abess.default <- function(x,
y_dim <- para$y_dim
multi_y <- para$multi_y
early_stop <- para$early_stop
fit_intercept <- para$fit_intercept

result <- abessGLM_API(
x = x,
Expand Down Expand Up @@ -447,7 +461,10 @@ abess.default <- function(x,
splicing_type = splicing_type,
sub_search = important_search,
cv_fold_id = cv_fold_id,
A_init = as.integer(init.active.set)
A_init = as.integer(init.active.set),
fit_intercept = fit_intercept,
beta_low = beta_low,
beta_high = beta_high
)

## process result
Expand Down
35 changes: 35 additions & 0 deletions R-package/R/initialization.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ Initialization_GLM <- function(c.max,
tune.path,
max.newton.iter,
lambda,
beta.low,
beta.high,
family,
screening.num,
gs.range,
early.stop,
fit.intercept,
weight,
cov.update,
normalize,
Expand All @@ -72,10 +75,13 @@ Initialization_GLM <- function(c.max,
para$tune.path <- tune.path
para$max.newton.iter <- max.newton.iter
para$lambda <- lambda
para$beta.low <- beta.low
para$beta.high <- beta.high
para$family <- family
para$screening.num <- screening.num
para$gs.range <- gs.range
para$early.stop <- early.stop
para$fit.intercept <- fit.intercept
para$weight <- weight
para$cov.update <- cov.update
para$normalize <- normalize
Expand Down Expand Up @@ -268,6 +274,24 @@ lambda.rpca <- lambda_private

lambda.glm <- lambda_private

beta_range <- function(para)
UseMethod("beta_range")

beta_range_private <- function(para) {
stopifnot(length(para$beta.low) == 1)
stopifnot(length(para$beta.high) == 1)
stopifnot(!anyNA(para$beta.low))
stopifnot(!anyNA(para$beta.high))
stopifnot(para$beta.low < para$beta.high)

para$beta_low <- as.double(para$beta.low)
para$beta_high <- as.double(para$beta.high)

para
}

beta_range.glm <- beta_range_private


warm_start <- function(para)
UseMethod("warm_start")
Expand Down Expand Up @@ -973,6 +997,15 @@ early_stop.glm <- function(para) {
para
}

fit_intercept <- function(para)
UseMethod("fit_intercept")

fit_intercept.glm <- function(para) {
stopifnot(is.logical(para$fit.intercept))
para$fit_intercept <- para$fit.intercept

para
}

model_type <- function(para)
UseMethod("model_type")
Expand Down Expand Up @@ -1122,8 +1155,10 @@ initializate <- function(para, data)

initializate.glm <- function(para, data) {
para <- lambda(para)
para <- beta_range(para)
para <- number_of_thread(para)
para <- early_stop(para)
para <- fit_intercept(para)
para <- warm_start(para)
para <- splicing_type(para)
para <- max_splicing_iter(para)
Expand Down

0 comments on commit 96ed869

Please sign in to comment.