Skip to content

Commit

Permalink
Merge pull request #43 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.4.2 release
  • Loading branch information
ablaom authored Jan 27, 2022
2 parents ab5548e + 8915445 commit 1e738e6
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.6'
- '1'
os:
- ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.4.1"
version = "0.4.2"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand All @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[compat]
IterationControl = "0.5"
MLJBase = "0.18.8, 0.19"
julia = "1"
julia = "1.6"

[extras]
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand Down
63 changes: 46 additions & 17 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
const ERR_MISSING_TRAINING_CONTROL =
ArgumentError("At least one control must be a training control "*
"(have type `$TrainingControl`) or be a "*
"custom control that calls IterationControl.train!. ")

const IterationResamplingTypes =
Union{Holdout,Nothing,MLJBase.TrainTestPairs}

Expand Down Expand Up @@ -37,11 +32,15 @@ mutable struct ProbabilisticIteratedModel{M<:Probabilistic} <: MLJBase.Probabili
cache::Bool
end

const ERR_MISSING_TRAINING_CONTROL =
ArgumentError("At least one control must be a training control "*
"(have type `$TrainingControl`) or be a "*
"custom control that calls IterationControl.train!. ")

const ERR_TOO_MANY_ARGUMENTS =
ArgumentError("At most one non-keyword argument allowed. ")
const EitherIteratedModel{M} =
Union{DeterministicIteratedModel{M},ProbabilisticIteratedModel{M}}

const ERR_NO_MODEL =
ArgumentError("You need to specify model=... ")
const ERR_NOT_SUPERVISED =
ArgumentError("Only `Deterministic` and `Probabilistic` "*
"model types supported.")
Expand All @@ -53,6 +52,13 @@ const ERR_NEED_PARAMETER =
"parameter. Please specify `iteration_parameter=...`. This "*
"must be a `Symbol` or, in the case of a nested parameter, "*
"an `Expr` (as in `booster.nrounds`). ")
const ERR_MODEL_UNSPECIFIED = ArgumentError(
"Expecting atomic model as argument, or as keyword argument `model=...`, "*
"but neither detected. ")


err_bad_iteration_parameter(p) =
ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ")

