Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update type checking to deal with unsup transformers with target #705

Closed
wants to merge 6 commits into from

Conversation

pazzo83
Copy link
Collaborator

@pazzo83 pazzo83 commented Dec 31, 2021

Updates for type checking when dealing with an unsupervised transformer that relies on a target (a la https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest)

@codecov-commenter
Copy link

codecov-commenter commented Jan 1, 2022

Codecov Report

Merging #705 (ac4c171) into dev (5d8c78c) will decrease coverage by 0.02%.
The diff coverage is 75.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #705      +/-   ##
==========================================
- Coverage   86.53%   86.51%   -0.03%     
==========================================
  Files          36       36              
  Lines        3401     3403       +2     
==========================================
+ Hits         2943     2944       +1     
- Misses        458      459       +1     
Impacted Files Coverage Δ
src/machines.jl 83.67% <75.00%> (-0.35%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5d8c78c...ac4c171. Read the comment docs.

Copy link
Member

@OkonSamuel OkonSamuel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test to avoid us accidentally breaking this in the future.

@@ -186,7 +186,11 @@ function check(model::Union{Supervised, SupervisedAnnotator}, args... ; full = f
end

function check(model::Unsupervised, args...; full=false)
check_unsupervised(model, full, args...)
if fit_data_scitype(model) <: NTuple{2, Any}
check_supervised(model, full, args...)
Copy link
Member

@OkonSamuel OkonSamuel Jan 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should add a check here just after line 189 to make sure length(args) > 1 and throw an informative error message. I know we already have such check implemented in the check_supervised function. But that throws err_supervised_nargs() which I think could confuse the user.
Another option may be to modify err_supervised_nargs() to err_nargs(model) and throw an informative error depending on the model super type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call - I wasn't sure whether or not to check for explicitly two args or just more than one (given my check on the type being a particular NTuple) - what do you think?

Copy link
Member

@ablaom ablaom Jan 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that there could be incorrect assumptions here about the form of the fit_data_scitype trait.

In particular, note that this could be a Union type, for models that support multiple signatures (for example, with or without weights):

julia> MLJ.fit_data_scitype(KNNRegressor)
Union{Tuple{Table{var"#s28"} where var"#s28"<:(AbstractVector{var"#s29"} where var"#s29"<:Continuous), AbstractVector{Continuous}}, Tuple{Table{var"#s28"} where var"#s28"<:(AbstractVector{var"#s29"} where var"#s29"<:Continuous), AbstractVector{Continuous}, AbstractVector{Union{Continuous, Count}}}}

So extracting the allowed number of arguments is not as simple as "matching" it to some NTuple type as I see here. Things are looking complicated and that's not ideal. So maybe we need a rethink here.

Our desire for a more general API is at odds with the current attempt to provide an exhaustively informative error message. Perhaps we should simplify our checks and, in the interests of keeping logic simple and maintainable, we:

  • throw out all the Supervised/Unsupervised/Static case distinctions
  • just check that the provided signature matches the fit_data_scitype trait (unless the latter is Unknown) and if not, throw a generic error. Something like:

ArgumentError("The number and/or types of data arguments do not match what the specified model supports. Commonly, but non exclusively, supervised models are constructed using the syntax machine(model, X, y) or machine(model, X, y, w), while most other models with machine(model, X). In general, data in machine(model, data...) must satisfy scitype(data) <: MLJ.fit_data_scitype(model), unless the right-hand side is Unknown. ")

@OkonSamuel @pazzo83 Thoughts?

cc @davnn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes a lot of sense to me. For my particular use case of the feature transformer with a target, it doesn't squarely fit into the Supervised/Unsupervised paradigm. I think it is fine to make sure the way a model being used matches its defined signature (via fit_data_scitiype) - that adds a lot more flexibility for those who want to add certain model types to the MLJ ecosystem that might not explicitly follow the standard definitions for supervised or unsupervised models.

If we decide to go this route - I can just update this PR or close it and start a new one!

Copy link
Member

@ablaom ablaom Jan 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. How about you sleep on it for a night or two. If no-one else pitches in, start a new PR (there's a merge conflict anyhow).

Thanks for staying flexible on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • throw out all the Supervised/Unsupervised/Static case distinctions

I think that's the way to go, preferably all over MLJ in the longer term. It would be nice if we could infer the model type based on the fit signature, something like (assuming verbosity is a kwarg in the following example):

mutable struct Model end
function fit(::Model, X; a,b,c) end
function fit(::Model, X, y; a,b,c) end
function fit(::Model, X, y, w; a,b,c) end

# static has no fit
# ...
# unsupervised fit
Base.hasmethod(fit, Tuple{Model, Union{}})
# supervised fit
Base.hasmethod(fit, Tuple{Model, Union{}, Union{}})
# supervised, weighted fit
Base.hasmethod(fit, Tuple{Model, Union{}, Union{}, Union{}})

The limitation of this approach is that we cannot define weighted static/unsupervised models, that would require us to move y and w to kwargs, differentiate them somehow by type, or define fit_data_scitype explicitly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or define fit_data_scitype explicitly

Yeah, probably that.

This is probably not the place for a general design discussion, but here is a snapshot of my current thinking, which is open to change: We move away, in the base API design, from the idea that there are a few well-defined model types altogether. Rather, a model implementation explicitly declares the following principal traits, in every case:

  • the scitypes of the training arguments (fit_data_scitype)
  • the operations that are explicitly defined or overloaded (generally a subset of [predict, transform] )
  • the scitypes of the outputs of these operations (currently called target_scitype, output_scitype, respectively)
  • whether predict returns :probabilistic, :deterministic or :interval (prediction_type).

Other traits (eg, is_supervised, supports_weights) are used to declare more subtle behaviour, like the role of a particular training argument, if present ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the feedback! What I will do is close this PR and spin up a new one reflecting the discussion here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants