From 9921c615ef418950d18b0ed866a36d28d88f0b4d Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 9 Feb 2022 15:15:18 +0100 Subject: [PATCH] Kernel ridge regression (#4492) Sklearn reference implementation: https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/kernel_ridge.py#L16 I've tried to avoid touching the c++/cuda layer so far. Pairwise kernels are implented based on a numba kernel for now. I've also used cupy's lapack wrapper to access cuSolver. The implementation of `pairwise_kernels` here can be reused to very easily implement kernel PCA. Todo: - [x] Single target fit/predict - [x] Standard kernels implemented - [x] Support custom kernels - [x] Support sample weights - [ ] ~~Support CSR X matrix. Maybe too difficult for this PR.~~ - [x] Multi-target fit/predict - [x] Change .py files to .pyx and moved to correct places. - [x] Benchmarking on reasonably large files - [x] Tests take less than 20s - [x] Ensure correct handling of input/output array types (I think I need to be using CumlArray and maybe some decorators) - [x] Documentation Authors: - Rory Mitchell (https://github.com/RAMitchell) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Micka (https://github.com/lowener) URL: https://github.com/rapidsai/cuml/pull/4492 --- .gitignore | 1 + python/cuml/__init__.py | 4 +- python/cuml/kernel_ridge/__init__.py | 18 ++ python/cuml/kernel_ridge/kernel_ridge.pyx | 291 +++++++++++++++++++++ python/cuml/metrics/__init__.py | 4 +- python/cuml/metrics/pairwise_kernels.py | 295 ++++++++++++++++++++++ python/cuml/test/test_kernel_ridge.py | 276 ++++++++++++++++++++ 7 files changed, 887 insertions(+), 2 deletions(-) create mode 100644 python/cuml/kernel_ridge/__init__.py create mode 100644 python/cuml/kernel_ridge/kernel_ridge.pyx create mode 100644 python/cuml/metrics/pairwise_kernels.py create mode 100644 python/cuml/test/test_kernel_ridge.py diff --git a/.gitignore b/.gitignore index 1a06104ca0..045470926e 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ log .DS_Store dask-worker-space/ tmp/ +.hypothesis ## files pickled in notebook when ran during python docstring generation docs/source/*.model diff --git a/python/cuml/__init__.py b/python/cuml/__init__.py index e535ee487d..4c5553b3e4 100644 --- a/python/cuml/__init__.py +++ b/python/cuml/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,8 @@ from cuml.internals.global_settings import ( GlobalSettings, _global_settings_data) +from cuml.kernel_ridge.kernel_ridge import KernelRidge + from cuml.linear_model.elastic_net import ElasticNet from cuml.linear_model.lasso import Lasso from cuml.linear_model.linear_regression import LinearRegression diff --git a/python/cuml/kernel_ridge/__init__.py b/python/cuml/kernel_ridge/__init__.py new file mode 100644 index 0000000000..e1d5692f5a --- /dev/null +++ b/python/cuml/kernel_ridge/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from cuml.kernel_ridge.kernel_ridge import KernelRidge diff --git a/python/cuml/kernel_ridge/kernel_ridge.pyx b/python/cuml/kernel_ridge/kernel_ridge.pyx new file mode 100644 index 0000000000..4625096db8 --- /dev/null +++ b/python/cuml/kernel_ridge/kernel_ridge.pyx @@ -0,0 +1,291 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# distutils: language = c++ + +import numpy as np +import warnings +from cupy import linalg +import cupy as cp +from cupyx import lapack, geterr, seterr +from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common.base import Base +from cuml.common.mixins import RegressorMixin +from cuml.common.doc_utils import generate_docstring +from cuml.common import input_to_cuml_array + +from cuml.metrics import pairwise_kernels + + +# cholesky solve with fallback to least squares for singular problems +def _safe_solve(K, y): + try: + # we need to set the error mode of cupy to raise + # otherwise we silently get an array of NaNs + err_mode = geterr()["linalg"] + seterr(linalg="raise") + dual_coef = lapack.posv(K, y) + seterr(linalg=err_mode) + except np.linalg.LinAlgError: + warnings.warn( + "Singular matrix in solving dual problem. Using " + "least-squares solution instead." + ) + dual_coef = linalg.lstsq(K, y, rcond=None)[0] + return dual_coef + + +def _solve_cholesky_kernel(K, y, alpha, sample_weight=None): + # dual_coef = inv(X X^t + alpha*Id) y + n_samples = K.shape[0] + n_targets = y.shape[1] + + K = cp.array(K, dtype=np.float64) + + alpha = cp.atleast_1d(alpha) + one_alpha = alpha.size == 1 + has_sw = sample_weight is not None + + if has_sw: + # Unlike other solvers, we need to support sample_weight directly + # because K might be a pre-computed kernel. + sw = cp.sqrt(cp.atleast_1d(sample_weight)) + y = y * sw[:, cp.newaxis] + K *= cp.outer(sw, sw) + + if one_alpha: + # Only one penalty, we can solve multi-target problems in one time. + K.flat[:: n_samples + 1] += alpha[0] + + dual_coef = _safe_solve(K, y) + + if has_sw: + dual_coef *= sw[:, cp.newaxis] + + return dual_coef + else: + # One penalty per target. We need to solve each target separately. + dual_coefs = cp.empty([n_targets, n_samples], K.dtype) + + for dual_coef, target, current_alpha in zip(dual_coefs, y.T, alpha): + K.flat[:: n_samples + 1] += current_alpha + + dual_coef[:] = _safe_solve(K, target).ravel() + + K.flat[:: n_samples + 1] -= current_alpha + + if has_sw: + dual_coefs *= sw[cp.newaxis, :] + + return dual_coefs.T + + +class KernelRidge(Base, RegressorMixin): + """ + Kernel ridge regression (KRR) performs l2 regularised ridge regression + using the kernel trick. The kernel trick allows the estimator to learn a + linear function in the space induced by the kernel. This may be a + non-linear function in the original feature space (when a non-linear + kernel is used). + This estimator supports multi-output regression (when y is 2 dimensional). + See the sklearn user guide for more information. + + Parameters + ---------- + alpha : float or array-like of shape (n_targets,), default=1.0 + Regularization strength; must be a positive float. Regularization + improves the conditioning of the problem and reduces the variance of + the estimates. Larger values specify stronger regularization. + If an array is passed, penalties are assumed to be specific + to the targets. + kernel : str or callable, default="linear" + Kernel mapping used internally. This parameter is directly passed to + :class:`~cuml.metrics.pairwise_kernel`. + If `kernel` is a string, it must be one of the metrics + in `cuml.metrics.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed". + If `kernel` is "precomputed", X is assumed to be a kernel matrix. + `kernel` may be a callable numba device function. If so, is called on + each pair of instances (rows) and the resulting value recorded. The + callable should take two rows from X as input and return the + corresponding kernel value as a single number. + gamma : float, default=None + Gamma parameter for the RBF, laplacian, polynomial, exponential chi2 + and sigmoid kernels. Interpretation of the default value is left to + the kernel; see the documentation for sklearn.metrics.pairwise. + Ignored by other kernels. + degree : float, default=3 + Degree of the polynomial kernel. Ignored by other kernels. + coef0 : float, default=1 + Zero coefficient for polynomial and sigmoid kernels. + Ignored by other kernels. + kernel_params : mapping of str to any, default=None + Additional parameters (keyword arguments) for kernel function passed + as callable object. + output_type : {'input', 'cudf', 'cupy', 'numpy', 'numba'}, default=None + Variable to control output type of the results and attributes of + the estimator. If None, it'll inherit the output type set at the + module level, `cuml.global_settings.output_type`. + See :ref:`output-data-type-configuration` for more info. + handle : cuml.Handle + Specifies the cuml.handle that holds internal CUDA state for + computations in this model. Most importantly, this specifies the + CUDA stream that will be used for the model's computations, so + users can run different models concurrently in different streams + by creating handles in several streams. + If it is None, a new one is created. + verbose : int or boolean, default=False + Sets logging level. It must be one of `cuml.common.logger.level_*`. + See :ref:`verbosity-levels` for more info. + Attributes + ---------- + dual_coef_ : ndarray of shape (n_samples,) or (n_samples, n_targets) + Representation of weight vector(s) in kernel space + X_fit_ : ndarray of shape (n_samples, n_features) + Training data, which is also required for prediction. If + kernel == "precomputed" this is instead the precomputed + training matrix, of shape (n_samples, n_samples). + + Examples + -------- + + .. code-block:: python + + import cupy as cp + from cuml.kernel_ridge import KernelRidge + from numba import cuda + import math + + n_samples, n_features = 10, 5 + rng = cp.random.RandomState(0) + y = rng.randn(n_samples) + X = rng.randn(n_samples, n_features) + + model = KernelRidge(kernel="poly").fit(X, y) + pred = model.predict(X) + + + @cuda.jit(device=True) + def custom_rbf_kernel(x, y, gamma=None): + if gamma is None: + gamma = 1.0 / len(x) + sum = 0.0 + for i in range(len(x)): + sum += (x[i] - y[i]) ** 2 + return math.exp(-gamma * sum) + + + model = KernelRidge(kernel=custom_rbf_kernel, + kernel_params={"gamma": 2.0}).fit(X, y) + pred = model.predict(X) + + """ + + dual_coef_ = CumlArrayDescriptor() + + def __init__( + self, + *, + alpha=1, + kernel="linear", + gamma=None, + degree=3, + coef0=1, + kernel_params=None, + output_type=None, + handle=None, + verbose=False + ): + super().__init__(handle=handle, verbose=verbose, + output_type=output_type) + self.alpha = cp.asarray(alpha) + self.kernel = kernel + self.gamma = gamma + self.degree = degree + self.coef0 = coef0 + self.kernel_params = kernel_params + + def get_param_names(self): + return super().get_param_names() + [ + "alpha", + "kernel", + "gamma", + "degree", + "coef0", + "kernel_params", + ] + + def _get_kernel(self, X, Y=None): + if isinstance(self.kernel, str): + params = {"gamma": self.gamma, + "degree": self.degree, "coef0": self.coef0} + else: + params = self.kernel_params or {} + return pairwise_kernels(X, Y, metric=self.kernel, + filter_params=True, **params) + + @generate_docstring() + def fit(self, X, y, sample_weight=None, + convert_dtype=True) -> "KernelRidge": + + ravel = False + if len(y.shape) == 1: + y = y.reshape(-1, 1) + ravel = True + + X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array( + X, check_dtype=[np.float32, np.float64] + ) + + y_m, _, _, _ = input_to_cuml_array( + y, + check_dtype=self.dtype, + convert_to_dtype=(self.dtype if convert_dtype else None), + check_rows=n_rows, + ) + + if self.n_cols < 1: + msg = "X matrix must have at least a column" + raise TypeError(msg) + + K = self._get_kernel(X_m) + self.dual_coef_ = _solve_cholesky_kernel( + K, cp.asarray(y_m), self.alpha, sample_weight + ) + + if ravel: + self.dual_coef_ = self.dual_coef_.ravel() + self.X_fit_ = X_m + return self + + def predict(self, X): + """Predict using the kernel ridge model. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Samples. If kernel == "precomputed" this is instead a + precomputed kernel matrix, shape = [n_samples, + n_samples_fitted], where n_samples_fitted is the number of + samples used in the fitting for this estimator. + Returns + ------- + C : array of shape (n_samples,) or (n_samples, n_targets) + Returns predicted values. + """ + X_m, _, _, _ = input_to_cuml_array( + X, check_dtype=[np.float32, np.float64]) + + K = self._get_kernel(X_m, self.X_fit_) + return cp.dot(cp.asarray(K), cp.asarray(self.dual_coef_)) diff --git a/python/cuml/metrics/__init__.py b/python/cuml/metrics/__init__.py index c90db9c1c7..276d6068af 100644 --- a/python/cuml/metrics/__init__.py +++ b/python/cuml/metrics/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,5 +36,7 @@ from cuml.metrics.pairwise_distances import sparse_pairwise_distances from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_METRICS from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_SPARSE_METRICS +from cuml.metrics.pairwise_kernels import pairwise_kernels +from cuml.metrics.pairwise_kernels import PAIRWISE_KERNEL_FUNCTIONS from cuml.metrics.hinge_loss import hinge_loss from cuml.metrics.kl_divergence import kl_divergence diff --git a/python/cuml/metrics/pairwise_kernels.py b/python/cuml/metrics/pairwise_kernels.py new file mode 100644 index 0000000000..ef2d5e2e41 --- /dev/null +++ b/python/cuml/metrics/pairwise_kernels.py @@ -0,0 +1,295 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +from numba import cuda +import cupy as cp +import numpy as np +import cuml.internals +from cuml.metrics import pairwise_distances + + +def linear_kernel(X, Y): + return cp.dot(X, Y.T) + + +def polynomial_kernel(X, Y, degree=3, gamma=None, coef0=1): + if gamma is None: + gamma = 1.0 / X.shape[1] + K = cp.dot(X, Y.T) + K *= gamma + K += coef0 + K **= degree + return K + + +def sigmoid_kernel(X, Y, gamma=None, coef0=1): + if gamma is None: + gamma = 1.0 / X.shape[1] + + K = cp.dot(X, Y.T) + K *= gamma + K += coef0 + cp.tanh(K, K) + return K + + +def rbf_kernel(X, Y, gamma=None): + if gamma is None: + gamma = 1.0 / X.shape[1] + + K = cp.asarray(pairwise_distances(X, Y, metric='sqeuclidean')) + K *= -gamma + cp.exp(K, K) + return K + + +def laplacian_kernel(X, Y, gamma=None): + if gamma is None: + gamma = 1.0 / X.shape[1] + + K = -gamma * cp.asarray(pairwise_distances(X, Y, metric='manhattan')) + cp.exp(K, K) + return K + + +def cosine_similarity(X, Y): + K = 1.0 - cp.asarray(pairwise_distances(X, Y, metric='cosine')) + return cp.nan_to_num(K, copy=False) + + +@cuda.jit(device=True) +def additive_chi2_kernel_element(x, y): + res = 0.0 + for i in range(len(x)): + denom = x[i] - y[i] + nom = x[i] + y[i] + if nom != 0.0: + res += denom * denom / nom + return -res + + +def additive_chi2_kernel(X, Y): + return custom_kernel(X, Y, additive_chi2_kernel_element) + + +def chi2_kernel(X, Y, gamma=1.0): + K = additive_chi2_kernel(X, Y) + K *= gamma + return cp.exp(K, K) + + +PAIRWISE_KERNEL_FUNCTIONS = { + "linear": linear_kernel, + "additive_chi2": additive_chi2_kernel, + "chi2": chi2_kernel, + "cosine": cosine_similarity, + "laplacian": laplacian_kernel, + "polynomial": polynomial_kernel, + "poly": polynomial_kernel, + "rbf": rbf_kernel, + "sigmoid": sigmoid_kernel, +} + + +def _filter_params(func, filter_params, **kwds): + # get all the possible extra function arguments, excluding x, y + py_func = func.py_func if hasattr(func, 'py_func') else func + all_func_kwargs = list(inspect.signature( + py_func).parameters.values()) + if len(all_func_kwargs) < 2: + raise ValueError( + "Expected at least two arguments to kernel function.") + + extra_arg_names = set(arg.name for arg in all_func_kwargs[2:]) + if not filter_params: + if not set(kwds.keys()) <= extra_arg_names: + raise ValueError( + "kwds contains arguments not used by kernel function") + return {k: v for k, v in kwds.items() if k in extra_arg_names} + + +def _kwds_to_tuple_args(func, **kwds): + # Returns keyword arguments formed as a tuple + # (numba kernels cannot deal with kwargs as a dict) + if not hasattr(func, "py_func"): + raise TypeError("Kernel function should be a numba device function.") + + # get all the possible extra function arguments, excluding x, y + all_func_kwargs = list(inspect.signature( + func.py_func).parameters.values()) + if len(all_func_kwargs) < 2: + raise ValueError( + "Expected at least two arguments to kernel function.") + + all_func_kwargs = all_func_kwargs[2:] + if any(p.default is inspect.Parameter.empty for p in all_func_kwargs): + raise ValueError( + "Extra kernel parameters must be passed as keyword arguments.") + all_func_kwargs = [(k.name, k.default) for k in all_func_kwargs] + + kwds_tuple = tuple( + kwds[k] if k in kwds.keys() else v for (k, v) in all_func_kwargs + ) + return kwds_tuple + + +_kernel_cache = {} + + +def custom_kernel(X, Y, func, **kwds): + kwds_tuple = _kwds_to_tuple_args( + func, **kwds) + + def evaluate_pairwise_kernels(X, Y, K): + idx = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x + X_m = X.shape[0] + Y_m = Y.shape[0] + row = idx // Y_m + col = idx % Y_m + if idx < X_m * Y_m: + if X is Y and row <= col: + # matrix is symmetric, reuse half the evaluations + k = func(X[row], Y[col], *kwds_tuple) + K[row, col] = k + K[col, row] = k + else: + k = func(X[row], Y[col], *kwds_tuple) + K[row, col] = k + + if Y is None: + Y = X + if X.shape[1] != Y.shape[1]: + raise ValueError("X and Y have different dimensions.") + + # Here we force K to use 64 bit, even if the input is 32 bit + # 32 bit K results in serious numerical stability problems + K = cp.zeros((X.shape[0], Y.shape[0]), dtype=np.float64) + + key = (func, kwds_tuple, X.dtype, Y.dtype) + if key in _kernel_cache: + compiled_kernel = _kernel_cache[key] + else: + compiled_kernel = cuda.jit(evaluate_pairwise_kernels) + _kernel_cache[key] = compiled_kernel + compiled_kernel.forall(X.shape[0] * Y.shape[0])(X, Y, K) + return K + + +@cuml.internals.api_return_array(get_output_type=True) +def pairwise_kernels(X, Y=None, metric="linear", *, + filter_params=False, convert_dtype=True, **kwds): + """Compute the kernel between arrays X and optional array Y. + This method takes either a vector array or a kernel matrix, and returns + a kernel matrix. If the input is a vector array, the kernels are + computed. If the input is a kernel matrix, it is returned instead. + This method provides a safe way to take a kernel matrix as input, while + preserving compatibility with many other algorithms that take a vector + array. + If Y is given (default is None), then the returned matrix is the pairwise + kernel between the arrays from both X and Y. + Valid values for metric are: + ['additive_chi2', 'chi2', 'linear', 'poly', 'polynomial', 'rbf', + 'laplacian', 'sigmoid', 'cosine'] + Parameters + ---------- + X : Dense matrix (device or host) of shape (n_samples_X, n_samples_X) or \ + (n_samples_X, n_features) + Array of pairwise kernels between samples, or a feature array. + The shape of the array should be (n_samples_X, n_samples_X) if + metric == "precomputed" and (n_samples_X, n_features) otherwise. + Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device + ndarray, cuda array interface compliant array like CuPy + Y : Dense matrix (device or host) of shape (n_samples_Y, n_features), + default=None + A second feature array only if X has shape (n_samples_X, n_features). + Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device + ndarray, cuda array interface compliant array like CuPy + metric : str or callable (numba device function), default="linear" + The metric to use when calculating kernel between instances in a + feature array. + If metric is "precomputed", X is assumed to be a kernel matrix. + Alternatively, if metric is a callable function, it is called on each + pair of instances (rows) and the resulting value recorded. The callable + should take two rows from X as input and return the corresponding + kernel value as a single number. + filter_params : bool, default=False + Whether to filter invalid parameters or not. + convert_dtype : bool, optional (default = True) + When set to True, the method will, when necessary, convert + Y to be the same data type as X if they differ. This + will increase memory used for the method. + **kwds : optional keyword parameters + Any further parameters are passed directly to the kernel function. + Returns + ------- + K : ndarray of shape (n_samples_X, n_samples_X) or \ + (n_samples_X, n_samples_Y) + A kernel matrix K such that K_{i, j} is the kernel between the + ith and jth vectors of the given matrix X, if Y is None. + If Y is not None, then K_{i, j} is the kernel between the ith array + from X and the jth array from Y. + Notes + ----- + If metric is 'precomputed', Y is ignored and X is returned. + + Examples + -------- + + .. code-block:: python + + import cupy as cp + from cuml.metrics import pairwise_kernels + from numba import cuda + import math + + X = cp.array([[2, 3], [3, 5], [5, 8]]) + Y = cp.array([[1, 0], [2, 1]]) + + pairwise_kernels(X, Y, metric='linear') + + @cuda.jit(device=True) + def custom_rbf_kernel(x, y, gamma=None): + if gamma is None: + gamma = 1.0 / len(x) + sum = 0.0 + for i in range(len(x)): + sum += (x[i] - y[i]) ** 2 + return math.exp(-gamma * sum) + + pairwise_kernels(X, Y, metric=custom_rbf_kernel) + + """ + X = cp.asarray(X) + if Y is None: + Y = X + else: + Y = cp.asarray(Y) + if X.shape[1] != Y.shape[1]: + raise ValueError("X and Y have different dimensions.") + + if metric == "precomputed": + return X + + if metric in PAIRWISE_KERNEL_FUNCTIONS: + kwds = _filter_params( + PAIRWISE_KERNEL_FUNCTIONS[metric], filter_params, **kwds) + return PAIRWISE_KERNEL_FUNCTIONS[metric](X, Y, **kwds) + elif isinstance(metric, str): + raise ValueError("Unknown kernel %r" % metric) + else: + kwds = _filter_params( + metric, filter_params, **kwds) + return custom_kernel(X, Y, metric, **kwds) diff --git a/python/cuml/test/test_kernel_ridge.py b/python/cuml/test/test_kernel_ridge.py new file mode 100644 index 0000000000..b9d08bac55 --- /dev/null +++ b/python/cuml/test/test_kernel_ridge.py @@ -0,0 +1,276 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import cupy as cp +from cupy import linalg +import numpy as np +from numba import cuda +from cuml import KernelRidge as cuKernelRidge +from cuml.metrics import pairwise_kernels, PAIRWISE_KERNEL_FUNCTIONS +from sklearn.metrics.pairwise import pairwise_kernels as skl_pairwise_kernels +import pytest +import math +import inspect +from sklearn.kernel_ridge import KernelRidge as sklKernelRidge +from hypothesis import given, settings, assume, strategies as st +from hypothesis.extra.numpy import arrays + + +def gradient_norm(X, y, model, K, sw=None): + if sw is None: + sw = cp.ones(X.shape[0]) + else: + sw = cp.atleast_1d(cp.array(sw, dtype=np.float64)) + + X = cp.array(X, dtype=np.float64) + y = cp.array(y, dtype=np.float64) + K = cp.array(K, dtype=np.float64) + betas = cp.array(model.dual_coef_, dtype=np.float64).reshape(y.shape) + + # initialise to NaN in case below loop has 0 iterations + grads = cp.full_like(y, np.NAN) + for i, (beta, target, current_alpha) in ( + enumerate(zip(betas.T, y.T, model.alpha))): + grads[:, i] = 0.0 + grads[:, i] = -cp.dot(K * sw, target) + grads[:, i] += cp.dot(cp.dot(K * sw, K), beta) + grads[:, i] += cp.dot(K * current_alpha, beta) + return linalg.norm(grads) + + +def test_pairwise_kernels_basic(): + X = np.zeros((4, 4)) + # standard kernel with no argument + pairwise_kernels(X, metric="chi2") + pairwise_kernels(X, metric="linear") + # standard kernel with correct kwd argument + pairwise_kernels(X, metric="chi2", gamma=1.0) + # standard kernel with incorrect kwd argument + with pytest.raises( + ValueError, match="kwds contains arguments not used by kernel function" + ): + pairwise_kernels(X, metric="linear", wrong_parameter_name=1.0) + # standard kernel with filtered kwd argument + pairwise_kernels(X, metric="rbf", filter_params=True, + wrong_parameter_name=1.0) + + # incorrect function type + def non_numba_kernel(x, y): + return x.dot(y) + + with pytest.raises( + TypeError, match="Kernel function should be a numba device function." + ): + pairwise_kernels(X, metric=non_numba_kernel) + + # correct function type + @cuda.jit(device=True) + def numba_kernel(x, y, + special_argument=3.0): + return 1 + 2 + + pairwise_kernels(X, metric=numba_kernel) + pairwise_kernels(X, metric=numba_kernel, special_argument=1.0) + + # malformed function + @cuda.jit(device=True) + def bad_numba_kernel(x): + return 1 + 2 + + with pytest.raises( + ValueError, match="Expected at least two arguments to kernel function." + ): + pairwise_kernels(X, metric=bad_numba_kernel) + + # malformed function 2 - No default value + @cuda.jit(device=True) + def bad_numba_kernel2(x, y, z): + return 1 + 2 + + with pytest.raises( + ValueError, match="Extra kernel parameters " + "must be passed as keyword arguments." + ): + pairwise_kernels(X, metric=bad_numba_kernel2) + + # Precomputed + assert np.allclose(X, pairwise_kernels(X, metric="precomputed")) + + +@cuda.jit(device=True) +def custom_kernel(x, y, custom_arg=5.0): + sum = 0.0 + for i in range(len(x)): + sum += (x[i] - y[i]) ** 2 + return math.exp(-custom_arg * sum) + 0.1 + + +test_kernels = sorted( + PAIRWISE_KERNEL_FUNCTIONS.keys()) + [custom_kernel] + + +@st.composite +def kernel_arg_strategy(draw): + kernel = draw(st.sampled_from(test_kernels)) + kernel_func = ( + PAIRWISE_KERNEL_FUNCTIONS[kernel] if isinstance( + kernel, str) else kernel + ) + # Inspect the function and generate some arguments + py_func = kernel_func.py_func if hasattr( + kernel_func, 'py_func') else kernel_func + all_func_kwargs = list( + inspect.signature(py_func).parameters.values())[ + 2: + ] + param = {} + for arg in all_func_kwargs: + # 50% chance we generate this parameter or leave it as default + if draw(st.booleans()): + continue + if isinstance(arg.default, float) or arg.default is None: + param[arg.name] = draw(st.floats(0.0, 5.0)) + if isinstance(arg.default, int): + param[arg.name] = draw(st.integers(1, 5)) + + return (kernel, param) + + +@st.composite +def array_strategy(draw): + X_m = draw(st.integers(1, 20)) + X_n = draw(st.integers(1, 10)) + dtype = draw(st.sampled_from([np.float64, np.float32])) + X = draw(arrays(dtype=dtype, shape=(X_m, X_n), + elements=st.floats(0, 5, width=32),)) + if draw(st.booleans()): + Y_m = draw(st.integers(1, 20)) + Y = draw( + arrays(dtype=dtype, shape=(Y_m, X_n), + elements=st.floats(0, 5, width=32),) + ) + else: + Y = None + return (X, Y) + + +@given(kernel_arg_strategy(), array_strategy()) +@settings(deadline=None) +def test_pairwise_kernels(kernel_arg, XY): + X, Y = XY + kernel, args = kernel_arg + K = pairwise_kernels(X, Y, metric=kernel, **args) + skl_kernel = kernel.py_func if hasattr(kernel, "py_func") else kernel + K_sklearn = skl_pairwise_kernels(X, Y, metric=skl_kernel, **args) + assert np.allclose(K, K_sklearn, atol=0.01, rtol=0.01) + + +@st.composite +def estimator_array_strategy(draw): + X_m = draw(st.integers(5, 20)) + X_n = draw(st.integers(2, 10)) + dtype = draw(st.sampled_from([np.float64, np.float32])) + rs = np.random.RandomState(draw(st.integers(1, 10))) + X = rs.rand(X_m, X_n).astype(dtype) + X_test = rs.rand(draw(st.integers(5, 20)), X_n).astype(dtype) + + n_targets = draw(st.integers(1, 3)) + a = draw( + arrays(dtype=dtype, shape=(X_n, n_targets), + elements=st.floats(0, 5, width=32),) + ) + y = X.dot(a) + + alpha = draw( + arrays(dtype=dtype, shape=(n_targets), + elements=st.floats(0, 5, width=32)) + ) + + sample_weight = draw( + st.one_of( + [ + st.just(None), + st.floats(0.1, 1.5), + arrays(dtype=np.float64, shape=X_m, + elements=st.floats(0.1, 5)), + ] + ) + ) + return (X, y, X_test, alpha, sample_weight) + + +@given( + kernel_arg_strategy(), + estimator_array_strategy(), + st.floats(1.0, 5.0), + st.integers(1, 5), + st.floats(1.0, 5.0), +) +@settings(deadline=None) +def test_estimator(kernel_arg, arrays, gamma, degree, coef0): + kernel, args = kernel_arg + X, y, X_test, alpha, sample_weight = arrays + model = cuKernelRidge( + kernel=kernel, + alpha=alpha, + gamma=gamma, + degree=degree, + coef0=coef0, + kernel_params=args, + ) + skl_kernel = kernel.py_func if hasattr(kernel, "py_func") else kernel + skl_model = sklKernelRidge( + kernel=skl_kernel, + alpha=alpha, + gamma=gamma, + degree=degree, + coef0=coef0, + kernel_params=args, + ) + if kernel == "chi2" or kernel == "additive_chi2": + # X must be positive + X = X + abs(X.min()) + 1.0 + + model.fit(X, y, sample_weight) + pred = model.predict(X_test).get() + if X.dtype == np.float64: + # For a convex optimisation problem we should arrive at gradient norm 0 + # If the solution has converged correctly + K = model._get_kernel(X) + grad_norm = gradient_norm(X, y, model, K, sample_weight) + assert grad_norm < 0.1 + try: + skl_model.fit(X, y, sample_weight) + except np.linalg.LinAlgError: + # sklearn can fail to fit multiclass models + # with singular kernel matrices + assume(False) + + skl_pred = skl_model.predict(X_test) + assert np.allclose(pred, skl_pred, atol=1e-2, rtol=1e-2) + + +def test_precomputed(): + rs = np.random.RandomState(23) + X = rs.normal(size=(10, 10)) + y = rs.normal(size=10) + K = pairwise_kernels(X) + precomputed_model = cuKernelRidge(kernel="precomputed") + precomputed_model.fit(K, y) + model = cuKernelRidge() + model.fit(X, y) + assert np.allclose(precomputed_model.dual_coef_, model.dual_coef_) + assert np.allclose( + precomputed_model.predict(K), model.predict(X), atol=1e-5, rtol=1e-5 + )