From e901f4a33fef286caa810eed1384b70752011fd4 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 12:31:02 +1300 Subject: [PATCH] fix serialization, resolving part of MLJ.jl issue 1099 --- src/core.jl | 4 ++++ test/core.jl | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/core.jl b/src/core.jl index df5f88b..6708ab0 100644 --- a/src/core.jl +++ b/src/core.jl @@ -151,3 +151,7 @@ MLJBase.predict(::EitherIteratedModel, fitresult, Xnew) = MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) = transform(fitresult, Xnew) + +# here `fitresult` is a trained atomic machine: +MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult) +MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult) diff --git a/test/core.jl b/test/core.jl index 566b5ac..f7dccde 100644 --- a/test/core.jl +++ b/test/core.jl @@ -285,7 +285,7 @@ end seekstart(io) mach2 = machine(io) close(io) - @test_broken MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) + @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) end end