Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove specific model checks for general scitype check #731

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 16 additions & 66 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,17 @@ warn_scitype(model::Supervised, X) =
"input_scitype(model) = $(input_scitype(model))."

warn_generic_scitype_mismatch(S, F) =
"The scitype of `args` in `machine(model, args...; kwargs)` "*
"does not match the scitype "*
"expected by model's `fit` method.\n"*
"The number and/or types of data arguments do not " *
"match what the specified model supports. Commonly, " *
"but non exclusively, supervised models are constructed " *
"using the syntax `machine(model, X, y)` or `machine(model, X, y, w)` " *
"while most other models with `machine(model, X)`. " *
"Here `X` are features, `y` a target, and `w` sample or class weights. " *
"In general, data in `machine(model, data...)` must satisfy " *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 102: Put "data" in back-quotes: "`data`".

"`scitype(data) <: MLJ.fit_data_scitype(model)` unless the " *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can drop the qualifier MLJ. from MLJ.fit_data_scitype(model) as this is now exported.

"right-hand side is `Unknown`. Here, the scitype of `args` " *
"in `machine(model, args...; kwargs)` does not match the scitype " *
"expected by model's `fit` method.\n" *
" provided: $S\n expected by fit: $F"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we replace the last sentence (from the original string, which is using different notation) with:

"In the present case:\n"*
"scitype(data) = $S\n"*
"and\n"*
"fit_data_scitype(model) = $F"


warn_scitype(model::Supervised, X, y) =
Expand All @@ -117,56 +125,6 @@ err_length_mismatch(model::Supervised) = DimensionMismatch(
check(model::Any, args...; kwargs...) =
throw(ArgumentError("Expected a `Model` instance, got $model. "))

function check_supervised(model, full, args...)
nowarns = true

nargs = length(args)
nargs > 1 || throw(err_supervised_nargs())

full || return nowarns

X, y = args[1:2]

# checks on input type:
input_scitype(model) <: Unknown ||
elscitype(X) <: input_scitype(model) || begin
@warn warn_scitype(model, X)
nowarns=false
end

# checks on target type:
target_scitype(model) <: Unknown ||
elscitype(y) <: target_scitype(model) || begin
@warn warn_scitype(model, X, y)
nowarns=false
end

# checks on dimension matching:
scitype(X) == CallableReturning{Nothing} || nrows(X()) == nrows(y()) ||
throw(err_length_mismatch(model))

return nowarns

end

function check_unsupervised(model, full, args...)
nowarns = true

nargs = length(args)
nargs <= 1 || throw(err_unsupervised_nargs())

if full && nargs == 1
X = args[1]
# check input scitype
input_scitype(model) <: Unknown ||
elscitype(X) <: input_scitype(model) || begin
@warn warn_scitype(model, X)
nowarns=false
end
end
return nowarns
end

function check(model::Model, args...; full=false)
nowarns = true

Expand All @@ -179,21 +137,13 @@ function check(model::Model, args...; full=false)
@warn warn_generic_scitype_mismatch(S, F)
nowarns = false
end
end

function check(model::Union{Supervised, SupervisedAnnotator}, args... ; full = false)
check_supervised(model, full, args...)
end

function check(model::Unsupervised, args...; full=false)
check_unsupervised(model, full, args...)
end
if length(args) > 1
X, y = args[1:2]

function check(model::UnsupervisedAnnotator, args... ; full = false)
if length(args) <= 1
check_unsupervised(model, full, args...)
else
check_supervised(model, full, args...)
# checks on dimension matching:
scitype(X) == CallableReturning{Nothing} || nrows(X()) == nrows(y()) ||
throw(err_length_mismatch(model))
end
end

Expand Down