Skip to content

Commit

Permalink
Merge pull request #15 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.2.2 release
  • Loading branch information
ablaom authored Apr 8, 2021
2 parents 8631981 + fdfe798 commit 452f63a
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 8 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.2.1"
version = "0.2.2"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand All @@ -14,11 +14,10 @@ MLJBase = "0.17.7, 0.18"
julia = "1"

[extras]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "MLJModelInterface", "StableRNGs", "Statistics", "Test"]
test = ["MLJModelInterface", "StableRNGs", "Statistics", "Test"]
9 changes: 8 additions & 1 deletion src/MLJIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ const CONTROLS = vcat(IterationControl.CONTROLS,
:WithEvaluationDo,
:CycleLearningRate])

const TRAINING_CONTROLS = [:Step, ]

# export all control types:
for control in CONTROLS
eval(:(export $control))
Expand All @@ -24,11 +26,16 @@ const CONTROLS_DEFAULT = [Step(10),
NotANumber()]

include("utilities.jl")
include("controls.jl")

const Control = Union{[@eval($c) for c in CONTROLS]...}
const TrainingControl = Union{[@eval($c) for c in TRAINING_CONTROLS]...}

include("constructors.jl")
include("traits.jl")
include("ic_model.jl")
include("controls.jl")
include("core.jl")



end # module
23 changes: 21 additions & 2 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
const ERR_MISSING_TRAINING_CONTROL =
ArgumentError("At least one control must be a training control "*
"(have type `$TrainingControl`) or be a "*
"custom control that calls IterationControl.train!. ")


## TYPES AND CONSTRUCTOR

mutable struct DeterministicIteratedModel{M<:Deterministic} <: MLJBase.Deterministic
Expand Down Expand Up @@ -176,8 +182,6 @@ function IteratedModel(; model=nothing,

model == nothing && throw(ERR_NO_MODEL)



if model isa Deterministic
iterated_model = DeterministicIteratedModel(model,
controls,
Expand Down Expand Up @@ -229,5 +233,20 @@ function MLJBase.clean!(iterated_model::EitherIteratedModel)
iteration_parameter(iterated_model.model) === nothing &&
throw(ERR_NEED_PARAMETER)

if iterated_model.resampling isa Holdout &&
iterated_model.resampling.shuffle
message *= "The use of sample-shuffling in `Holdout` "*
"will significantly slow training as "*
"each increment of the iteration parameter "*
"will force iteration from scratch (cold restart). "
end

training_control_candidates = filter(iterated_model.controls) do c
c isa TrainingControl || !(c isa Control)
end
if isempty(training_control_candidates)
throw(ERR_MISSING_TRAINING_CONTROL)
end

return message
end
2 changes: 1 addition & 1 deletion test/_dummy_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export DummyIterativeModel, make_dummy
using Random
using Statistics
import StableRNGs.LehmerRNG
using CategoricalArrays
using MLJBase.CategoricalArrays
import Base.==

using MLJModelInterface
Expand Down
12 changes: 11 additions & 1 deletion test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@ struct Bar <: MLJBase.Deterministic end
@test iterated_model.measure == RootMeanSquaredError()
@test_logs IteratedModel(model=model, measure=mae)

iterated_model = @test_logs IteratedModel(model=model, resampling=nothing)
@test_logs IteratedModel(model=model, resampling=nothing)

@test_logs((:info, r"The use of sample"),
IteratedModel(model=model,
resampling=Holdout(rng=123),
measure=rms))

@test_throws(MLJIteration.ERR_MISSING_TRAINING_CONTROL,
IteratedModel(model=model,
resampling=nothing,
controls=[Patience(), NotANumber()]))
end

end
Expand Down

0 comments on commit 452f63a

Please sign in to comment.