diff --git a/src/core.jl b/src/core.jl index df5f88b..6708ab0 100644 --- a/src/core.jl +++ b/src/core.jl @@ -151,3 +151,7 @@ MLJBase.predict(::EitherIteratedModel, fitresult, Xnew) = MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) = transform(fitresult, Xnew) + +# here `fitresult` is a trained atomic machine: +MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult) +MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult) diff --git a/test/core.jl b/test/core.jl index 566b5ac..f7dccde 100644 --- a/test/core.jl +++ b/test/core.jl @@ -285,7 +285,7 @@ end seekstart(io) mach2 = machine(io) close(io) - @test_broken MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) + @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) end end