Skip to content

Commit

Permalink
Merge pull request #21 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.3.0 release
  • Loading branch information
ablaom authored Apr 23, 2021
2 parents c488ecd + af9c484 commit 5943532
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.2.3"
version = "0.3.0"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
IterationControl = "0.3.3"
IterationControl = "0.4"
MLJBase = "0.17.7, 0.18"
julia = "1"

Expand Down
9 changes: 7 additions & 2 deletions src/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ IterationControl.@create_docs(

function IterationControl.update!(c::WithIterationsDo,
ic_model,
verbosity, state...)
verbosity,
n,
state...)
i = ic_model.n_iterations
r = c.f(i)
done = (c.stop_if_true && r isa Bool && r) ? true : false
Expand Down Expand Up @@ -83,7 +85,9 @@ IterationControl.@create_docs(

function IterationControl.update!(c::WithEvaluationDo,
ic_model,
verbosity, state...)
verbosity,
n,
state...)
e = ic_model.evaluation
r = c.f(e)
done = (c.stop_if_true && r isa Bool && r) ? true : false
Expand Down Expand Up @@ -156,6 +160,7 @@ end
function IterationControl.update!(control::CycleLearningRate,
wrapper,
verbosity,
ncycles,
state = (n = 0,))
n = state.n
rates = n == 0 ? one_cycle(control) : state.learning_rates
Expand Down
32 changes: 16 additions & 16 deletions test/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ X, y = make_dummy(N=8);
c = WithIterationsDo(f)
m = MLJIteration.ICModel(machine(DummyIterativeModel(n=0), X, y), :n, 0)
IC.train!(m, 2)
state = IC.update!(c, m, 0)
state = IC.update!(c, m, 0, 1)
@test !state.done
@test v == [2, ]
IC.train!(m, 2)
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 2, state)
@test !state.done
@test v == [2, 4]
@test IC.takedown(c, 0, state) == (done = false, log="")
Expand All @@ -29,11 +29,11 @@ X, y = make_dummy(N=8);
c = WithIterationsDo(f2, stop_if_true=true)
m = MLJIteration.ICModel(machine(DummyIterativeModel(n=0), X, y), :n, 0)
IC.train!(m, 2)
state = IC.update!(c, m, 0)
state = IC.update!(c, m, 0, 1)
@test !state.done
@test v == [2, ]
IC.train!(m, 2)
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 2, state)
@test state.done
@test v == [2, 4]
@test IC.takedown(c, 0, state) ==
Expand All @@ -45,11 +45,11 @@ X, y = make_dummy(N=8);
c = WithIterationsDo(f3, stop_if_true=true, stop_message="foo")
m = MLJIteration.ICModel(machine(DummyIterativeModel(n=0), X, y), :n, 0)
IC.train!(m, 2)
state = IC.update!(c, m, 0)
state = IC.update!(c, m, 0, 1)
@test !state.done
@test v == [2, ]
IC.train!(m, 2)
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 2, state)
@test state.done
@test v == [2, 4]
@test IC.takedown(c, 0, state) ==
Expand All @@ -68,11 +68,11 @@ resampler = MLJBase.Resampler(model=DummyIterativeModel(n=0),
resampling_machine = machine(deepcopy(resampler), X, y)
m = MLJIteration.ICModel(resampling_machine, :n, 0)
IC.train!(m, 2)
state = IC.update!(c, m, 0)
state = IC.update!(c, m, 0, 1)
@test !state.done
@test length(v) == 1
IC.train!(m, 2)
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 2, state)
@test !state.done
@test length(v) == 2
@test IC.takedown(c, 0, state) == (done = false, log="")
Expand All @@ -83,11 +83,11 @@ resampler = MLJBase.Resampler(model=DummyIterativeModel(n=0),
resampling_machine = machine(deepcopy(resampler), X, y)
m = MLJIteration.ICModel(resampling_machine, :n, 0)
IC.train!(m, 2)
state = IC.update!(c, m, 0)
state = IC.update!(c, m, 0, 1)
@test !state.done
@test length(v) == 1
IC.train!(m, 2)
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 2, state)
@test state.done
@test length(v) == 2
@test IC.takedown(c, 0, state) ==
Expand All @@ -100,11 +100,11 @@ resampler = MLJBase.Resampler(model=DummyIterativeModel(n=0),
resampling_machine = machine(deepcopy(resampler), X, y)
m = MLJIteration.ICModel(resampling_machine, :n, 0)
IC.train!(m, 2)
state = IC.update!(c, m, 0)
state = IC.update!(c, m, 0, 1)
@test !state.done
@test length(v) == 1
IC.train!(m, 2)
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 2, state)
@test state.done
@test length(v) == 2
@test IC.takedown(c, 0, state) ==
Expand All @@ -125,13 +125,13 @@ end
upper = 1.5)
model = DummyIterativeModel(n=0, learning_rate=42)
m = MLJIteration.ICModel(machine(model, X, y), :n, 0)
state = @test_logs (:info, r"learning rate") IC.update!(c, m, 2)
state = @test_logs IC.update!(c, m, 1)
state = @test_logs (:info, r"learning rate") IC.update!(c, m, 2, 1)
state = @test_logs IC.update!(c, m, 1, 1)
@test state == (n = 1, learning_rates = [0.5, 1.5])
@test model.learning_rate == 0.5
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 2, state)
@test model.learning_rate == 1.5
state = IC.update!(c, m, 0, state)
state = IC.update!(c, m, 0, 3, state)
@test model.learning_rate == 0.5
end

Expand Down

0 comments on commit 5943532

Please sign in to comment.