-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update type checking to deal with unsup transformers with target #705
Conversation
Codecov Report
@@ Coverage Diff @@
## dev #705 +/- ##
==========================================
- Coverage 86.53% 86.51% -0.03%
==========================================
Files 36 36
Lines 3401 3403 +2
==========================================
+ Hits 2943 2944 +1
- Misses 458 459 +1
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a test to avoid us accidentally breaking this in the future.
@@ -186,7 +186,11 @@ function check(model::Union{Supervised, SupervisedAnnotator}, args... ; full = f | |||
end | |||
|
|||
function check(model::Unsupervised, args...; full=false) | |||
check_unsupervised(model, full, args...) | |||
if fit_data_scitype(model) <: NTuple{2, Any} | |||
check_supervised(model, full, args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we should add a check here just after line 189 to make sure length(args) > 1
and throw an informative error message. I know we already have such check implemented in the check_supervised
function. But that throws err_supervised_nargs()
which I think could confuse the user.
Another option may be to modify err_supervised_nargs()
to err_nargs(model)
and throw an informative error depending on the model super type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call - I wasn't sure whether or not to check for explicitly two args or just more than one (given my check on the type being a particular NTuple
) - what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see that there could be incorrect assumptions here about the form of the fit_data_scitype
trait.
In particular, note that this could be a Union
type, for models that support multiple signatures (for example, with or without weights):
julia> MLJ.fit_data_scitype(KNNRegressor)
Union{Tuple{Table{var"#s28"} where var"#s28"<:(AbstractVector{var"#s29"} where var"#s29"<:Continuous), AbstractVector{Continuous}}, Tuple{Table{var"#s28"} where var"#s28"<:(AbstractVector{var"#s29"} where var"#s29"<:Continuous), AbstractVector{Continuous}, AbstractVector{Union{Continuous, Count}}}}
So extracting the allowed number of arguments is not as simple as "matching" it to some NTuple
type as I see here. Things are looking complicated and that's not ideal. So maybe we need a rethink here.
Our desire for a more general API is at odds with the current attempt to provide an exhaustively informative error message. Perhaps we should simplify our checks and, in the interests of keeping logic simple and maintainable, we:
- throw out all the
Supervised
/Unsupervised
/Static
case distinctions - just check that the provided signature matches the
fit_data_scitype
trait (unless the latter isUnknown
) and if not, throw a generic error. Something like:
ArgumentError("The number and/or types of data arguments do not match what the specified model supports. Commonly, but non exclusively, supervised models are constructed using the syntax machine(model, X, y)
or machine(model, X, y, w)
, while most other models with machine(model, X)
. In general, data
in machine(model, data...)
must satisfy scitype(data) <: MLJ.fit_data_scitype(model)
, unless the right-hand side is Unknown
. ")
@OkonSamuel @pazzo83 Thoughts?
cc @davnn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes a lot of sense to me. For my particular use case of the feature transformer with a target, it doesn't squarely fit into the Supervised/Unsupervised paradigm. I think it is fine to make sure the way a model being used matches its defined signature (via fit_data_scitiype
) - that adds a lot more flexibility for those who want to add certain model types to the MLJ
ecosystem that might not explicitly follow the standard definitions for supervised or unsupervised models.
If we decide to go this route - I can just update this PR or close it and start a new one!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. How about you sleep on it for a night or two. If no-one else pitches in, start a new PR (there's a merge conflict anyhow).
Thanks for staying flexible on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- throw out all the
Supervised
/Unsupervised
/Static
case distinctions
I think that's the way to go, preferably all over MLJ in the longer term. It would be nice if we could infer the model type based on the fit signature, something like (assuming verbosity
is a kwarg in the following example):
mutable struct Model end
function fit(::Model, X; a,b,c) end
function fit(::Model, X, y; a,b,c) end
function fit(::Model, X, y, w; a,b,c) end
# static has no fit
# ...
# unsupervised fit
Base.hasmethod(fit, Tuple{Model, Union{}})
# supervised fit
Base.hasmethod(fit, Tuple{Model, Union{}, Union{}})
# supervised, weighted fit
Base.hasmethod(fit, Tuple{Model, Union{}, Union{}, Union{}})
The limitation of this approach is that we cannot define weighted static/unsupervised models, that would require us to move y
and w
to kwargs, differentiate them somehow by type, or define fit_data_scitype
explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or define
fit_data_scitype
explicitly
Yeah, probably that.
This is probably not the place for a general design discussion, but here is a snapshot of my current thinking, which is open to change: We move away, in the base API design, from the idea that there are a few well-defined model types altogether. Rather, a model implementation explicitly declares the following principal traits, in every case:
- the scitypes of the training arguments (
fit_data_scitype
) - the operations that are explicitly defined or overloaded (generally a subset of [
predict
,transform
] ) - the scitypes of the outputs of these operations (currently called
target_scitype
,output_scitype
, respectively) - whether
predict
returns:probabilistic
,:deterministic
or:interval
(prediction_type
).
Other traits (eg, is_supervised
, supports_weights
) are used to declare more subtle behaviour, like the role of a particular training argument, if present ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the feedback! What I will do is close this PR and spin up a new one reflecting the discussion here.
Updates for type checking when dealing with an unsupervised transformer that relies on a target (a la https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest)