Skip to content

Commit

Permalink
Merge pull request #67 from JuliaAI/featimp
Browse files Browse the repository at this point in the history
Add feature importances support to iterated models
  • Loading branch information
ablaom authored Sep 5, 2024
2 parents 99a5dd2 + 104705b commit a3aa159
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,11 @@ MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) =
# here `fitresult` is a trained atomic machine:
MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult)
MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult)

# Feature importances
function MLJBase.feature_importances(::EitherIteratedModel, fitresult, report)
# fitresult here is the curent state of the iterated machine
# The line below will return `nothing` when the iteration model doesn't
# support feature_importances.
return MLJBase.feature_importances(fitresult)
end
1 change: 1 addition & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ for trait in [:supports_weights,
:is_pure_julia,
:input_scitype,
:output_scitype,
:reports_feature_importances,
:target_scitype]
quote
# needed because traits are not always deducable from
Expand Down
13 changes: 11 additions & 2 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
# if I serialize/deserialized `thing` then `id` below changes:
id = objectid(thing)
fitresult = (thing, id, mean(y))
return fitresult, nothing, NamedTuple()
report = (importances = [ftr => 1.0 for ftr in MLJBase.schema(X).names], )
return fitresult, nothing, report
end

function MLJBase.predict(::EphemeralRegressor, fitresult, X)
thing, id, μ = fitresult
return id == objectid(thing) ? fill(μ, nrows(X)) :
Expand All @@ -290,7 +292,12 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
return (thing, id, μ)
end

@testset "save and restore" begin
MLJBase.reports_feature_importances(::Type{<:EphemeralRegressor}) = true
function MLJBase.feature_importances(::EphemeralRegressor, fitresult, report)
return report.importances
end

@testset "feature importances, save and restore" begin
#https://github.com/JuliaAI/MLJ.jl/issues/1099
X, y = (; x = rand(10)), fill(42.0, 3)
controls = [Step(1), NumberLimit(2)]
Expand All @@ -302,12 +309,14 @@ end
)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
@test MLJBase.feature_importances(mach) == [:x => 1.0];
io = IOBuffer()
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
close(io)
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)

end

end
Expand Down
1 change: 1 addition & 0 deletions test/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ imodel = IteratedModel(model=model, measure=mae)
@test output_scitype(imodel) == output_scitype(model)
@test target_scitype(imodel) == target_scitype(model)
@test constructor(imodel) == IteratedModel
@test reports_feature_importances(imodel) == reports_feature_importances(model)

end

Expand Down

0 comments on commit a3aa159

Please sign in to comment.