Skip to content

Commit

Permalink
Merge pull request #27 from JuliaAI/constructor
Browse files Browse the repository at this point in the history
Overload `constructor` trait for all model types
  • Loading branch information
EssamWisam authored Jun 3, 2024
2 parents 1c971a5 + ec174a9 commit c2c40b1
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 98 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBalancing"
uuid = "45f359ea-796d-4f51-95a5-deb1a414c586"
authors = ["Essam Wisam <[email protected]>", "Anthony Blaom <[email protected]> and contributors"]
version = "0.1.4"
version = "0.1.5"

[deps]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand All @@ -12,9 +12,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
MLJBase = "1"
MLJBase = "1.4"
OrderedCollections = "1.6"
MLJModelInterface = "1.9"
MLJModelInterface = "1.10"
MLUtils = "0.4"
StatsBase = "0.34"
julia = "1.6"
Expand Down
3 changes: 2 additions & 1 deletion src/balanced_bagging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ MMI.metadata_pkg(
MMI.metadata_model(
BalancedBaggingClassifier,
target_scitype = AbstractVector{<:Finite},
load_path = "MLJBalancing." * string(BalancedBaggingClassifier),
load_path = "MLJBalancing.BalancedBaggingClassifier",
)
MMI.constructor(::Type{<:BalancedBaggingClassifier}) = BalancedBaggingClassifier

MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{<:Any,<:Any,P}}) where P =
MLJBase.prepend(:model, iteration_parameter(P))
Expand Down
26 changes: 17 additions & 9 deletions src/balanced_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ struct BalancedModel <:ProbabilisticNetworkComposite
model::Probabilistic # get rid of abstract types
end
BalancedModel(;model=nothing, balancer=nothing) = BalancedModel(model, balancer)
BalancedModel(model; kwargs...) = BalancedModel(; model, kwargs...)
BalancedModel(;model=nothing, balancer=nothing) = BalancedModel(model, balancer)
BalancedModel(model; kwargs...) = BalancedModel(; model, kwargs...)
In the following, we use macros to automate code generation of these for all model
types
Expand Down Expand Up @@ -66,15 +66,15 @@ const ERR_UNSUPPORTED_MODEL(model) = ErrorException(
"$PRETTY_SUPPORTED_MODEL_TYPES.\n"*
"Model provided has type `$(typeof(model))`. "
)
const ERR_NUM_ARGS_BM = "`BalancedModel` can at most have one non-keyword argument where the model is passed."
const ERR_NUM_ARGS_BM = "`BalancedModel` can at most have one non-keyword argument where the model is passed."


