Skip to content
This repository has been archived by the owner on Jul 31, 2021. It is now read-only.

Commit

Permalink
added init_model
Browse files Browse the repository at this point in the history
  • Loading branch information
kapsner committed Apr 4, 2020
1 parent c588db2 commit 2c6683c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 1 deletion.
7 changes: 7 additions & 0 deletions R/LearnerClassifLightGBM.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ LearnerClassifLightGBM = R6::R6Class(
default = 5L,
lower = 3L,
tags = c("config", "train")),
ParamUty$new(id = "init_model",
default = NULL,
tags = c("config", "train")),
#######################################
#######################################
# Classification only
Expand Down Expand Up @@ -557,6 +560,8 @@ LearnerClassifLightGBM = R6::R6Class(
self$param_set$values[["nrounds_by_cv"]] = NULL
nfolds = self$param_set$values[["nfolds"]]
self$param_set$values[["nfolds"]] = NULL
init_model = self$param_set$values[["init_model"]]
self$param_set$values[["init_model"]] = NULL
# get training parameters
pars = self$param_set$get_values(tags = "train")
# train CV model, in case that nrounds_by_cv is true
Expand All @@ -574,6 +579,7 @@ LearnerClassifLightGBM = R6::R6Class(
, nfold = nfolds
, stratified = TRUE
, eval = feval
, init_model = init_model
)
message(
sprintf(
Expand All @@ -593,6 +599,7 @@ LearnerClassifLightGBM = R6::R6Class(
, data = private$dtrain
, params = pars
, eval = feval
, init_model = init_model
) # use the mlr3misc::invoke function (it's similar to do.call())
},
.predict = function(task) {
Expand Down
3 changes: 3 additions & 0 deletions R/LearnerRegrLightGBM.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ LearnerRegrLightGBM = R6::R6Class(
default = 5L,
lower = 3L,
tags = c("config", "train")),
ParamUty$new(id = "init_model",
default = NULL,
tags = c("config", "train")),
#######################################
#######################################
# Regression only
Expand Down
3 changes: 3 additions & 0 deletions R/lgbparams.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ lgbparams = function() {
default = 5L,
lower = 3L,
tags = c("config", "train")),
ParamUty$new(id = "init_model",
default = NULL,
tags = c("config", "train")),
#######################################
#######################################
# Regression only
Expand Down
4 changes: 3 additions & 1 deletion vignettes/mlr3learners_lightgbm_binary.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ learner$param_set$values = mlr3misc::insert_named(
"learning_rate" = 0.1,
"seed" = 17L,
"metric" = "auc",
"num_iterations" = 100
"num_iterations" = 100,
"snapshot_freq" = 10,
"output_model" = "LGB.txt"
)
)
```
Expand Down

1 comment on commit 2c6683c

@lintr-bot
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

R/zzz.R:14:29: warning: no visible binding for global variable ‘LearnerClassifLightGBM’

x$add("classif.lightgbm", LearnerClassifLightGBM)
                            ^~~~~~~~~~~~~~~~~~~~~~

R/zzz.R:15:26: warning: no visible binding for global variable ‘LearnerRegrLightGBM’

x$add("regr.lightgbm", LearnerRegrLightGBM)
                         ^~~~~~~~~~~~~~~~~~~

Please sign in to comment.