diff --git a/keras/__init__.py b/keras/__init__.py index 91fac8353f9..4df10c6c84a 100644 --- a/keras/__init__.py +++ b/keras/__init__.py @@ -49,6 +49,7 @@ from keras.api import utils from keras.api import version from keras.api import visualization +from keras.api import wrappers # END DO NOT EDIT. diff --git a/keras/api/__init__.py b/keras/api/__init__.py index ae3a063c368..2b24945d1c6 100644 --- a/keras/api/__init__.py +++ b/keras/api/__init__.py @@ -32,6 +32,7 @@ from keras.api import tree from keras.api import utils from keras.api import visualization +from keras.api import wrappers from keras.src.backend import Variable from keras.src.backend import device from keras.src.backend import name_scope diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index 2feb6330464..cabed08d631 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -25,6 +25,7 @@ from keras.api import tree from keras.api import utils from keras.api import visualization +from keras.api import wrappers from keras.api._tf_keras.keras import backend from keras.api._tf_keras.keras import layers from keras.api._tf_keras.keras import losses diff --git a/keras/api/_tf_keras/keras/wrappers/__init__.py b/keras/api/_tf_keras/keras/wrappers/__init__.py new file mode 100644 index 00000000000..1d9b8747c6e --- /dev/null +++ b/keras/api/_tf_keras/keras/wrappers/__init__.py @@ -0,0 +1,9 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier +from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor +from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer diff --git a/keras/api/wrappers/__init__.py b/keras/api/wrappers/__init__.py new file mode 100644 index 00000000000..1d9b8747c6e --- /dev/null +++ b/keras/api/wrappers/__init__.py @@ -0,0 +1,9 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier +from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor +from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer diff --git a/keras/src/wrappers/__init__.py b/keras/src/wrappers/__init__.py new file mode 100644 index 00000000000..8c55aa752f5 --- /dev/null +++ b/keras/src/wrappers/__init__.py @@ -0,0 +1,5 @@ +from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier +from keras.src.wrappers.sklearn_wrapper import SKLearnRegressor +from keras.src.wrappers.sklearn_wrapper import SKLearnTransformer + +__all__ = ["SKLearnClassifier", "SKLearnRegressor", "SKLearnTransformer"] diff --git a/keras/src/wrappers/fixes.py b/keras/src/wrappers/fixes.py new file mode 100644 index 00000000000..20e91cfa2a1 --- /dev/null +++ b/keras/src/wrappers/fixes.py @@ -0,0 +1,119 @@ +import sklearn +from packaging.version import parse as parse_version +from sklearn import get_config + +sklearn_version = parse_version(parse_version(sklearn.__version__).base_version) + +if sklearn_version < parse_version("1.6"): + + def patched_more_tags(estimator, expected_failed_checks): + import copy + + from sklearn.utils._tags import _safe_tags + + original_tags = copy.deepcopy(_safe_tags(estimator)) + + def patched_more_tags(self): + original_tags.update({"_xfail_checks": expected_failed_checks}) + return original_tags + + estimator.__class__._more_tags = patched_more_tags + return estimator + + def parametrize_with_checks( + estimators, + *, + legacy: bool = True, + expected_failed_checks=None, + ): + # legacy is not supported and ignored + from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001 + + estimators = [ + patched_more_tags(estimator, expected_failed_checks(estimator)) + for estimator in estimators + ] + + return parametrize_with_checks(estimators) +else: + from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001 + + +def _validate_data(estimator, *args, **kwargs): + """Validate the input data. + + wrapper for sklearn.utils.validation.validate_data or + BaseEstimator._validate_data depending on the scikit-learn version. + + TODO: remove when minimum scikit-learn version is 1.6 + """ + try: + # scikit-learn >= 1.6 + from sklearn.utils.validation import validate_data + + return validate_data(estimator, *args, **kwargs) + except ImportError: + return estimator._validate_data(*args, **kwargs) + except: + raise + + +def type_of_target(y, input_name="", *, raise_unknown=False): + # fix for raise_unknown which is introduced in scikit-learn 1.6 + from sklearn.utils.multiclass import type_of_target + + def _raise_or_return(target_type): + """Depending on the value of raise_unknown, either raise an error or + return 'unknown'. + """ + if raise_unknown and target_type == "unknown": + input = input_name if input_name else "data" + raise ValueError(f"Unknown label type for {input}: {y!r}") + else: + return target_type + + target_type = type_of_target(y, input_name=input_name) + return _raise_or_return(target_type) + + +def _routing_enabled(): + """Return whether metadata routing is enabled. + + Returns: + enabled : bool + Whether metadata routing is enabled. If the config is not set, it + defaults to False. + + TODO: remove when the config key is no longer available in scikit-learn + """ + return get_config().get("enable_metadata_routing", False) + + +def _raise_for_params(params, owner, method): + """Raise an error if metadata routing is not enabled and params are passed. + + Parameters: + params : dict + The metadata passed to a method. + owner : object + The object to which the method belongs. + method : str + The name of the method, e.g. "fit". + + Raises: + ValueError + If metadata routing is not enabled and params are passed. + """ + caller = ( + f"{owner.__class__.__name__}.{method}" + if method + else owner.__class__.__name__ + ) + if not _routing_enabled() and params: + raise ValueError( + f"Passing extra keyword arguments to {caller} is only supported if" + " enable_metadata_routing=True, which you can set using" + " `sklearn.set_config`. See the User Guide" + " for more" + f" details. Extra parameters passed are: {set(params)}" + ) diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py new file mode 100644 index 00000000000..8f2f4e1caf8 --- /dev/null +++ b/keras/src/wrappers/sklearn_test.py @@ -0,0 +1,119 @@ +"""Tests using Scikit-Learn's bundled estimator_checks.""" + +from contextlib import contextmanager + +import pytest + +import keras +from keras.src.backend import floatx +from keras.src.backend import set_floatx +from keras.src.layers import Dense +from keras.src.layers import Input +from keras.src.models import Model +from keras.src.wrappers import SKLearnClassifier +from keras.src.wrappers import SKLearnRegressor +from keras.src.wrappers import SKLearnTransformer +from keras.src.wrappers.fixes import parametrize_with_checks + + +def dynamic_model(X, y, loss, layers=[10]): + """Creates a basic MLP classifier dynamically choosing binary/multiclass + classification loss and ouput activations. + """ + n_features_in = X.shape[1] + inp = Input(shape=(n_features_in,)) + + hidden = inp + for layer_size in layers: + hidden = Dense(layer_size, activation="relu")(hidden) + + n_outputs = y.shape[1] if len(y.shape) > 1 else 1 + out = [Dense(n_outputs, activation="softmax")(hidden)] + model = Model(inp, out) + model.compile(loss=loss, optimizer="rmsprop") + + return model + + +@contextmanager +def use_floatx(x: str): + """Context manager to temporarily + set the keras backend precision. + """ + _floatx = floatx() + set_floatx(x) + try: + yield + finally: + set_floatx(_floatx) + + +EXPECTED_FAILED_CHECKS = { + "SKLearnClassifier": { + "check_classifiers_regression_target": "not an issue in sklearn>=1.6", + "check_parameters_default_constructible": ( + "not an issue in sklearn>=1.6" + ), + "check_classifiers_one_label_sample_weights": ( + "0 sample weight is not ignored" + ), + "check_classifiers_classes": ( + "with small test cases the estimator returns not all classes " + "sometimes" + ), + "check_classifier_data_not_an_array": ( + "This test assumes reproducibility in fit." + ), + "check_supervised_y_2d": "This test assumes reproducibility in fit.", + "check_fit_idempotent": "This test assumes reproducibility in fit.", + }, + "SKLearnRegressor": { + "check_parameters_default_constructible": ( + "not an issue in sklearn>=1.6" + ), + }, + "SKLearnTransformer": { + "check_parameters_default_constructible": ( + "not an issue in sklearn>=1.6" + ), + }, +} + + +@parametrize_with_checks( + estimators=[ + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "loss": "categorical_crossentropy", + "layers": [20, 20, 20], + }, + fit_kwargs={"epochs": 5}, + ), + SKLearnRegressor( + model=dynamic_model, + model_kwargs={"loss": "mse"}, + ), + SKLearnTransformer( + model=dynamic_model, + model_kwargs={"loss": "mse"}, + ), + ], + expected_failed_checks=lambda estimator: EXPECTED_FAILED_CHECKS[ + type(estimator).__name__ + ], +) +def test_sklearn_estimator_checks(estimator, check): + """Checks that can be passed with sklearn's default tolerances + and in a single epoch. + """ + try: + check(estimator) + except Exception as exc: + if keras.config.backend() == "numpy" and ( + isinstance(exc, NotImplementedError) + or "NotImplementedError" in str(exc) + ): + pytest.xfail("Backend not implemented") + else: + raise diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py new file mode 100644 index 00000000000..80fa57b2feb --- /dev/null +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -0,0 +1,470 @@ +import copy + +import numpy as np +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.base import RegressorMixin +from sklearn.base import TransformerMixin +from sklearn.base import check_is_fitted +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import OneHotEncoder +from sklearn.utils.metadata_routing import MetadataRequest + +from keras.src.api_export import keras_export +from keras.src.models.cloning import clone_model +from keras.src.models.model import Model +from keras.src.wrappers.fixes import _routing_enabled +from keras.src.wrappers.fixes import _validate_data +from keras.src.wrappers.fixes import type_of_target +from keras.src.wrappers.utils import TargetReshaper +from keras.src.wrappers.utils import _check_model + + +class SKLBase(BaseEstimator): + """Base class for scikit-learn wrappers. + + Note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + history_ : dict + The history of the fit, returned by `model.fit`. + """ + + def __init__( + self, + model, + warm_start=False, + model_kwargs=None, + fit_kwargs=None, + ): + self.model = model + self.warm_start = warm_start + self.model_kwargs = model_kwargs + self.fit_kwargs = fit_kwargs + + def _more_tags(self): + return {"non_deterministic": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.non_deterministic = True + return tags + + def __sklearn_clone__(self): + """Return a deep copy of the model. + + This is used by the `sklearn.base.clone` function. + """ + model = ( + self.model if callable(self.model) else copy.deepcopy(self.model) + ) + return type(self)( + model=model, + warm_start=self.warm_start, + model_kwargs=self.model_kwargs, + ) + + @property + def epoch_(self) -> int: + """The current training epoch.""" + return getattr(self, "history_", {}).get("epoch", 0) + + def set_fit_request(self, **kwargs): + """Set requested parameters by the fit method. + + Please see [scikit-learn's metadata routing]( + https://scikit-learn.org/stable/metadata_routing.html) for more + details. + + + Arguments: + kwargs : dict + Arguments should be of the form `param_name=alias`, and `alias` + can be one of `{True, False, None, str}`. + + Returns: + self + """ + if not _routing_enabled(): + raise RuntimeError( + "This method is only available when metadata routing is " + "enabled. You can enable it using " + "sklearn.set_config(enable_metadata_routing=True)." + ) + + self._metadata_request = MetadataRequest(owner=self.__class__.__name__) + for param, alias in kwargs.items(): + self._metadata_request.score.add_request(param=param, alias=alias) + return self + + def _get_model(self, X, y): + if isinstance(self.model, Model): + return clone_model(self.model) + else: + args = self.model_kwargs or {} + return self.model(X=X, y=y, **args) + + def fit(self, X, y, **kwargs): + """Fit the model. + + Args: + X: array-like, shape=(n_samples, n_features) + The input samples. + y: array-like, shape=(n_samples,) or (n_samples, n_outputs) + The targets. + **kwargs: keyword arguments passed to `model.fit` + """ + X, y = _validate_data(self, X, y) + y = self._process_target(y, reset=True) + model = self._get_model(X, y) + _check_model(model) + + fit_kwargs = self.fit_kwargs or {} + fit_kwargs.update(kwargs) + self.history_ = model.fit(X, y, **fit_kwargs) + + self.model_ = model + return self + + def predict(self, X): + """Predict using the model.""" + check_is_fitted(self) + X = _validate_data(self, X, reset=False) + raw_output = self.model_.predict(X) + return self._reverse_process_target(raw_output) + + def _process_target(self, y, reset=False): + """Regressors are NOOP here, classifiers do OHE.""" + # This is here to raise the right error in case of invalid target + type_of_target(y, raise_unknown=True) + if reset: + self._target_encoder = TargetReshaper().fit(y) + return self._target_encoder.transform(y) + + def _reverse_process_target(self, y): + """Regressors are NOOP here, classifiers reverse OHE.""" + return self._target_encoder.inverse_transform(y) + + +@keras_export("keras.wrappers.SKLearnClassifier") +class SKLearnClassifier(ClassifierMixin, SKLBase): + """scikit-learn compatible classifier wrapper for Keras models. + + Note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + history_ : dict + The history of the fit, returned by `model.fit`. + classes_ : array-like, shape=(n_classes,) + The classes labels. + + Example: + Here we use a function which creates a basic MLP model dynamically + choosing the input and output shapes. We will use this to create our + scikit-learn model. + + ``` python + from keras.src.layers import Dense, Input, Model + + def dynamic_model(X, y, loss, layers=[10]): + # Creates a basic MLP model dynamically choosing the input and + # output shapes. + n_features_in = X.shape[1] + inp = Input(shape=(n_features_in,)) + + hidden = inp + for layer_size in layers: + hidden = Dense(layer_size, activation="relu")(hidden) + + n_outputs = y.shape[1] if len(y.shape) > 1 else 1 + out = [Dense(n_outputs, activation="softmax")(hidden)] + model = Model(inp, out) + model.compile(loss=loss, optimizer="rmsprop") + + return model + ``` + + You can then use this function to create a scikit-learn compatible model + and fit it on some data. + + ``` python + from sklearn.datasets import make_classification + from keras.wrappers import SKLearnClassifier + + X, y = make_classification(n_samples=1000, n_features=10, n_classes=3) + est = SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "loss": "categorical_crossentropy", + "layers": [20, 20, 20], + }, + ) + + est.fit(X, y, epochs=5) + ``` + """ + + def _process_target(self, y, reset=False): + """Classifiers do OHE.""" + target_type = type_of_target(y, raise_unknown=True) + if target_type not in ["binary", "multiclass"]: + raise ValueError( + "Only binary and multiclass target types are supported." + f" Target type: {target_type}" + ) + if reset: + self._target_encoder = make_pipeline( + TargetReshaper(), OneHotEncoder(sparse_output=False) + ).fit(y) + self.classes_ = np.unique(y) + if len(self.classes_) == 1: + raise ValueError( + "Classifier can't train when only one class is present." + ) + return self._target_encoder.transform(y) + + def _more_tags(self): + # required to be compatible with scikit-learn<1.6 + return {"poor_score": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.classifier_tags.poor_score = True + return tags + + +@keras_export("keras.wrappers.SKLearnRegressor") +class SKLearnRegressor(RegressorMixin, SKLBase): + """scikit-learn compatible regressor wrapper for Keras models. + + Note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + + Example: + Here we use a function which creates a basic MLP model dynamically + choosing the input and output shapes. We will use this to create our + scikit-learn model. + + ``` python + from keras.src.layers import Dense, Input, Model + + def dynamic_model(X, y, loss, layers=[10]): + # Creates a basic MLP model dynamically choosing the input and + # output shapes. + n_features_in = X.shape[1] + inp = Input(shape=(n_features_in,)) + + hidden = inp + for layer_size in layers: + hidden = Dense(layer_size, activation="relu")(hidden) + + n_outputs = y.shape[1] if len(y.shape) > 1 else 1 + out = [Dense(n_outputs, activation="softmax")(hidden)] + model = Model(inp, out) + model.compile(loss=loss, optimizer="rmsprop") + + return model + ``` + + You can then use this function to create a scikit-learn compatible model + and fit it on some data. + + ``` python + from sklearn.datasets import make_regression + from keras.wrappers import SKLearnRegressor + + X, y = make_regression(n_samples=1000, n_features=10) + est = SKLearnRegressor( + model=dynamic_model, + model_kwargs={ + "loss": "mse", + "layers": [20, 20, 20], + }, + ) + + est.fit(X, y, epochs=5) + ``` + """ + + def _more_tags(self): + # required to be compatible with scikit-learn<1.6 + return {"poor_score": True} + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.regressor_tags.poor_score = True + return tags + + +@keras_export("keras.wrappers.SKLearnTransformer") +class SKLearnTransformer(TransformerMixin, SKLBase): + """scikit-learn compatible transformer wrapper for Keras models. + + Note that this is a scikit-learn compatible transformer, and not a + transformer in the deep learning sense. + + Also note that there are sources of randomness in model initialization and + training. Refer to [Reproducibility in Keras Models]( + https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to + control randomness. + + Args: + model: `Model`. + An instance of `Model`, or a callable returning such an object. + Note that if input is a `Model`, it will be cloned using + `keras.models.clone_model` before being fitted, unless + `warm_start=True`. + The `Model` instance needs to be passed as already compiled. + If callable, it must accept at least `X` and `y` as keyword + arguments. Other arguments must be accepted if passed as + `model_kwargs` by the user. + warm_start: bool, defaults to `False`. + Whether to reuse the model weights from the previous fit. If `True`, + the given model won't be cloned and the weights from the previous + fit will be reused. + model_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model`, if `model` is callable. + fit_kwargs: dict, defaults to `None`. + Keyword arguments passed to `model.fit`. These can also be passed + directly to the `fit` method of the scikit-learn wrapper. The + values passed directly to the `fit` method take precedence over + these. + + Attributes: + model_ : `Model` + The fitted model. + history_ : dict + The history of the fit, returned by `model.fit`. + + Example: + A common use case for a scikit-learn transformer, is to have a step + which gives you the embedding of your data. Here we assume + `my_package.my_model` is a Keras model which takes the input and gives + embeddings of the data, and `my_package.my_data` is your dataset loader. + + ``` python + from my_package import my_model, my_data + from keras.wrappers import SKLearnTransformer + from sklearn.frozen import FrozenEstimator # requires scikit-learn>=1.6 + from sklearn.pipeline import make_pipeline + from sklearn.ensemble import HistGradientBoostingClassifier + + X, y = my_data() + + trs = FrozenEstimator(SKLearnTransformer(model=my_model)) + pipe = make_pipeline(trs, HistGradientBoostingClassifier()) + pipe.fit(X, y) + ``` + + Note that in the above example, `FrozenEstimator` prevents any further + training of the transformer step in the pipeline, which can be the case + if you don't want to change the embedding model at hand. + """ + + def transform(self, X): + """Transform the data. + + Args: + X: array-like, shape=(n_samples, n_features) + The input samples. + + Returns: + X_transformed: array-like, shape=(n_samples, n_features) + The transformed data. + """ + check_is_fitted(self) + X = _validate_data(self, X, reset=False) + return self.model_.predict(X) + + def _more_tags(self): + # required to be compatible with scikit-learn<1.6 + return { + "preserves_dtype": [], + } + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.transformer_tags.preserves_dtype = [] + return tags diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py new file mode 100644 index 00000000000..62ae6d0eb9b --- /dev/null +++ b/keras/src/wrappers/utils.py @@ -0,0 +1,71 @@ +from sklearn.base import BaseEstimator +from sklearn.base import TransformerMixin +from sklearn.base import check_is_fitted +from sklearn.utils._array_api import get_namespace + + +def _check_model(model): + """Check whether the model need sto be compiled.""" + # compile model if user gave us an un-compiled model + if not model.compiled or not model.loss or not model.optimizer: + raise RuntimeError( + "Given model needs to be compiled, and have a loss and an " + "optimizer." + ) + + +class TargetReshaper(TransformerMixin, BaseEstimator): + """Convert 1D targets to 2D and back. + + For use in pipelines with transformers that only accept + 2D inputs, like OneHotEncoder and OrdinalEncoder. + + Attributes: + ndim_ : int + Dimensions of y that the transformer was trained on. + """ + + def fit(self, y): + """Fit the transformer to a target y. + + Returns: + TargetReshaper + A reference to the current instance of TargetReshaper. + """ + self.ndim_ = y.ndim + return self + + def transform(self, y): + """Makes 1D y 2D. + + Args: + y : np.ndarray + Target y to be transformed. + + Returns: + np.ndarray + A numpy array, of dimension at least 2. + """ + if y.ndim == 1: + return y.reshape(-1, 1) + return y + + def inverse_transform(self, y): + """Revert the transformation of transform. + + Args: + y: np.ndarray + Transformed numpy array. + + Returns: + np.ndarray + If the transformer was fit to a 1D numpy array, + and a 2D numpy array with a singleton second dimension + is passed, it will be squeezed back to 1D. Otherwise, it + will eb left untouched. + """ + check_is_fitted(self) + xp, _ = get_namespace(y) + if self.ndim_ == 1 and y.ndim == 2: + return xp.squeeze(y, axis=1) + return y diff --git a/requirements-common.txt b/requirements-common.txt index eff32ae05a8..2d1ec92d911 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -3,6 +3,7 @@ ruff pytest numpy scipy +scikit-learn pandas absl-py requests