"""
BalancedModel(; model=nothing, balancer1=balancer_model1, balancer2=balancer_model2, ...)
BalancedModel(model; balancer1=balancer_model1, balancer2=balancer_model2, ...)
Given a classification model, and one or more balancer models that all implement the `MLJModelInterface`,
`BalancedModel` allows constructing a sequential pipeline that wraps an arbitrary number of balancing models
`BalancedModel` allows constructing a sequential pipeline that wraps an arbitrary number of balancing models
and a classifier together in a sequential pipeline.
# Operation
Expand All @@ -83,7 +83,7 @@ Given a classification model, and one or more balancer models that all implement
- During prediction, the balancers have no effect.
# Arguments
- `model::Supervised`: A classification model that implements the `MLJModelInterface`.
- `model::Supervised`: A classification model that implements the `MLJModelInterface`.
- `balancer1::Static=...`: The first balancer model to pass the data to. This keyword argument can have any name.
- `balancer2::Static=...`: The second balancer model to pass the data to. This keyword argument can have any name.
- and so on for an arbitrary number of balancers.
Expand Down Expand Up @@ -216,14 +216,20 @@ MMI.package_name(::Type{<:UNION_COMPOSITE_TYPES}) = "MLJBalancing"
MMI.package_license(::Type{<:UNION_COMPOSITE_TYPES}) = "MIT"
MMI.package_uuid(::Type{<:UNION_COMPOSITE_TYPES}) = "45f359ea-796d-4f51-95a5-deb1a414c586"
MMI.is_wrapper(::Type{<:UNION_COMPOSITE_TYPES}) = true
MMI.package_url(::Type{<:UNION_COMPOSITE_TYPES}) ="https://github.com/JuliaAI/MLJBalancing.jl"
MMI.package_url(::Type{<:UNION_COMPOSITE_TYPES}) =
"https://github.com/JuliaAI/MLJBalancing.jl"

# load path should point to constructor:
MMI.load_path(::Type{<:UNION_COMPOSITE_TYPES}) = "MLJBalancing.BalancedModel"
MMI.constructor(::Type{<:UNION_COMPOSITE_TYPES}) = BalancedModel

# All the composite types BalancedModelProbabilistic, BalancedModelDeterministic, etc.
const COMPOSITE_TYPES = values(MODELTYPE_TO_COMPOSITETYPE)
for composite_type in COMPOSITE_TYPES
quote
MMI.iteration_parameter(::Type{<:$composite_type{balancernames, M}}) where {balancernames, M} =
MLJBase.prepend(:model, iteration_parameter(M))
MMI.iteration_parameter(
::Type{<:$composite_type{balancernames, M}},
) where {balancernames, M} = MLJBase.prepend(:model, iteration_parameter(M))
end |> eval
for trait in [
:input_scitype,
Expand All @@ -241,7 +247,9 @@ for composite_type in COMPOSITE_TYPES
:is_supervised,
:prediction_type]
quote
MMI.$trait(::Type{<:$composite_type{balancernames, M}}) where {balancernames, M} = MMI.$trait(M)
MMI.$trait(
::Type{<:$composite_type{balancernames, M}},
) where {balancernames, M} = MMI.$trait(M)
end |> eval
end
end
12 changes: 9 additions & 3 deletions test/balanced_bagging.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

@testset "group_inds and get_majority_minority_inds_counts" begin
y = [0, 0, 0, 0, 1, 1, 1, 0]
@test MLJBalancing.group_inds(y) == Dict(0 => [1, 2, 3, 4, 8], 1 => [5, 6, 7])
Expand Down Expand Up @@ -111,7 +110,11 @@ end
pred_manual = mean([pred1, pred2])

## using BalancedBagging
modelo = BalancedBaggingClassifier(model = model, T = 2, rng = Random.MersenneTwister(42))
modelo = BalancedBaggingClassifier(
model = model,
T = 2,
rng = Random.MersenneTwister(42),
)
mach = machine(modelo, X, y)
fit!(mach)
pred_auto = MLJBase.predict(mach, Xt)
Expand All @@ -123,7 +126,10 @@ end

## traits
@test fit_data_scitype(modelo) == fit_data_scitype(model)
@test is_wrapper(modelo)
@test is_wrapper(modelo)
@test constructor(modelo) == BalancedBaggingClassifier
@test package_name(modelo) == "MLJBalancing"
@test load_path(modelo) == "MLJBalancing.BalancedBaggingClassifier"
end


Expand Down
166 changes: 84 additions & 82 deletions test/balanced_model.jl
Original file line number Diff line number Diff line change
@@ -1,93 +1,95 @@
@testset "BalancedModel" begin
### end-to-end test
# Create and split data
X, y = generate_imbalanced_data(100, 5; class_probs = [0.2, 0.3, 0.5], rng=Random.MersenneTwister(42))
X = DataFrame(X)
train_inds, test_inds =
partition(eachindex(y), 0.8, shuffle = true, stratify = y, rng = Random.MersenneTwister(42))
X_train, X_test = X[train_inds, :], X[test_inds, :]
y_train, y_test = y[train_inds], y[test_inds]

# Load models and balancers
DeterministicConstantClassifier = @load DeterministicConstantClassifier pkg=MLJModels
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels

# Here are a probabilistic and a deterministic model
model_prob = LogisticClassifier()
model_det = DeterministicConstantClassifier()
# And here are three resamplers from Imbalance.
# The package should actually work with any `Static` transformer of the form `(X, y) -> (Xout, yout)`
# provided that it implements the MLJ interface. Here, the balancer is the transformer
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = Random.MersenneTwister(42))
balancer2 = Imbalance.MLJ.SMOTENC(k = 10, ratios = 1.2, rng = Random.MersenneTwister(42))
balancer3 = Imbalance.MLJ.ROSE(ratios = 1.3, rng = Random.MersenneTwister(42))

### 1. Make a pipeline of the three balancers and a probablistic model
## ordinary way
mach = machine(balancer1)
Xover, yover = MLJBase.transform(mach, X_train, y_train)
mach = machine(balancer2)
Xover, yover = MLJBase.transform(mach, Xover, yover)
mach = machine(balancer3)
Xover, yover = MLJBase.transform(mach, Xover, yover)

mach = machine(model_prob, Xover, yover)
fit!(mach)
y_pred = MLJBase.predict(mach, X_test)

# with MLJ balancing
@test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin
BalancedModel(b1 = balancer1, b2 = balancer2, b3 = balancer3)
end
@test_throws(
MLJBalancing.ERR_UNSUPPORTED_MODEL(1),
BalancedModel(model = 1, b1 = balancer1, b2 = balancer2, b3 = balancer3),
)
@test_logs (:warn, MLJBalancing.WRN_BALANCER_UNSPECIFIED) begin
BalancedModel(model = model_prob)
end
balanced_model =
BalancedModel(model = model_prob, b1 = balancer1, b2 = balancer2, b3 = balancer3)
mach = machine(balanced_model, X_train, y_train)
fit!(mach)
y_pred2 = MLJBase.predict(mach, X_test)
### end-to-end test
# Create and split data
X, y = generate_imbalanced_data(100, 5; class_probs = [0.2, 0.3, 0.5], rng=Random.MersenneTwister(42))
X = DataFrame(X)
train_inds, test_inds =
partition(eachindex(y), 0.8, shuffle = true, stratify = y, rng = Random.MersenneTwister(42))
X_train, X_test = X[train_inds, :], X[test_inds, :]
y_train, y_test = y[train_inds], y[test_inds]

# Load models and balancers
DeterministicConstantClassifier = @load DeterministicConstantClassifier pkg=MLJModels
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels

# Here are a probabilistic and a deterministic model
model_prob = LogisticClassifier()
model_det = DeterministicConstantClassifier()
# And here are three resamplers from Imbalance.
# The package should actually work with any `Static` transformer of the form `(X, y) -> (Xout, yout)`
# provided that it implements the MLJ interface. Here, the balancer is the transformer
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = Random.MersenneTwister(42))
balancer2 = Imbalance.MLJ.SMOTENC(k = 10, ratios = 1.2, rng = Random.MersenneTwister(42))
balancer3 = Imbalance.MLJ.ROSE(ratios = 1.3, rng = Random.MersenneTwister(42))

### 1. Make a pipeline of the three balancers and a probablistic model
## ordinary way
mach = machine(balancer1)
Xover, yover = MLJBase.transform(mach, X_train, y_train)
mach = machine(balancer2)
Xover, yover = MLJBase.transform(mach, Xover, yover)
mach = machine(balancer3)
Xover, yover = MLJBase.transform(mach, Xover, yover)

mach = machine(model_prob, Xover, yover)
fit!(mach)
y_pred = MLJBase.predict(mach, X_test)

# with MLJ balancing
@test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin
BalancedModel(b1 = balancer1, b2 = balancer2, b3 = balancer3)
end
@test_throws(
MLJBalancing.ERR_UNSUPPORTED_MODEL(1),
BalancedModel(model = 1, b1 = balancer1, b2 = balancer2, b3 = balancer3),
)
@test_logs (:warn, MLJBalancing.WRN_BALANCER_UNSPECIFIED) begin
BalancedModel(model = model_prob)
end
balanced_model =
BalancedModel(model = model_prob, b1 = balancer1, b2 = balancer2, b3 = balancer3)
@test constructor(balanced_model) == BalancedModel

mach = machine(balanced_model, X_train, y_train)
fit!(mach)
y_pred2 = MLJBase.predict(mach, X_test)

@test y_pred y_pred2

# traits:
@test fit_data_scitype(balanced_model) == fit_data_scitype(model_prob)
@test is_wrapper(balanced_model)

### 2. Make a pipeline of the three balancers and a deterministic model
## ordinary way
mach = machine(balancer1)
Xover, yover = MLJBase.transform(mach, X_train, y_train)
mach = machine(balancer2)
Xover, yover = MLJBase.transform(mach, Xover, yover)
mach = machine(balancer3)
Xover, yover = MLJBase.transform(mach, Xover, yover)

mach = machine(model_det, Xover, yover)
fit!(mach)
y_pred = MLJBase.predict(mach, X_test)

# with MLJ balancing
balanced_model =
BalancedModel(model = model_det, b1 = balancer1, b2 = balancer2, b3 = balancer3)
mach = machine(balanced_model, X_train, y_train)
fit!(mach)
y_pred2 = MLJBase.predict(mach, X_test)

@test y_pred == y_pred2

### check that setpropertyname and getpropertyname work
Base.getproperty(balanced_model, :b1) == balancer1
Base.setproperty!(balanced_model, :b1, balancer2)
Base.getproperty(balanced_model, :b1) == balancer2


### 2. Make a pipeline of the three balancers and a deterministic model
## ordinary way
mach = machine(balancer1)
Xover, yover = MLJBase.transform(mach, X_train, y_train)
mach = machine(balancer2)
Xover, yover = MLJBase.transform(mach, Xover, yover)
mach = machine(balancer3)
Xover, yover = MLJBase.transform(mach, Xover, yover)

mach = machine(model_det, Xover, yover)
fit!(mach)
y_pred = MLJBase.predict(mach, X_test)

# with MLJ balancing
balanced_model =
BalancedModel(model = model_det, b1 = balancer1, b2 = balancer2, b3 = balancer3)
mach = machine(balanced_model, X_train, y_train)
fit!(mach)
y_pred2 = MLJBase.predict(mach, X_test)

@test y_pred == y_pred2

### check that setpropertyname and getpropertyname work
Base.getproperty(balanced_model, :b1) == balancer1
Base.setproperty!(balanced_model, :b1, balancer2)
Base.getproperty(balanced_model, :b1) == balancer2
@test_throws(
MLJBalancing.ERR_NO_PROP,
Base.setproperty!(balanced_model, :name11, balancer2),
Base.setproperty!(balanced_model, :name11, balancer2),
)
end

Expand All @@ -96,7 +98,7 @@ end
## setup parameters
R = Random.MersenneTwister(42)
LogisticClassifier = @load LogisticClassifier pkg = MLJLinearModels verbosity = 0
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = Random.MersenneTwister(42))
balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = Random.MersenneTwister(42))
model = LogisticClassifier()
BalancedModel(model=model, balancer1=balancer1) == BalancedModel(model; balancer1=balancer1)

Expand Down

0 comments on commit c2c40b1

Please sign in to comment.