Skip to content

Commit

Permalink
Merge pull request #137 from JuliaAI/patch-handling-weights-in-superv…
Browse files Browse the repository at this point in the history
…ised-fit-data-scitype

Fix handling weights in supervised fit data scitype
  • Loading branch information
ablaom authored Jan 28, 2022
2 parents 2a77b27 + c2b4a66 commit d207f88
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.3.5"
version = "1.3.6"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
6 changes: 3 additions & 3 deletions src/model_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ function supervised_fit_data_scitype(M)
T = target_scitype(M)
ret = Tuple{I, T}
if supports_weights(M)
W = AbstractVector{Union{Continuous, Count}} # weight scitype
W = AbstractVector{<:Union{Continuous, Count}} # weight scitype
return Union{ret, Tuple{I, T, W}}
elseif supports_class_weights(M)
W = AbstractDict{Finite, Union{Continuous, Count}}
W = AbstractDict{Finite, <:Union{Continuous, Count}}
return Union{ret, Tuple{I, T, W}}
end
return ret
Expand All @@ -67,7 +67,7 @@ StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) = Tuple{input_scitype(M)}
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
StatTraits.fit_data_scitype(M::Type{<:Supervised}) = supervised_fit_data_scitype(M)

# In special case of `UnsupervisedAnnotator`, we allow the target
# In special case of `UnsupervisedAnnotator`, we allow the target
# as an optional argument to `fit` (that is ignored) so that the
# `machine` constructor will accept it as a valid argument, which
# then enables *evaluation* of the detector with labeled data:
Expand Down

0 comments on commit d207f88

Please sign in to comment.