Skip to content

Commit

Permalink
Merge pull request #59 from JuliaAI/save-restore
Browse files Browse the repository at this point in the history
Overload `save` and  `restore` to fix a serialization issue
  • Loading branch information
ablaom authored Mar 7, 2024
2 parents b12c0bc + 4db6d59 commit 39f58e4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.6.0"
version = "0.6.1"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand Down
4 changes: 4 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,7 @@ MLJBase.predict(::EitherIteratedModel, fitresult, Xnew) =

MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) =
transform(fitresult, Xnew)

# here `fitresult` is a trained atomic machine:
MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult)
MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult)
49 changes: 49 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,54 @@ end
@test iteration_parameter(model) == :n
end

# define a supervised model with ephemeral `fitresult`, but which overcomes this by
# overloading `save`/`restore`:
thing = []
mutable struct EphemeralRegressor <: Deterministic
n::Int # dummy iteration parameter
end
EphemeralRegressor(; n=1) = EphemeralRegressor(n)
function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
# if I serialize/deserialized `thing` then `id` below changes:
id = objectid(thing)
fitresult = (thing, id, mean(y))
return fitresult, nothing, NamedTuple()
end
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
thing, id, μ = fitresult
return id == objectid(thing) ? fill(μ, nrows(X)) :
throw(ErrorException("dead fitresult"))
end
MLJBase.iteration_parameter(::EphemeralRegressor) = :n
function MLJBase.save(::EphemeralRegressor, fitresult)
thing, _, μ = fitresult
return (thing, μ)
end
function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
thing, μ = serialized_fitresult
id = objectid(thing)
return (thing, id, μ)
end

@testset "save and restore" begin
#https://github.com/alan-turing-institute/MLJ.jl/issues/1099
X, y = (; x = rand(10)), fill(42.0, 3)
controls = [Step(1), NumberLimit(2)]
imodel = IteratedModel(
EphemeralRegressor(42);
measure=l2,
resampling=Holdout(),
controls,
)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
io = IOBuffer()
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
close(io)
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
end

end
true

0 comments on commit 39f58e4

Please sign in to comment.