diff --git a/src/core.jl b/src/core.jl index 6708ab0..24ae536 100644 --- a/src/core.jl +++ b/src/core.jl @@ -155,3 +155,11 @@ MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) = # here `fitresult` is a trained atomic machine: MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult) MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult) + +# Feature importances +function MLJBase.feature_importances(::EitherIteratedModel, fitresult, report) + # fitresult here is the curent state of the iterated machine + # The line below will return `nothing` when the iteration model doesn't + # support feature_importances. + return MLJBase.feature_importances(fitresult) +end \ No newline at end of file diff --git a/src/traits.jl b/src/traits.jl index 6380d41..216a88a 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -15,6 +15,7 @@ for trait in [:supports_weights, :is_pure_julia, :input_scitype, :output_scitype, + :reports_feature_importances, :target_scitype] quote # needed because traits are not always deducable from diff --git a/test/core.jl b/test/core.jl index 1f54945..00311b4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -272,8 +272,10 @@ function MLJBase.fit(::EphemeralRegressor, verbosity, X, y) # if I serialize/deserialized `thing` then `id` below changes: id = objectid(thing) fitresult = (thing, id, mean(y)) - return fitresult, nothing, NamedTuple() + report = (importances = [ftr => 1.0 for ftr in MLJBase.schema(X).names], ) + return fitresult, nothing, report end + function MLJBase.predict(::EphemeralRegressor, fitresult, X) thing, id, μ = fitresult return id == objectid(thing) ? fill(μ, nrows(X)) : @@ -290,7 +292,12 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) return (thing, id, μ) end -@testset "save and restore" begin +MLJBase.reports_feature_importances(::Type{<:EphemeralRegressor}) = true +function MLJBase.feature_importances(::EphemeralRegressor, fitresult, report) + return report.importances +end + +@testset "feature importances, save and restore" begin #https://github.com/JuliaAI/MLJ.jl/issues/1099 X, y = (; x = rand(10)), fill(42.0, 3) controls = [Step(1), NumberLimit(2)] @@ -302,12 +309,14 @@ end ) mach = machine(imodel, X, y) fit!(mach, verbosity=0) + @test MLJBase.feature_importances(mach) == [:x => 1.0]; io = IOBuffer() MLJBase.save(io, mach) seekstart(io) mach2 = machine(io) close(io) @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) + end end diff --git a/test/traits.jl b/test/traits.jl index 341628b..b132f98 100644 --- a/test/traits.jl +++ b/test/traits.jl @@ -23,6 +23,7 @@ imodel = IteratedModel(model=model, measure=mae) @test output_scitype(imodel) == output_scitype(model) @test target_scitype(imodel) == target_scitype(model) @test constructor(imodel) == IteratedModel +@test reports_feature_importances(imodel) == reports_feature_importances(model) end