"""
IteratedModel(model=nothing,
Expand Down Expand Up @@ -169,7 +175,8 @@ updated to the last value used in the preceding `fit!(mach)` call. Then
repeated application of the (updated) controls begin anew.
"""
function IteratedModel(; model=nothing,
function IteratedModel(args...;
model=nothing,
control=CONTROLS_DEFAULT,
controls=control,
resampling=MLJBase.Holdout(),
Expand All @@ -183,10 +190,18 @@ function IteratedModel(; model=nothing,
iteration_parameter=nothing,
cache=true)

model == nothing && throw(ERR_NO_MODEL)
length(args) < 2 || throw(ArgumentError("At most one non-keyword argument allowed. "))
if length(args) === 1
atom = first(args)
model === nothing ||
@warn "Using `model=$atom`. Ignoring specification `model=$model`. "
else
model === nothing && throw(ERR_MODEL_UNSPECIFIED)
atom = model
end

if model isa Deterministic
iterated_model = DeterministicIteratedModel(model,
if atom isa Deterministic
iterated_model = DeterministicIteratedModel(atom,
controls,
resampling,
measure,
Expand All @@ -197,8 +212,8 @@ function IteratedModel(; model=nothing,
check_measure,
iteration_parameter,
cache)
elseif model isa Probabilistic
iterated_model = ProbabilisticIteratedModel(model,
elseif atom isa Probabilistic
iterated_model = ProbabilisticIteratedModel(atom,
controls,
resampling,
measure,
Expand All @@ -220,6 +235,8 @@ function IteratedModel(; model=nothing,

end



function MLJBase.clean!(iterated_model::EitherIteratedModel)
message = ""
if iterated_model.measure === nothing &&
Expand All @@ -232,9 +249,21 @@ function MLJBase.clean!(iterated_model::EitherIteratedModel)
"Setting measure=$(iterated_model.measure). "
end
end
iterated_model.iteration_parameter === nothing &&
iteration_parameter(iterated_model.model) === nothing &&
throw(ERR_NEED_PARAMETER)
if iterated_model.iteration_parameter === nothing
iterated_model.iteration_parameter = iteration_parameter(iterated_model.model)
if iterated_model.iteration_parameter === nothing
throw(ERR_NEED_PARAMETER)
else
message *= "No iteration parameter specified. "*
"Setting iteration_parameter=:($(iterated_model.iteration_parameter)). "
end
end
try
MLJBase.recursive_getproperty(iterated_model.model,
iterated_model.iteration_parameter)
catch
throw(err_bad_iteration_parameter(iterated_model.iteration_parameter))
end

resampling = iterated_model.resampling

Expand Down
6 changes: 1 addition & 5 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,7 @@ end
function MLJBase.fit(iterated_model::EitherIteratedModel, verbosity, data...)

model = deepcopy(iterated_model.model)

# get name of iteration parameter:
_iter = MLJBase.iteration_parameter(model)
iteration_param = _iter === nothing ?
iterated_model.iteration_parameter : _iter
iteration_param = iterated_model.iteration_parameter

# instantiate `train_mach`:
mach = if iterated_model.resampling === nothing
Expand Down
2 changes: 1 addition & 1 deletion src/ic_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end
# overloading `expose`- for `resampling === nothing`:
IterationControl.expose(ic_model::ICModel) = ic_model.mach

# overloading `expose`- for `resampling isa Holdout` or
# overloading `expose`- for `resampling isa Holdout` or
# other resampling strategy:
IterationControl.expose(ic_model::ICModel{<:Machine{<:Resampler}}) =
MLJBase.fitted_params(ic_model.mach).machine
Expand Down
25 changes: 19 additions & 6 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,28 @@ using Test

struct Foo <: MLJBase.Unsupervised end
struct Bar <: MLJBase.Deterministic end
struct FooBar <: MLJBase.Deterministic end

@testset "constructors" begin
model = DummyIterativeModel()
@test_throws MLJIteration.ERR_NO_MODEL IteratedModel()
@test_throws MLJIteration.ERR_TOO_MANY_ARGUMENTS IteratedModel(1, 2)
@test_throws MLJIteration.ERR_MODEL_UNSPECIFIED IteratedModel()
@test_throws MLJIteration.ERR_NOT_SUPERVISED IteratedModel(model=Foo())
@test_throws MLJIteration.ERR_NOT_SUPERVISED IteratedModel(model=Int)
@test_throws MLJIteration.ERR_NEED_MEASURE IteratedModel(model=Bar())
@test_throws MLJIteration.ERR_NEED_PARAMETER IteratedModel(model=Bar(),
measure=rms)
iterated_model = @test_logs((:info, r"No measure"),
measure=rms)
iterated_model = @test_logs((:info, "No measure specified. Setting "*
"measure=RootMeanSquaredError(). No "*
"iteration parameter specified. "*
"Setting iteration_parameter=:(n). "),
IteratedModel(model=model))
@test iterated_model.measure == RootMeanSquaredError()
@test_logs IteratedModel(model=model, measure=mae)
@test iterated_model.iteration_parameter == :n
@test_logs IteratedModel(model=model, measure=mae, iteration_parameter=:n)
@test_logs IteratedModel(model, measure=mae, iteration_parameter=:n)

@test_logs IteratedModel(model=model, resampling=nothing)
@test_logs IteratedModel(model=model, resampling=nothing, iteration_parameter=:n)

@test_logs((:info, r"`resampling` must be"),
IteratedModel(model=model,
Expand All @@ -34,12 +41,18 @@ struct Bar <: MLJBase.Deterministic end
measure=rms))
@test_logs IteratedModel(model=model,
resampling=[([1, 2], [3, 4]),],
measure=rms)
measure=rms,
iteration_parameter=:n)

@test_throws(MLJIteration.ERR_MISSING_TRAINING_CONTROL,
IteratedModel(model=model,
resampling=nothing,
controls=[Patience(), InvalidValue()]))

@test_throws(MLJIteration.err_bad_iteration_parameter(:goo),
IteratedModel(model=model,
measure=mae,
iteration_parameter=:goo))
end

end
Expand Down

0 comments on commit 1e738e6

Please sign in to comment.