-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT add scikit-learn wrappers #20599
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20599 +/- ##
==========================================
- Coverage 82.55% 82.52% -0.03%
==========================================
Files 518 525 +7
Lines 48682 48948 +266
Branches 7592 7615 +23
==========================================
+ Hits 40188 40393 +205
- Misses 6669 6719 +50
- Partials 1825 1836 +11
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is amazing! You've made some great improvements over the version in SciKeras.
I agree that the approach you're proposing here of not just copying over SciKeras or trying to reproduce all of the features out of the gate and instead build it up better as a new version over time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/api/wrappers/__init__.py
Outdated
|
||
from keras.src.wrappers._sklearn import KerasClassifier | ||
from keras.src.wrappers._sklearn import KerasRegressor | ||
from keras.src.wrappers._sklearn import KerasTransformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reservations about "KerasTransformer" due to the very specific meaning that Transformer has in deep learning. This is going to confuse the hell out of people who search "Transformer" on keras.io. This was also not part of the original implementation of sklearn wrappers we had in Keras.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough.
I've included this here since, in the meantime, using pre-trained models as a step in a pipeline to get embeddings or any other kind of transformation has become much more popular than a couple of years ago.
But I do understand the name is quite confusing. What do you think of naming these estimators SKLearn{Regressor, Classifier, Transformer}
? That should make it clear that the space/scope of these estimators are scikit-learn. We can also add a clear note in the docstring that these have nothing to do with "transformers".
keras/src/wrappers/_sklearn.py
Outdated
deterministic state using this seed. Pass an int for reproducible | ||
results across multiple function calls. | ||
|
||
Attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but how should I document public attributes which users can use to inspect the object after calling fit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not plain model
btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scikit-learn convention / API is that given estimator arguments are never changed and est.fit(X,y)
is (ignoring randomness) the same as est.fit(X, y); est.fit(X, y)
. Object attributes with no trailing underscore are the ones given to __init__
, and the ones set during fit
are with a trailing underscore.
keras/src/wrappers/_sklearn.py
Outdated
arguments. Other arguments must be accepted if passed as | ||
`model_args` by the user. | ||
|
||
warm_start: bool, default=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the same arg description format as the rest of the codebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hope this is okay now:
warm_start: bool, defaults to False.
Whether to reuse the model weights from the previous fit. If `True`,
the given model won't be cloned and the weights from the previous
fit will be reused.
keras/src/wrappers/utils.py
Outdated
For use in pipelines with transformers that only accept | ||
2D inputs, like OneHotEncoder and OrdinalEncoder. | ||
|
||
Attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the same docstring format as in the rest of the codebase.
Thank you @adrinjalali. I use these wrappers all the time for bridging keras and scikit-learn. From my point of view they're an essential feature for keras. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review. A couple of questions here.
keras/api/wrappers/__init__.py
Outdated
|
||
from keras.src.wrappers._sklearn import KerasClassifier | ||
from keras.src.wrappers._sklearn import KerasRegressor | ||
from keras.src.wrappers._sklearn import KerasTransformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough.
I've included this here since, in the meantime, using pre-trained models as a step in a pipeline to get embeddings or any other kind of transformation has become much more popular than a couple of years ago.
But I do understand the name is quite confusing. What do you think of naming these estimators SKLearn{Regressor, Classifier, Transformer}
? That should make it clear that the space/scope of these estimators are scikit-learn. We can also add a clear note in the docstring that these have nothing to do with "transformers".
keras/src/wrappers/_sklearn.py
Outdated
deterministic state using this seed. Pass an int for reproducible | ||
results across multiple function calls. | ||
|
||
Attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but how should I document public attributes which users can use to inspect the object after calling fit?
keras/src/wrappers/_sklearn.py
Outdated
arguments. Other arguments must be accepted if passed as | ||
`model_args` by the user. | ||
|
||
warm_start: bool, default=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hope this is okay now:
warm_start: bool, defaults to False.
Whether to reuse the model weights from the previous fit. If `True`,
the given model won't be cloned and the weights from the previous
fit will be reused.
Yes, let's do that. This is much better. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
keras/src/wrappers/_sklearn.py
Outdated
deterministic state using this seed. Pass an int for reproducible | ||
results across multiple function calls. | ||
|
||
Attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not plain model
btw?
keras/src/wrappers/sklearn_test.py
Outdated
|
||
EXPECTED_FAILED_CHECKS = { | ||
"SKLearnClassifier": { | ||
"check_classifiers_regression_target": ("not an issue in sklearn>=1.6"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are strings wrapped in tuples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not a tuple, here cause it used to be longer and needed to be in a new line 😅
keras/src/wrappers/random_state.py
Outdated
import numpy as np | ||
|
||
try: | ||
import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not unconditionally try to import tensorflow since that could cause issues with other backends
keras/src/wrappers/random_state.py
Outdated
|
||
|
||
@contextmanager | ||
def tensorflow_random_state(seed: int) -> Generator[None, None, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should really not be necessary, instead just use set_random_seed
from Keras. If users need TF_DETERMINISTIC_OPS that is something they should set manually, separately.
keras/src/wrappers/_sklearn.py
Outdated
values passed directly to the `fit` method take precedence over | ||
these. | ||
random_state : int, np.random.RandomState, or None, defaults to None. | ||
Set the Tensorflow random number generators to a reproducible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work with all backends, not just TF. So no mention of TF and no special casing of TF in the code
keras/src/wrappers/_sklearn.py
Outdated
If callable, it must accept at least `X` and `y` as keyword | ||
arguments. Other arguments must be accepted if passed as | ||
`model_args` by the user. | ||
warm_start: bool, defaults to False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use backticks around code keywords like False, True, None
keras/src/wrappers/_sklearn.py
Outdated
Whether to reuse the model weights from the previous fit. If `True`, | ||
the given model won't be cloned and the weights from the previous | ||
fit will be reused. | ||
model_args: dict, defaults to None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are actually kwargs rather than args (args would be a tuple of values)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically yes, but I'm not sure if model_kwargs
is a better name here; renamed to model_kwargs
keras/src/wrappers/_sklearn.py
Outdated
directly to the `fit` method of the scikit-learn wrapper. The | ||
values passed directly to the `fit` method take precedence over | ||
these. | ||
random_state : int, np.random.RandomState, or None, defaults to None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not seed
(which is the standard arg name for this in Keras)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding randomness, I don't think it's a good idea to recommend keras.utils.set_random_seed
since it has a global side effect. Setting randomness should ideally be local to the object. Right now with the current design of keras and TF it's not clear to me what the best recommendation is. But since this is a larger issue than this PR, I've removed random_seed
and the docs now point to the example where users can control the randomness themselves.
keras/src/wrappers/_sklearn.py
Outdated
Whether to reuse the model weights from the previous fit. If `True`, | ||
the given model won't be cloned and the weights from the previous | ||
fit will be reused. | ||
model_args: dict, defaults to None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically yes, but I'm not sure if model_kwargs
is a better name here; renamed to model_kwargs
keras/src/wrappers/_sklearn.py
Outdated
deterministic state using this seed. Pass an int for reproducible | ||
results across multiple function calls. | ||
|
||
Attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scikit-learn convention / API is that given estimator arguments are never changed and est.fit(X,y)
is (ignoring randomness) the same as est.fit(X, y); est.fit(X, y)
. Object attributes with no trailing underscore are the ones given to __init__
, and the ones set during fit
are with a trailing underscore.
keras/src/wrappers/sklearn_test.py
Outdated
|
||
EXPECTED_FAILED_CHECKS = { | ||
"SKLearnClassifier": { | ||
"check_classifiers_regression_target": ("not an issue in sklearn>=1.6"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not a tuple, here cause it used to be longer and needed to be in a new line 😅
|
||
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version) | ||
|
||
if sklearn_version < parse_version("1.6"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If seems like it would be much easier and maintainable to simply require a minimum sklearn version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not the worst. And I've included all version specific code in a single fixes.py
so that we know later where to clean up. Also, 1.6 was just released a few days ago, so I don't think it's a good idea to have that as a minimum required version. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, that's fine
keras/src/wrappers/utils.py
Outdated
def inverse_transform(self, y): | ||
"""Revert the transformation of transform. | ||
|
||
Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the docstrings in this file to use the standard format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- thank you for the neat contribution!
Oh wow 🥳 thanks for the reviews. Happy to contribute docs, or more features, or be pinged on issues about this ❤️ |
Thanks @adrinjalali and everyone involved in this contribution. |
Amazing work folks! I'll update SciKeras to point folks towards these wrappers. Very happy to see this come full circle 🥳. |
Fixes #20399
This adds a minimal wrapper under
keras.wrappers
. It delegates all model construction parameters to the function generating the model, and therefore doesn't require much__init__
params at all.There are a lot of useful features under https://github.com/adriangb/scikeras which I haven't included here to make the review much easier. Happy to work on more features in this PR or after, if they're not covered.
As for the CI, would we want to test this in a separate job in
actions.yml
or do we want to include it in thebuild
job? Also, should we test against multiplescikit-learn
versions in the CI?also cc @adriangb @clstaudt @fchollet