Skip to content

Commit

Permalink
overload constructor trait for model types
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jun 3, 2024
1 parent 3843dda commit ec174a9
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
3 changes: 2 additions & 1 deletion src/balanced_bagging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ MMI.metadata_pkg(
MMI.metadata_model(
BalancedBaggingClassifier,
target_scitype = AbstractVector{<:Finite},
load_path = "MLJBalancing." * string(BalancedBaggingClassifier),
load_path = "MLJBalancing.BalancedBaggingClassifier",
)
MMI.constructor(::Type{<:BalancedBaggingClassifier}) = BalancedBaggingClassifier

MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{<:Any,<:Any,P}}) where P =
MLJBase.prepend(:model, iteration_parameter(P))
Expand Down
16 changes: 12 additions & 4 deletions src/balanced_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,20 @@ MMI.package_name(::Type{<:UNION_COMPOSITE_TYPES}) = "MLJBalancing"
MMI.package_license(::Type{<:UNION_COMPOSITE_TYPES}) = "MIT"
MMI.package_uuid(::Type{<:UNION_COMPOSITE_TYPES}) = "45f359ea-796d-4f51-95a5-deb1a414c586"
MMI.is_wrapper(::Type{<:UNION_COMPOSITE_TYPES}) = true
MMI.package_url(::Type{<:UNION_COMPOSITE_TYPES}) ="https://github.com/JuliaAI/MLJBalancing.jl"
MMI.package_url(::Type{<:UNION_COMPOSITE_TYPES}) =
"https://github.com/JuliaAI/MLJBalancing.jl"

# load path should point to constructor:
MMI.load_path(::Type{<:UNION_COMPOSITE_TYPES}) = "MLJBalancing.BalancedModel"
MMI.constructor(::Type{<:UNION_COMPOSITE_TYPES}) = BalancedModel

# All the composite types BalancedModelProbabilistic, BalancedModelDeterministic, etc.
const COMPOSITE_TYPES = values(MODELTYPE_TO_COMPOSITETYPE)
for composite_type in COMPOSITE_TYPES
quote
MMI.iteration_parameter(::Type{<:$composite_type{balancernames, M}}) where {balancernames, M} =
MLJBase.prepend(:model, iteration_parameter(M))
MMI.iteration_parameter(
::Type{<:$composite_type{balancernames, M}},
) where {balancernames, M} = MLJBase.prepend(:model, iteration_parameter(M))
end |> eval
for trait in [
:input_scitype,
Expand All @@ -241,7 +247,9 @@ for composite_type in COMPOSITE_TYPES
:is_supervised,
:prediction_type]
quote
MMI.$trait(::Type{<:$composite_type{balancernames, M}}) where {balancernames, M} = MMI.$trait(M)
MMI.$trait(
::Type{<:$composite_type{balancernames, M}},
) where {balancernames, M} = MMI.$trait(M)
end |> eval
end
end
12 changes: 9 additions & 3 deletions test/balanced_bagging.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

@testset "group_inds and get_majority_minority_inds_counts" begin
y = [0, 0, 0, 0, 1, 1, 1, 0]
@test MLJBalancing.group_inds(y) == Dict(0 => [1, 2, 3, 4, 8], 1 => [5, 6, 7])
Expand Down Expand Up @@ -111,7 +110,11 @@ end
pred_manual = mean([pred1, pred2])

## using BalancedBagging
modelo = BalancedBaggingClassifier(model = model, T = 2, rng = Random.MersenneTwister(42))
modelo = BalancedBaggingClassifier(
model = model,
T = 2,
rng = Random.MersenneTwister(42),
)
mach = machine(modelo, X, y)
fit!(mach)
pred_auto = MLJBase.predict(mach, Xt)
Expand All @@ -123,7 +126,10 @@ end

## traits
@test fit_data_scitype(modelo) == fit_data_scitype(model)
@test is_wrapper(modelo)
@test is_wrapper(modelo)
@test constructor(modelo) == BalancedBaggingClassifier
@test package_name(modelo) == "MLJBalancing"
@test load_path(modelo) == "MLJBalancing.BalancedBaggingClassifier"
end


Expand Down
12 changes: 7 additions & 5 deletions test/balanced_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@
@test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin
BalancedModel(b1 = balancer1, b2 = balancer2, b3 = balancer3)
end
@test_throws(
MLJBalancing.ERR_UNSUPPORTED_MODEL(1),
BalancedModel(model = 1, b1 = balancer1, b2 = balancer2, b3 = balancer3),
)
@test_throws(
MLJBalancing.ERR_UNSUPPORTED_MODEL(1),
BalancedModel(model = 1, b1 = balancer1, b2 = balancer2, b3 = balancer3),
)
@test_logs (:warn, MLJBalancing.WRN_BALANCER_UNSPECIFIED) begin
BalancedModel(model = model_prob)
end
balanced_model =
BalancedModel(model = model_prob, b1 = balancer1, b2 = balancer2, b3 = balancer3)
BalancedModel(model = model_prob, b1 = balancer1, b2 = balancer2, b3 = balancer3)
@test constructor(balanced_model) == BalancedModel

mach = machine(balanced_model, X_train, y_train)
fit!(mach)
y_pred2 = MLJBase.predict(mach, X_test)
Expand Down

0 comments on commit ec174a9

Please sign in to comment.