-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kernel ridge regression #4492
Merged
Merged
Kernel ridge regression #4492
Changes from 13 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
11114f6
Basic working implementation
RAMitchell 47ec1cf
Parameter dispatch
RAMitchell 9067a37
Implement prediction.
RAMitchell 1f8ecac
Cache jit compile
RAMitchell a7ec29c
Numerical stability
RAMitchell 272d2f8
Sample weights
RAMitchell 9ab20c8
Remove deadline
RAMitchell 675fce6
Cythonise, move files
RAMitchell 7057929
Docs
RAMitchell e7a75db
Lint
RAMitchell c3bd2c4
Merge branch 'branch-22.04' of https://github.com/rapidsai/cuml into …
RAMitchell 7f970bf
Lint
RAMitchell 2c3f249
Copyright
RAMitchell 483e164
Fix predict
RAMitchell f245462
Pass base tests
RAMitchell a6dc27d
Change kernel implementations
RAMitchell bfab33b
Don't test float32 gradient - too unstable.
RAMitchell b9b1346
Copyright
RAMitchell c9955f0
Review comments
RAMitchell File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
|
||
from cuml.kernel_ridge.kernel_ridge import KernelRidge |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
# | ||
# Copyright (c) 2019-2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# distutils: language = c++ | ||
|
||
import numpy as np | ||
import warnings | ||
from cupy import linalg | ||
import cupy as cp | ||
from cupyx import lapack, geterr, seterr | ||
from cuml.common.array_descriptor import CumlArrayDescriptor | ||
from cuml.common.base import Base | ||
from cuml.common.mixins import RegressorMixin | ||
from cuml.common.doc_utils import generate_docstring | ||
from cuml.common import input_to_cuml_array | ||
|
||
from cuml.metrics import pairwise_kernels | ||
|
||
|
||
# cholesky solve with fallback to least squares for singular problems | ||
def _safe_solve(K, y): | ||
try: | ||
# we need to set the error mode of cupy to raise | ||
# otherwise we silently get an array of NaNs | ||
err_mode = geterr()["linalg"] | ||
seterr(linalg="raise") | ||
dual_coef = lapack.posv(K, y) | ||
seterr(linalg=err_mode) | ||
except np.linalg.LinAlgError: | ||
warnings.warn( | ||
"Singular matrix in solving dual problem. Using " | ||
"least-squares solution instead." | ||
) | ||
dual_coef = linalg.lstsq(K, y, rcond=None)[0] | ||
return dual_coef | ||
|
||
|
||
def _solve_cholesky_kernel(K, y, alpha, sample_weight=None): | ||
# dual_coef = inv(X X^t + alpha*Id) y | ||
n_samples = K.shape[0] | ||
n_targets = y.shape[1] | ||
|
||
K = cp.array(K, dtype=np.float64) | ||
|
||
alpha = cp.atleast_1d(alpha) | ||
one_alpha = alpha.size == 1 | ||
has_sw = sample_weight is not None | ||
|
||
if has_sw: | ||
# Unlike other solvers, we need to support sample_weight directly | ||
# because K might be a pre-computed kernel. | ||
sw = cp.sqrt(cp.atleast_1d(sample_weight)) | ||
y = y * sw[:, cp.newaxis] | ||
K *= cp.outer(sw, sw) | ||
|
||
if one_alpha: | ||
# Only one penalty, we can solve multi-target problems in one time. | ||
K.flat[:: n_samples + 1] += alpha[0] | ||
|
||
dual_coef = _safe_solve(K, y) | ||
|
||
if has_sw: | ||
dual_coef *= sw[:, cp.newaxis] | ||
|
||
return dual_coef | ||
else: | ||
# One penalty per target. We need to solve each target separately. | ||
dual_coefs = cp.empty([n_targets, n_samples], K.dtype) | ||
|
||
for dual_coef, target, current_alpha in zip(dual_coefs, y.T, alpha): | ||
K.flat[:: n_samples + 1] += current_alpha | ||
|
||
dual_coef[:] = _safe_solve(K, target).ravel() | ||
|
||
K.flat[:: n_samples + 1] -= current_alpha | ||
|
||
if has_sw: | ||
dual_coefs *= sw[cp.newaxis, :] | ||
|
||
return dual_coefs.T | ||
|
||
|
||
class KernelRidge(Base, RegressorMixin): | ||
"""Kernel ridge regression. | ||
Kernel ridge regression (KRR) performs l2 regularised ridge regression | ||
using the kernel trick. The kernel trick allows the estimator to learn a | ||
linear function in the space induced by the kernel. This may be a | ||
non-linear function in the original feature space (when a non-linear | ||
kernel is used). | ||
This estimator supports multi-output regression (when y is 2 dimensional). | ||
See the sklearn user guide for more information. | ||
|
||
Parameters | ||
---------- | ||
alpha : float or array-like of shape (n_targets,), default=1.0 | ||
Regularization strength; must be a positive float. Regularization | ||
improves the conditioning of the problem and reduces the variance of | ||
the estimates. Larger values specify stronger regularization. | ||
If an array is passed, penalties are assumed to be specific | ||
to the targets. | ||
kernel : str or callable, default="linear" | ||
Kernel mapping used internally. This parameter is directly passed to | ||
:class:`~cuml.metrics.pairwise_kernel`. | ||
If `kernel` is a string, it must be one of the metrics | ||
in `cuml.metrics.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed". | ||
If `kernel` is "precomputed", X is assumed to be a kernel matrix. | ||
`kernel` may be a callable numba device function. If so, is called on | ||
each pair of instances (rows) and the resulting value recorded. The | ||
callable should take two rows from X as input and return the | ||
corresponding kernel value as a single number. | ||
gamma : float, default=None | ||
Gamma parameter for the RBF, laplacian, polynomial, exponential chi2 | ||
and sigmoid kernels. Interpretation of the default value is left to | ||
the kernel; see the documentation for sklearn.metrics.pairwise. | ||
Ignored by other kernels. | ||
degree : float, default=3 | ||
Degree of the polynomial kernel. Ignored by other kernels. | ||
coef0 : float, default=1 | ||
Zero coefficient for polynomial and sigmoid kernels. | ||
Ignored by other kernels. | ||
kernel_params : mapping of str to any, default=None | ||
Additional parameters (keyword arguments) for kernel function passed | ||
as callable object. | ||
Attributes | ||
---------- | ||
dual_coef_ : ndarray of shape (n_samples,) or (n_samples, n_targets) | ||
Representation of weight vector(s) in kernel space | ||
X_fit_ : ndarray of shape (n_samples, n_features) | ||
Training data, which is also required for prediction. If | ||
kernel == "precomputed" this is instead the precomputed | ||
training matrix, of shape (n_samples, n_samples). | ||
|
||
Examples | ||
-------- | ||
|
||
.. code-block:: python | ||
|
||
import cupy as cp | ||
from cuml.kernel_ridge import KernelRidge | ||
from numba import cuda | ||
import math | ||
|
||
n_samples, n_features = 10, 5 | ||
rng = cp.random.RandomState(0) | ||
y = rng.randn(n_samples) | ||
X = rng.randn(n_samples, n_features) | ||
|
||
model = KernelRidge(kernel="poly").fit(X, y) | ||
pred = model.predict(X) | ||
|
||
|
||
@cuda.jit(device=True) | ||
def custom_rbf_kernel(x, y, gamma=None): | ||
if gamma is None: | ||
gamma = 1.0 / len(x) | ||
sum = 0.0 | ||
for i in range(len(x)): | ||
sum += (x[i] - y[i]) ** 2 | ||
return math.exp(-gamma * sum) | ||
|
||
|
||
model = KernelRidge(kernel=custom_rbf_kernel, | ||
kernel_params={"gamma": 2.0}).fit(X, y) | ||
pred = model.predict(X) | ||
|
||
""" | ||
|
||
dual_coef_ = CumlArrayDescriptor() | ||
|
||
def __init__( | ||
self, | ||
*, | ||
alpha=1, | ||
kernel="linear", | ||
gamma=None, | ||
degree=3, | ||
coef0=1, | ||
kernel_params=None, | ||
handle=None, | ||
output_type=None, | ||
verbose=False | ||
): | ||
super().__init__(handle=handle, verbose=verbose, | ||
output_type=output_type) | ||
self.alpha = cp.asarray(alpha) | ||
self.kernel = kernel | ||
self.gamma = gamma | ||
self.degree = degree | ||
self.coef0 = coef0 | ||
self.kernel_params = kernel_params | ||
|
||
def _get_kernel(self, X, Y=None): | ||
if isinstance(self.kernel, str): | ||
params = {"gamma": self.gamma, | ||
"degree": self.degree, "coef0": self.coef0} | ||
else: | ||
params = self.kernel_params or {} | ||
return pairwise_kernels(X, metric=self.kernel, | ||
RAMitchell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
filter_params=True, **params) | ||
|
||
@generate_docstring() | ||
def fit(self, X, y, sample_weight=None, | ||
convert_dtype=True) -> "KernelRidge": | ||
|
||
ravel = False | ||
if len(y.shape) == 1: | ||
y = y.reshape(-1, 1) | ||
ravel = True | ||
|
||
X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array( | ||
X, check_dtype=[np.float32, np.float64] | ||
) | ||
|
||
y_m, _, _, _ = input_to_cuml_array( | ||
y, | ||
check_dtype=self.dtype, | ||
convert_to_dtype=(self.dtype if convert_dtype else None), | ||
check_rows=n_rows, | ||
) | ||
|
||
if self.n_cols < 1: | ||
msg = "X matrix must have at least a column" | ||
raise TypeError(msg) | ||
|
||
K = self._get_kernel(X_m, self.kernel) | ||
self.dual_coef_ = _solve_cholesky_kernel( | ||
K, cp.asarray(y_m), self.alpha, sample_weight | ||
) | ||
|
||
if ravel: | ||
self.dual_coef_ = self.dual_coef_.ravel() | ||
self.X_fit_ = X_m | ||
return self | ||
|
||
def predict(self, X): | ||
"""Predict using the kernel ridge model. | ||
Parameters | ||
---------- | ||
X : array-like of shape (n_samples, n_features) | ||
Samples. If kernel == "precomputed" this is instead a | ||
precomputed kernel matrix, shape = [n_samples, | ||
n_samples_fitted], where n_samples_fitted is the number of | ||
samples used in the fitting for this estimator. | ||
Returns | ||
------- | ||
C : array of shape (n_samples,) or (n_samples, n_targets) | ||
Returns predicted values. | ||
""" | ||
X_m, _, _, _ = input_to_cuml_array( | ||
X, check_dtype=[np.float32, np.float64]) | ||
|
||
K = self._get_kernel(X_m, self.X_fit_) | ||
return cp.dot(cp.asarray(K), cp.asarray(self.dual_coef_)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noticed this- we should remove 2019 since this is a new file.