diff --git a/src/machines.jl b/src/machines.jl index c251b22b..34364f98 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -93,9 +93,17 @@ warn_scitype(model::Supervised, X) = "input_scitype(model) = $(input_scitype(model))." warn_generic_scitype_mismatch(S, F) = - "The scitype of `args` in `machine(model, args...; kwargs)` "* - "does not match the scitype "* - "expected by model's `fit` method.\n"* + "The number and/or types of data arguments do not " * + "match what the specified model supports. Commonly, " * + "but non exclusively, supervised models are constructed " * + "using the syntax `machine(model, X, y)` or `machine(model, X, y, w)` " * + "while most other models with `machine(model, X)`. " * + "Here `X` are features, `y` a target, and `w` sample or class weights. " * + "In general, data in `machine(model, data...)` must satisfy " * + "`scitype(data) <: MLJ.fit_data_scitype(model)` unless the " * + "right-hand side is `Unknown`. Here, the scitype of `args` " * + "in `machine(model, args...; kwargs)` does not match the scitype " * + "expected by model's `fit` method.\n" * " provided: $S\n expected by fit: $F" warn_scitype(model::Supervised, X, y) = @@ -117,56 +125,6 @@ err_length_mismatch(model::Supervised) = DimensionMismatch( check(model::Any, args...; kwargs...) = throw(ArgumentError("Expected a `Model` instance, got $model. ")) -function check_supervised(model, full, args...) - nowarns = true - - nargs = length(args) - nargs > 1 || throw(err_supervised_nargs()) - - full || return nowarns - - X, y = args[1:2] - - # checks on input type: - input_scitype(model) <: Unknown || - elscitype(X) <: input_scitype(model) || begin - @warn warn_scitype(model, X) - nowarns=false - end - - # checks on target type: - target_scitype(model) <: Unknown || - elscitype(y) <: target_scitype(model) || begin - @warn warn_scitype(model, X, y) - nowarns=false - end - - # checks on dimension matching: - scitype(X) == CallableReturning{Nothing} || nrows(X()) == nrows(y()) || - throw(err_length_mismatch(model)) - - return nowarns - -end - -function check_unsupervised(model, full, args...) - nowarns = true - - nargs = length(args) - nargs <= 1 || throw(err_unsupervised_nargs()) - - if full && nargs == 1 - X = args[1] - # check input scitype - input_scitype(model) <: Unknown || - elscitype(X) <: input_scitype(model) || begin - @warn warn_scitype(model, X) - nowarns=false - end - end - return nowarns -end - function check(model::Model, args...; full=false) nowarns = true @@ -179,21 +137,13 @@ function check(model::Model, args...; full=false) @warn warn_generic_scitype_mismatch(S, F) nowarns = false end -end - -function check(model::Union{Supervised, SupervisedAnnotator}, args... ; full = false) - check_supervised(model, full, args...) -end -function check(model::Unsupervised, args...; full=false) - check_unsupervised(model, full, args...) -end + if length(args) > 1 + X, y = args[1:2] -function check(model::UnsupervisedAnnotator, args... ; full = false) - if length(args) <= 1 - check_unsupervised(model, full, args...) - else - check_supervised(model, full, args...) + # checks on dimension matching: + scitype(X) == CallableReturning{Nothing} || nrows(X()) == nrows(y()) || + throw(err_length_mismatch(model)) end end