Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT add scikit-learn wrappers #20599

Merged
merged 26 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6178e35
FEAT add scikit-learn wrappers
adrinjalali Dec 5, 2024
9a4b999
import cleanup
adrinjalali Dec 5, 2024
e14ec2b
run black
adrinjalali Dec 5, 2024
77d0975
linters
adrinjalali Dec 5, 2024
5b28b90
lint
adrinjalali Dec 5, 2024
e450bf1
add scikit-learn to requirements-common
adrinjalali Dec 5, 2024
f73c5e2
generate public api
adrinjalali Dec 5, 2024
c5aa0ea
fix tests for sklearn 1.5
adrinjalali Dec 5, 2024
999c598
check fixes
adrinjalali Dec 5, 2024
5d9bd01
Merge remote-tracking branch 'upstream/master' into wrapper
adrinjalali Dec 5, 2024
a91b475
skip numpy tests
adrinjalali Dec 6, 2024
ab1d0ea
xfail instead of skip
adrinjalali Dec 6, 2024
4f3d7ad
apply review comments
adrinjalali Dec 9, 2024
425b940
Merge remote-tracking branch 'upstream/master' into wrapper
adrinjalali Dec 9, 2024
7f9dd79
change names to SKL* and add transformer example
adrinjalali Dec 10, 2024
9eebe2d
fix API and imports
adrinjalali Dec 10, 2024
8e784fa
fix for new sklearn
adrinjalali Dec 10, 2024
4c26931
sklearn1.6 test
adrinjalali Dec 10, 2024
6289a8f
review comments and remove random_state
adrinjalali Dec 10, 2024
a958d98
Merge remote-tracking branch 'upstream/master' into wrapper
adrinjalali Dec 10, 2024
3cad37f
add another skipped test
adrinjalali Dec 10, 2024
70b44e8
rename file
adrinjalali Dec 11, 2024
f506f97
change imports
adrinjalali Dec 11, 2024
a33501e
unindent
adrinjalali Dec 12, 2024
be8835f
Merge remote-tracking branch 'upstream/master' into wrapper
adrinjalali Dec 12, 2024
eb7a893
docstrings
adrinjalali Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from keras.api import utils
from keras.api import version
from keras.api import visualization
from keras.api import wrappers

# END DO NOT EDIT.

Expand Down
1 change: 1 addition & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.api import wrappers
from keras.src.backend import Variable
from keras.src.backend import device
from keras.src.backend import name_scope
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.api import wrappers
from keras.api._tf_keras.keras import backend
from keras.api._tf_keras.keras import layers
from keras.api._tf_keras.keras import losses
Expand Down
9 changes: 9 additions & 0 deletions keras/api/_tf_keras/keras/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer
9 changes: 9 additions & 0 deletions keras/api/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer
5 changes: 5 additions & 0 deletions keras/src/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier
from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor
from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer

__all__ = ["SKLearnClassifier", "SKLearnRegressor", "SKLearnTransformer"]
119 changes: 119 additions & 0 deletions keras/src/wrappers/fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import sklearn
from packaging.version import parse as parse_version
from sklearn import get_config

sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)

