Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT add scikit-learn wrappers #20599

Merged
merged 26 commits into from
Dec 12, 2024
Merged

FEAT add scikit-learn wrappers #20599

merged 26 commits into from
Dec 12, 2024

Conversation

adrinjalali
Copy link
Contributor

Fixes #20399

This adds a minimal wrapper under keras.wrappers. It delegates all model construction parameters to the function generating the model, and therefore doesn't require much __init__ params at all.

There are a lot of useful features under https://github.com/adriangb/scikeras which I haven't included here to make the review much easier. Happy to work on more features in this PR or after, if they're not covered.

As for the CI, would we want to test this in a separate job in actions.yml or do we want to include it in the build job? Also, should we test against multiple scikit-learn versions in the CI?

also cc @adriangb @clstaudt @fchollet

@codecov-commenter
Copy link

codecov-commenter commented Dec 5, 2024

Codecov Report

Attention: Patch coverage is 76.63043% with 43 lines in your changes missing coverage. Please review.

Project coverage is 82.52%. Comparing base (90d36dc) to head (eb7a893).
Report is 6 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/wrappers/fixes.py 54.76% 18 Missing and 1 partial ⚠️
keras/src/wrappers/sklearn_wrapper.py 86.11% 12 Missing and 3 partials ⚠️
keras/src/wrappers/utils.py 71.42% 3 Missing and 3 partials ⚠️
keras/api/_tf_keras/keras/wrappers/__init__.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20599      +/-   ##
==========================================
- Coverage   82.55%   82.52%   -0.03%     
==========================================
  Files         518      525       +7     
  Lines       48682    48948     +266     
  Branches     7592     7615      +23     
==========================================
+ Hits        40188    40393     +205     
- Misses       6669     6719      +50     
- Partials     1825     1836      +11     
Flag Coverage Δ
keras 82.36% <76.63%> (-0.04%) ⬇️
keras-jax 65.69% <76.63%> (+0.08%) ⬆️
keras-numpy 60.66% <69.02%> (+0.07%) ⬆️
keras-tensorflow 66.52% <76.63%> (+0.04%) ⬆️
keras-torch 65.59% <76.63%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@adriangb adriangb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is amazing! You've made some great improvements over the version in SciKeras.

I agree that the approach you're proposing here of not just copying over SciKeras or trying to reproduce all of the features out of the gate and instead build it up better as a new version over time.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!


from keras.src.wrappers._sklearn import KerasClassifier
from keras.src.wrappers._sklearn import KerasRegressor
from keras.src.wrappers._sklearn import KerasTransformer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reservations about "KerasTransformer" due to the very specific meaning that Transformer has in deep learning. This is going to confuse the hell out of people who search "Transformer" on keras.io. This was also not part of the original implementation of sklearn wrappers we had in Keras.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

I've included this here since, in the meantime, using pre-trained models as a step in a pipeline to get embeddings or any other kind of transformation has become much more popular than a couple of years ago.

But I do understand the name is quite confusing. What do you think of naming these estimators SKLearn{Regressor, Classifier, Transformer}? That should make it clear that the space/scope of these estimators are scikit-learn. We can also add a clear note in the docstring that these have nothing to do with "transformers".

keras/src/wrappers/_sklearn.py Outdated Show resolved Hide resolved
deterministic state using this seed. Pass an int for reproducible
results across multiple function calls.

Attributes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but how should I document public attributes which users can use to inspect the object after calling fit?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not plain model btw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scikit-learn convention / API is that given estimator arguments are never changed and est.fit(X,y) is (ignoring randomness) the same as est.fit(X, y); est.fit(X, y). Object attributes with no trailing underscore are the ones given to __init__, and the ones set during fit are with a trailing underscore.

keras/src/wrappers/_sklearn.py Outdated Show resolved Hide resolved
arguments. Other arguments must be accepted if passed as
`model_args` by the user.

warm_start: bool, default=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the same arg description format as the rest of the codebase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope this is okay now:

        warm_start: bool, defaults to False.
            Whether to reuse the model weights from the previous fit. If `True`,
            the given model won't be cloned and the weights from the previous
            fit will be reused.

keras/src/wrappers/_sklearn.py Outdated Show resolved Hide resolved
keras/src/wrappers/_sklearn.py Outdated Show resolved Hide resolved
For use in pipelines with transformers that only accept
2D inputs, like OneHotEncoder and OrdinalEncoder.

Attributes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the same docstring format as in the rest of the codebase.

@clstaudt
Copy link

clstaudt commented Dec 9, 2024

Thank you @adrinjalali. I use these wrappers all the time for bridging keras and scikit-learn. From my point of view they're an essential feature for keras.

Copy link
Contributor Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. A couple of questions here.


from keras.src.wrappers._sklearn import KerasClassifier
from keras.src.wrappers._sklearn import KerasRegressor
from keras.src.wrappers._sklearn import KerasTransformer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

I've included this here since, in the meantime, using pre-trained models as a step in a pipeline to get embeddings or any other kind of transformation has become much more popular than a couple of years ago.

But I do understand the name is quite confusing. What do you think of naming these estimators SKLearn{Regressor, Classifier, Transformer}? That should make it clear that the space/scope of these estimators are scikit-learn. We can also add a clear note in the docstring that these have nothing to do with "transformers".

