Skip to content

Commit

Permalink
Merge pull request #62 from JuliaAI/constructor
Browse files Browse the repository at this point in the history
Overload `constructor` trait for `IteratedModel` types
  • Loading branch information
ablaom authored Jun 3, 2024
2 parents 03cd01b + 3bbfab0 commit 99a5dd2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.6.1"
version = "0.6.2"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand All @@ -11,7 +11,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
IterationControl = "0.5"
MLJBase = "1.3"
MLJBase = "1.4"
julia = "1.6"

[extras]
Expand Down
6 changes: 2 additions & 4 deletions src/traits.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
MLJBase.is_wrapper(::Type{<:EitherIteratedModel}) = true
MLJBase.caches_data_by_default(::Type{<:EitherIteratedModel}) = false
MLJBase.load_path(::Type{<:DeterministicIteratedModel}) =
"MLJIteration.DeterministicIteratedModel"
MLJBase.load_path(::Type{<:ProbabilisticIteratedModel}) =
"MLJIteration.ProbabilisticIteratedModel"
MLJBase.load_path(::Type{<:EitherIteratedModel}) = "MLJIteration.IteratedModel"
MLJBase.constructor(::Type{<:EitherIteratedModel}) = IteratedModel
MLJBase.package_name(::Type{<:EitherIteratedModel}) = "MLJIteration"
MLJBase.package_uuid(::Type{<:EitherIteratedModel}) =
"614be32b-d00c-4edb-bd02-1eb411ab5e55"
Expand Down
3 changes: 2 additions & 1 deletion test/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ imodel = IteratedModel(model=model, measure=mae)
@test !MLJBase.caches_data_by_default(imodel)
@test !supports_weights(imodel)
@test !supports_class_weights(imodel)
@test load_path(imodel) == "MLJIteration.DeterministicIteratedModel"
@test load_path(imodel) == "MLJIteration.IteratedModel"
@test package_name(imodel) == "MLJIteration"
@test package_uuid(imodel) == "614be32b-d00c-4edb-bd02-1eb411ab5e55"
@test package_url(imodel) == "https://github.com/JuliaAI/MLJIteration.jl"
Expand All @@ -22,6 +22,7 @@ imodel = IteratedModel(model=model, measure=mae)
@test input_scitype(imodel) == input_scitype(model)
@test output_scitype(imodel) == output_scitype(model)
@test target_scitype(imodel) == target_scitype(model)
@test constructor(imodel) == IteratedModel

end

Expand Down

0 comments on commit 99a5dd2

Please sign in to comment.