From 8322851c8a18e8541c2e124f5330fcdc4e847a56 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 22 Apr 2021 15:26:08 +1200 Subject: [PATCH 1/3] bump compat IterationControl = "0.4" --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index aff48e6..e64bba8 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -IterationControl = "0.3.3" +IterationControl = "0.4" MLJBase = "0.17.7, 0.18" julia = "1" From 50415360c66f6d955f469e00815b2e53aa55bceb Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 22 Apr 2021 15:26:27 +1200 Subject: [PATCH 2/3] address new API for IterationControl controls --- src/controls.jl | 9 +++++++-- test/controls.jl | 32 ++++++++++++++++---------------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/controls.jl b/src/controls.jl index e41f357..aee5a9a 100644 --- a/src/controls.jl +++ b/src/controls.jl @@ -29,7 +29,9 @@ IterationControl.@create_docs( function IterationControl.update!(c::WithIterationsDo, ic_model, - verbosity, state...) + verbosity, + n, + state...) i = ic_model.n_iterations r = c.f(i) done = (c.stop_if_true && r isa Bool && r) ? true : false @@ -83,7 +85,9 @@ IterationControl.@create_docs( function IterationControl.update!(c::WithEvaluationDo, ic_model, - verbosity, state...) + verbosity, + n, + state...) e = ic_model.evaluation r = c.f(e) done = (c.stop_if_true && r isa Bool && r) ? true : false @@ -156,6 +160,7 @@ end function IterationControl.update!(control::CycleLearningRate, wrapper, verbosity, + ncycles, state = (n = 0,)) n = state.n rates = n == 0 ? one_cycle(control) : state.learning_rates diff --git a/test/controls.jl b/test/controls.jl index 765b4b3..2bc3852 100644 --- a/test/controls.jl +++ b/test/controls.jl @@ -15,11 +15,11 @@ X, y = make_dummy(N=8); c = WithIterationsDo(f) m = MLJIteration.ICModel(machine(DummyIterativeModel(n=0), X, y), :n, 0) IC.train!(m, 2) - state = IC.update!(c, m, 0) + state = IC.update!(c, m, 0, 1) @test !state.done @test v == [2, ] IC.train!(m, 2) - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 2, state) @test !state.done @test v == [2, 4] @test IC.takedown(c, 0, state) == (done = false, log="") @@ -29,11 +29,11 @@ X, y = make_dummy(N=8); c = WithIterationsDo(f2, stop_if_true=true) m = MLJIteration.ICModel(machine(DummyIterativeModel(n=0), X, y), :n, 0) IC.train!(m, 2) - state = IC.update!(c, m, 0) + state = IC.update!(c, m, 0, 1) @test !state.done @test v == [2, ] IC.train!(m, 2) - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 2, state) @test state.done @test v == [2, 4] @test IC.takedown(c, 0, state) == @@ -45,11 +45,11 @@ X, y = make_dummy(N=8); c = WithIterationsDo(f3, stop_if_true=true, stop_message="foo") m = MLJIteration.ICModel(machine(DummyIterativeModel(n=0), X, y), :n, 0) IC.train!(m, 2) - state = IC.update!(c, m, 0) + state = IC.update!(c, m, 0, 1) @test !state.done @test v == [2, ] IC.train!(m, 2) - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 2, state) @test state.done @test v == [2, 4] @test IC.takedown(c, 0, state) == @@ -68,11 +68,11 @@ resampler = MLJBase.Resampler(model=DummyIterativeModel(n=0), resampling_machine = machine(deepcopy(resampler), X, y) m = MLJIteration.ICModel(resampling_machine, :n, 0) IC.train!(m, 2) - state = IC.update!(c, m, 0) + state = IC.update!(c, m, 0, 1) @test !state.done @test length(v) == 1 IC.train!(m, 2) - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 2, state) @test !state.done @test length(v) == 2 @test IC.takedown(c, 0, state) == (done = false, log="") @@ -83,11 +83,11 @@ resampler = MLJBase.Resampler(model=DummyIterativeModel(n=0), resampling_machine = machine(deepcopy(resampler), X, y) m = MLJIteration.ICModel(resampling_machine, :n, 0) IC.train!(m, 2) - state = IC.update!(c, m, 0) + state = IC.update!(c, m, 0, 1) @test !state.done @test length(v) == 1 IC.train!(m, 2) - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 2, state) @test state.done @test length(v) == 2 @test IC.takedown(c, 0, state) == @@ -100,11 +100,11 @@ resampler = MLJBase.Resampler(model=DummyIterativeModel(n=0), resampling_machine = machine(deepcopy(resampler), X, y) m = MLJIteration.ICModel(resampling_machine, :n, 0) IC.train!(m, 2) - state = IC.update!(c, m, 0) + state = IC.update!(c, m, 0, 1) @test !state.done @test length(v) == 1 IC.train!(m, 2) - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 2, state) @test state.done @test length(v) == 2 @test IC.takedown(c, 0, state) == @@ -125,13 +125,13 @@ end upper = 1.5) model = DummyIterativeModel(n=0, learning_rate=42) m = MLJIteration.ICModel(machine(model, X, y), :n, 0) - state = @test_logs (:info, r"learning rate") IC.update!(c, m, 2) - state = @test_logs IC.update!(c, m, 1) + state = @test_logs (:info, r"learning rate") IC.update!(c, m, 2, 1) + state = @test_logs IC.update!(c, m, 1, 1) @test state == (n = 1, learning_rates = [0.5, 1.5]) @test model.learning_rate == 0.5 - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 2, state) @test model.learning_rate == 1.5 - state = IC.update!(c, m, 0, state) + state = IC.update!(c, m, 0, 3, state) @test model.learning_rate == 0.5 end From d062effb34aff79f56c2c3944b42ac40a44c4044 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 23 Apr 2021 12:41:11 +1200 Subject: [PATCH 3/3] bump 0.3.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e64bba8..21e728d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.2.3" +version = "0.3.0" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"