deterministic state using this seed. Pass an int for reproducible
results across multiple function calls.

Attributes:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but how should I document public attributes which users can use to inspect the object after calling fit?

arguments. Other arguments must be accepted if passed as
`model_args` by the user.

warm_start: bool, default=False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope this is okay now:

        warm_start: bool, defaults to False.
            Whether to reuse the model weights from the previous fit. If `True`,
            the given model won't be cloned and the weights from the previous
            fit will be reused.

@fchollet
Copy link
Collaborator

fchollet commented Dec 9, 2024

But I do understand the name is quite confusing. What do you think of naming these estimators SKLearn{Regressor, Classifier, Transformer}? That should make it clear that the space/scope of these estimators are scikit-learn. We can also add a clear note in the docstring that these have nothing to do with "transformers".

Yes, let's do that. This is much better.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

deterministic state using this seed. Pass an int for reproducible
results across multiple function calls.

Attributes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not plain model btw?

keras/src/wrappers/_sklearn.py Outdated Show resolved Hide resolved

EXPECTED_FAILED_CHECKS = {
"SKLearnClassifier": {
"check_classifiers_regression_target": ("not an issue in sklearn>=1.6"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are strings wrapped in tuples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not a tuple, here cause it used to be longer and needed to be in a new line 😅

import numpy as np

try:
import tensorflow as tf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not unconditionally try to import tensorflow since that could cause issues with other backends



@contextmanager
def tensorflow_random_state(seed: int) -> Generator[None, None, None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really not be necessary, instead just use set_random_seed from Keras. If users need TF_DETERMINISTIC_OPS that is something they should set manually, separately.

keras/src/wrappers/_sklearn.py Outdated Show resolved Hide resolved
values passed directly to the `fit` method take precedence over
these.
random_state : int, np.random.RandomState, or None, defaults to None.
Set the Tensorflow random number generators to a reproducible
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work with all backends, not just TF. So no mention of TF and no special casing of TF in the code

If callable, it must accept at least `X` and `y` as keyword
arguments. Other arguments must be accepted if passed as
`model_args` by the user.
warm_start: bool, defaults to False.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use backticks around code keywords like False, True, None

Whether to reuse the model weights from the previous fit. If `True`,
the given model won't be cloned and the weights from the previous
fit will be reused.
model_args: dict, defaults to None.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are actually kwargs rather than args (args would be a tuple of values)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically yes, but I'm not sure if model_kwargs is a better name here; renamed to model_kwargs

directly to the `fit` method of the scikit-learn wrapper. The
values passed directly to the `fit` method take precedence over
these.
random_state : int, np.random.RandomState, or None, defaults to None.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not seed (which is the standard arg name for this in Keras)?

Copy link
Contributor Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding randomness, I don't think it's a good idea to recommend keras.utils.set_random_seed since it has a global side effect. Setting randomness should ideally be local to the object. Right now with the current design of keras and TF it's not clear to me what the best recommendation is. But since this is a larger issue than this PR, I've removed random_seed and the docs now point to the example where users can control the randomness themselves.

keras/src/wrappers/_sklearn.py Outdated Show resolved Hide resolved
Whether to reuse the model weights from the previous fit. If `True`,
the given model won't be cloned and the weights from the previous
fit will be reused.
model_args: dict, defaults to None.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically yes, but I'm not sure if model_kwargs is a better name here; renamed to model_kwargs

deterministic state using this seed. Pass an int for reproducible
results across multiple function calls.

Attributes:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scikit-learn convention / API is that given estimator arguments are never changed and est.fit(X,y) is (ignoring randomness) the same as est.fit(X, y); est.fit(X, y). Object attributes with no trailing underscore are the ones given to __init__, and the ones set during fit are with a trailing underscore.


EXPECTED_FAILED_CHECKS = {
"SKLearnClassifier": {
"check_classifiers_regression_target": ("not an issue in sklearn>=1.6"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not a tuple, here cause it used to be longer and needed to be in a new line 😅


sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)

if sklearn_version < parse_version("1.6"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If seems like it would be much easier and maintainable to simply require a minimum sklearn version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not the worst. And I've included all version specific code in a single fixes.py so that we know later where to clean up. Also, 1.6 was just released a few days ago, so I don't think it's a good idea to have that as a minimum required version. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's fine

def inverse_transform(self, y):
"""Revert the transformation of transform.

Parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the docstrings in this file to use the standard format

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be fixed now.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- thank you for the neat contribution!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Dec 12, 2024
@fchollet fchollet merged commit 32a642d into keras-team:master Dec 12, 2024
6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Dec 12, 2024
@adrinjalali
Copy link
Contributor Author

Oh wow 🥳 thanks for the reviews. Happy to contribute docs, or more features, or be pinged on issues about this ❤️

@adrinjalali adrinjalali deleted the wrapper branch December 12, 2024 20:48
@glemaitre
Copy link

Thanks @adrinjalali and everyone involved in this contribution.

@adriangb
Copy link
Contributor

Amazing work folks!

I'll update SciKeras to point folks towards these wrappers.

Very happy to see this come full circle 🥳.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

scikit-learn wrappers in keras
7 participants