Skip to content

Commit

Permalink
Merge pull request #18 from JuliaAI/training-losses-bug-fix
Browse files Browse the repository at this point in the history
Training losses bug fix
  • Loading branch information
ablaom authored Apr 22, 2021
2 parents d2e2bd1 + 28801bb commit d892b94
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.2.2"
version = "0.2.3"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand Down
8 changes: 2 additions & 6 deletions src/ic_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
# (cf. MLJBase/src/resampling.jl). It is for the wrapped type that we
# will be overloading the methods of `IterationControl`.

const ERR_TRAINING_LOSSES =
ArgumentError("Attempt to inspect training losses for "*
"a model that doesn't report them. ")
const ERR_EVALUATION =
ArgumentError("There are no evaluation objects if `resampling=nothing`. ")


mlj_model(mach::Machine) = mach.model
mlj_model(mach::Machine{<:Resampler}) = mach.model.model

Expand Down Expand Up @@ -63,7 +59,7 @@ IterationControl.expose(ic_model::ICModel{<:Machine{<:Resampler}}) =
# overloading `loss` - for `resampling === nothing`:
function IterationControl.loss(m::ICModel)
losses = training_losses(IterationControl.expose(m))
losses === nothing && throw(ERR_TRAINING_LOSSES)
losses isa Nothing && return nothing
return last(losses)
end

Expand All @@ -75,7 +71,7 @@ IterationControl.loss(m::ICModel{<:Machine{<:Resampler}}) =
function IterationControl.training_losses(m::ICModel)
mach = IterationControl.expose(m)
losses = training_losses(mach)
losses === nothing && throw(ERR_TRAINING_LOSSES)
losses isa Nothing && return nothing
s = length(losses)
return view(losses, (s - m.Δi[] + 1):s)
end
Expand Down
9 changes: 3 additions & 6 deletions test/ic_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,13 @@ MLJBase.predict(::FooBar, ::Any, Xnew) = ones(length(Xnew))
y)
fit!(resampling_machine, verbosity=0)
ic_model = MLJIteration.ICModel(resampling_machine, :n, 0)
@test_throws(MLJIteration.ERR_TRAINING_LOSSES,
IterationControl.training_losses(ic_model))
@test IterationControl.training_losses(ic_model) === nothing

mach = machine(FooBar(10), X, y)
fit!(mach, verbosity=0)
ic_model = MLJIteration.ICModel(mach, :n, 0)
@test_throws(MLJIteration.ERR_TRAINING_LOSSES,
IterationControl.training_losses(ic_model))
@test_throws(MLJIteration.ERR_TRAINING_LOSSES,
IterationControl.loss(ic_model))
@test IterationControl.training_losses(ic_model) === nothing
@test IterationControl.loss(ic_model) === nothing
end

@testset "ICModel interface for users" begin
Expand Down

0 comments on commit d892b94

Please sign in to comment.