if sklearn_version < parse_version("1.6"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If seems like it would be much easier and maintainable to simply require a minimum sklearn version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not the worst. And I've included all version specific code in a single fixes.py so that we know later where to clean up. Also, 1.6 was just released a few days ago, so I don't think it's a good idea to have that as a minimum required version. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's fine


def patched_more_tags(estimator, expected_failed_checks):
import copy

from sklearn.utils._tags import _safe_tags

original_tags = copy.deepcopy(_safe_tags(estimator))

def patched_more_tags(self):
original_tags.update({"_xfail_checks": expected_failed_checks})
return original_tags

estimator.__class__._more_tags = patched_more_tags
return estimator

def parametrize_with_checks(
estimators,
*,
legacy: bool = True,
expected_failed_checks=None,
):
# legacy is not supported and ignored
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001

estimators = [
patched_more_tags(estimator, expected_failed_checks(estimator))
for estimator in estimators
]

return parametrize_with_checks(estimators)
else:
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001


def _validate_data(estimator, *args, **kwargs):
"""Validate the input data.

wrapper for sklearn.utils.validation.validate_data or
BaseEstimator._validate_data depending on the scikit-learn version.

TODO: remove when minimum scikit-learn version is 1.6
"""
try:
# scikit-learn >= 1.6
from sklearn.utils.validation import validate_data

return validate_data(estimator, *args, **kwargs)
except ImportError:
return estimator._validate_data(*args, **kwargs)
except:
raise


def type_of_target(y, input_name="", *, raise_unknown=False):
# fix for raise_unknown which is introduced in scikit-learn 1.6
from sklearn.utils.multiclass import type_of_target

def _raise_or_return(target_type):
"""Depending on the value of raise_unknown, either raise an error or
return 'unknown'.
"""
if raise_unknown and target_type == "unknown":
input = input_name if input_name else "data"
raise ValueError(f"Unknown label type for {input}: {y!r}")
else:
return target_type

target_type = type_of_target(y, input_name=input_name)
return _raise_or_return(target_type)


def _routing_enabled():
"""Return whether metadata routing is enabled.

Returns:
enabled : bool
Whether metadata routing is enabled. If the config is not set, it
defaults to False.

TODO: remove when the config key is no longer available in scikit-learn
"""
return get_config().get("enable_metadata_routing", False)


def _raise_for_params(params, owner, method):
"""Raise an error if metadata routing is not enabled and params are passed.

Parameters:
params : dict
The metadata passed to a method.
owner : object
The object to which the method belongs.
method : str
The name of the method, e.g. "fit".

Raises:
ValueError
If metadata routing is not enabled and params are passed.
"""
caller = (
f"{owner.__class__.__name__}.{method}"
if method
else owner.__class__.__name__
)
if not _routing_enabled() and params:
raise ValueError(
f"Passing extra keyword arguments to {caller} is only supported if"
" enable_metadata_routing=True, which you can set using"
" `sklearn.set_config`. See the User Guide"
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
f" details. Extra parameters passed are: {set(params)}"
)
119 changes: 119 additions & 0 deletions keras/src/wrappers/sklearn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Tests using Scikit-Learn's bundled estimator_checks."""

from contextlib import contextmanager

import pytest

import keras
from keras.src.backend import floatx
from keras.src.backend import set_floatx
from keras.src.layers import Dense
from keras.src.layers import Input
from keras.src.models import Model
from keras.src.wrappers import SKLearnClassifier
from keras.src.wrappers import SKLearnRegressor
from keras.src.wrappers import SKLearnTransformer
from keras.src.wrappers.fixes import parametrize_with_checks


def dynamic_model(X, y, loss, layers=[10]):
"""Creates a basic MLP classifier dynamically choosing binary/multiclass
classification loss and ouput activations.
"""
n_features_in = X.shape[1]
inp = Input(shape=(n_features_in,))

hidden = inp
for layer_size in layers:
hidden = Dense(layer_size, activation="relu")(hidden)

n_outputs = y.shape[1] if len(y.shape) > 1 else 1
out = [Dense(n_outputs, activation="softmax")(hidden)]
model = Model(inp, out)
model.compile(loss=loss, optimizer="rmsprop")

return model


@contextmanager
def use_floatx(x: str):
"""Context manager to temporarily
set the keras backend precision.
"""
_floatx = floatx()
set_floatx(x)
try:
yield
finally:
set_floatx(_floatx)


EXPECTED_FAILED_CHECKS = {
"SKLearnClassifier": {
"check_classifiers_regression_target": "not an issue in sklearn>=1.6",
"check_parameters_default_constructible": (
"not an issue in sklearn>=1.6"
),
"check_classifiers_one_label_sample_weights": (
"0 sample weight is not ignored"
),
"check_classifiers_classes": (
"with small test cases the estimator returns not all classes "
"sometimes"
),
"check_classifier_data_not_an_array": (
"This test assumes reproducibility in fit."
),
"check_supervised_y_2d": "This test assumes reproducibility in fit.",
"check_fit_idempotent": "This test assumes reproducibility in fit.",
},
"SKLearnRegressor": {
"check_parameters_default_constructible": (
"not an issue in sklearn>=1.6"
),
},
"SKLearnTransformer": {
"check_parameters_default_constructible": (
"not an issue in sklearn>=1.6"
),
},
}


@parametrize_with_checks(
estimators=[
SKLearnClassifier(
model=dynamic_model,
model_kwargs={
"loss": "categorical_crossentropy",
"layers": [20, 20, 20],
},
fit_kwargs={"epochs": 5},
),
SKLearnRegressor(
model=dynamic_model,
model_kwargs={"loss": "mse"},
),
SKLearnTransformer(
model=dynamic_model,
model_kwargs={"loss": "mse"},
),
],
expected_failed_checks=lambda estimator: EXPECTED_FAILED_CHECKS[
type(estimator).__name__
],
)
def test_sklearn_estimator_checks(estimator, check):
"""Checks that can be passed with sklearn's default tolerances
and in a single epoch.
"""
try:
check(estimator)
except Exception as exc:
if keras.config.backend() == "numpy" and (
isinstance(exc, NotImplementedError)
or "NotImplementedError" in str(exc)
):
pytest.xfail("Backend not implemented")
else:
raise
Loading
Loading