Skip to content

Commit

Permalink
Merge pull request #52 from JuliaAI/stabilize-rng-in-tests
Browse files Browse the repository at this point in the history
Stabilize RNG in tests
  • Loading branch information
ablaom authored Jul 1, 2024
2 parents 822bdd9 + a0b492a commit 402861a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
20 changes: 14 additions & 6 deletions src/MLJXGBoostInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,21 @@ function kwargs(model, verbosity, obj)
excluded = [:importance_type]
fn = filter((excluded), fieldnames(typeof(model)))
out = NamedTuple(n=>getfield(model, n) for n fn if !isnothing(getfield(model, n)))
out = merge(out, (silent=(verbosity 0),))
# watchlist is for log output, so override if it's default and verbosity ≤ 0
wl = (verbosity 0 && isnothing(model.watchlist)) ? (;) : model.watchlist
if !isnothing(wl)
out = merge(out, (watchlist=wl,))

# `watchlist` needs to be consistent with `verbosity`. If you don't pass
# `watchlist=(;)` in the case of unspecified `watchlist`, then logging will happen no
# matter what the value of `verbosity`!
watchlist = (verbosity 0 && isnothing(model.watchlist)) ? (;) : model.watchlist
if !isnothing(watchlist)
out = merge(out, (; watchlist))
end
out = merge(out, (objective=_fix_objective(obj),))

# need `0 ≤ verbosity ≤ 3`:
verbosity = min(max(verbosity, 0), 3)

objective=_fix_objective(obj)
out = merge(out, (; verbosity, objective))

return out
end

Expand Down
18 changes: 13 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import XGBoost
using MLJXGBoostInterface
using MLJTestInterface
using Distributions
import StableRNGs
const rng = StableRNGs.StableRNG(123)
using StableRNGs
const rng = StableRNG(123)

@test_logs (:warn, r"Constraint ") XGBoostClassifier(objective="wrong")
@test_logs (:warn, r"Constraint ") XGBoostCount(objective="wrong")
Expand Down Expand Up @@ -54,18 +54,26 @@ end

# test regressor for early stopping rounds
# add some noise to create more differentiator in the evaluation metric to test if it chose the correct ntree_limit
mod_labels = labels + rand(Float64, 1000) * 10
mod_labels = labels + rand(StableRNG(123), Float64, 1000) * 10
es_regressor = XGBoostRegressor(num_round = 250, early_stopping_rounds = 20, eta = 0.5, max_depth = 20,
eval_metric = ["mae"], watchlist = Dict("train" => XGBoost.DMatrix(features, mod_labels)))
(fitresultR, cacheR, reportR) = MLJBase.fit(es_regressor, 0, features, mod_labels)
(fitresultR, cacheR, reportR) = @test_logs(
(:info,),
match_mode=:any,
MLJBase.fit(es_regressor, 0, features, mod_labels),
)
rpred = predict(es_regressor, fitresultR, features);
@test abs(mean(abs.(rpred-mod_labels)) - fitresultR[1].best_score) < 1e-8
@test !ismissing(fitresultR[1].best_iteration)

# try without early stopping (should be worse given the generated dataset) - to make sure it's a fair comparison - set early_stopping_rounds = num_round
nes_regressor = XGBoostRegressor(num_round = 250, early_stopping_rounds = 250, eta = 0.5, max_depth = 20,
eval_metric = ["mae"], watchlist = Dict("train" => XGBoost.DMatrix(features, mod_labels)))
(fitresultR, cacheR, reportR) = MLJBase.fit(nes_regressor, 0, features, mod_labels)
(fitresultR, cacheR, reportR) = @test_logs(
(:info,),
match_mode=:any,
MLJBase.fit(nes_regressor, 0, features, mod_labels),
)
rpred_noES = predict(es_regressor, fitresultR, features);
@test abs(mean(abs.(rpred-mod_labels))) < abs(mean(abs.(rpred_noES-mod_labels)))
@test ismissing(fitresultR[1].best_iteration)
Expand Down

0 comments on commit 402861a

Please sign in to comment.