Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel ridge regression #4492

Merged
merged 19 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ log
.DS_Store
dask-worker-space/
tmp/
.hypothesis

## files pickled in notebook when ran during python docstring generation
docs/source/*.model
Expand Down
4 changes: 3 additions & 1 deletion python/cuml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,6 +45,8 @@
from cuml.internals.global_settings import (
GlobalSettings, _global_settings_data)

from cuml.kernel_ridge.kernel_ridge import KernelRidge

from cuml.linear_model.elastic_net import ElasticNet
from cuml.linear_model.lasso import Lasso
from cuml.linear_model.linear_regression import LinearRegression
Expand Down
18 changes: 18 additions & 0 deletions python/cuml/kernel_ridge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from cuml.kernel_ridge.kernel_ridge import KernelRidge
266 changes: 266 additions & 0 deletions python/cuml/kernel_ridge/kernel_ridge.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed this- we should remove 2019 since this is a new file.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# distutils: language = c++

import numpy as np
import warnings
from cupy import linalg
import cupy as cp
from cupyx import lapack, geterr, seterr
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.common.base import Base
from cuml.common.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_cuml_array

from cuml.metrics import pairwise_kernels


# cholesky solve with fallback to least squares for singular problems
def _safe_solve(K, y):
try:
# we need to set the error mode of cupy to raise
# otherwise we silently get an array of NaNs
err_mode = geterr()["linalg"]
seterr(linalg="raise")
dual_coef = lapack.posv(K, y)
seterr(linalg=err_mode)
except np.linalg.LinAlgError:
warnings.warn(
"Singular matrix in solving dual problem. Using "
"least-squares solution instead."
)
dual_coef = linalg.lstsq(K, y, rcond=None)[0]
return dual_coef


def _solve_cholesky_kernel(K, y, alpha, sample_weight=None):
# dual_coef = inv(X X^t + alpha*Id) y
n_samples = K.shape[0]
n_targets = y.shape[1]

K = cp.array(K, dtype=np.float64)

alpha = cp.atleast_1d(alpha)
one_alpha = alpha.size == 1
has_sw = sample_weight is not None

if has_sw:
# Unlike other solvers, we need to support sample_weight directly
# because K might be a pre-computed kernel.
sw = cp.sqrt(cp.atleast_1d(sample_weight))
y = y * sw[:, cp.newaxis]
K *= cp.outer(sw, sw)

if one_alpha:
# Only one penalty, we can solve multi-target problems in one time.
K.flat[:: n_samples + 1] += alpha[0]

dual_coef = _safe_solve(K, y)

if has_sw:
dual_coef *= sw[:, cp.newaxis]

return dual_coef
else:
# One penalty per target. We need to solve each target separately.
dual_coefs = cp.empty([n_targets, n_samples], K.dtype)

for dual_coef, target, current_alpha in zip(dual_coefs, y.T, alpha):
K.flat[:: n_samples + 1] += current_alpha

dual_coef[:] = _safe_solve(K, target).ravel()

K.flat[:: n_samples + 1] -= current_alpha

if has_sw:
dual_coefs *= sw[cp.newaxis, :]

return dual_coefs.T


class KernelRidge(Base, RegressorMixin):
"""Kernel ridge regression.
Kernel ridge regression (KRR) performs l2 regularised ridge regression
using the kernel trick. The kernel trick allows the estimator to learn a
linear function in the space induced by the kernel. This may be a
non-linear function in the original feature space (when a non-linear
kernel is used).
This estimator supports multi-output regression (when y is 2 dimensional).
See the sklearn user guide for more information.

Parameters
----------
alpha : float or array-like of shape (n_targets,), default=1.0
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
If an array is passed, penalties are assumed to be specific
to the targets.
kernel : str or callable, default="linear"
Kernel mapping used internally. This parameter is directly passed to
:class:`~cuml.metrics.pairwise_kernel`.
If `kernel` is a string, it must be one of the metrics
in `cuml.metrics.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed".
If `kernel` is "precomputed", X is assumed to be a kernel matrix.
`kernel` may be a callable numba device function. If so, is called on
each pair of instances (rows) and the resulting value recorded. The
callable should take two rows from X as input and return the
corresponding kernel value as a single number.
gamma : float, default=None
Gamma parameter for the RBF, laplacian, polynomial, exponential chi2
and sigmoid kernels. Interpretation of the default value is left to
the kernel; see the documentation for sklearn.metrics.pairwise.
Ignored by other kernels.
degree : float, default=3
Degree of the polynomial kernel. Ignored by other kernels.
coef0 : float, default=1
Zero coefficient for polynomial and sigmoid kernels.
Ignored by other kernels.
kernel_params : mapping of str to any, default=None
Additional parameters (keyword arguments) for kernel function passed
as callable object.
Attributes
----------
dual_coef_ : ndarray of shape (n_samples,) or (n_samples, n_targets)
Representation of weight vector(s) in kernel space
X_fit_ : ndarray of shape (n_samples, n_features)
Training data, which is also required for prediction. If
kernel == "precomputed" this is instead the precomputed
training matrix, of shape (n_samples, n_samples).

Examples
--------

.. code-block:: python

import cupy as cp
from cuml.kernel_ridge import KernelRidge
from numba import cuda
import math

n_samples, n_features = 10, 5
rng = cp.random.RandomState(0)
y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)

model = KernelRidge(kernel="poly").fit(X, y)
pred = model.predict(X)


@cuda.jit(device=True)
def custom_rbf_kernel(x, y, gamma=None):
if gamma is None:
gamma = 1.0 / len(x)
sum = 0.0
for i in range(len(x)):
sum += (x[i] - y[i]) ** 2
return math.exp(-gamma * sum)


model = KernelRidge(kernel=custom_rbf_kernel,
kernel_params={"gamma": 2.0}).fit(X, y)
pred = model.predict(X)

"""

dual_coef_ = CumlArrayDescriptor()

def __init__(
self,
*,
alpha=1,
kernel="linear",
gamma=None,
degree=3,
coef0=1,
kernel_params=None,
handle=None,
output_type=None,
verbose=False
):
super().__init__(handle=handle, verbose=verbose,
output_type=output_type)
self.alpha = cp.asarray(alpha)
self.kernel = kernel
self.gamma = gamma
self.degree = degree
self.coef0 = coef0
self.kernel_params = kernel_params

def _get_kernel(self, X, Y=None):
if isinstance(self.kernel, str):
params = {"gamma": self.gamma,
"degree": self.degree, "coef0": self.coef0}
else:
params = self.kernel_params or {}
return pairwise_kernels(X, metric=self.kernel,
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
filter_params=True, **params)

@generate_docstring()
def fit(self, X, y, sample_weight=None,
convert_dtype=True) -> "KernelRidge":

ravel = False
if len(y.shape) == 1:
y = y.reshape(-1, 1)
ravel = True

X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64]
)

y_m, _, _, _ = input_to_cuml_array(
y,
check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype else None),
check_rows=n_rows,
)

if self.n_cols < 1:
msg = "X matrix must have at least a column"
raise TypeError(msg)

K = self._get_kernel(X_m, self.kernel)
self.dual_coef_ = _solve_cholesky_kernel(
K, cp.asarray(y_m), self.alpha, sample_weight
)

if ravel:
self.dual_coef_ = self.dual_coef_.ravel()
self.X_fit_ = X_m
return self

def predict(self, X):
"""Predict using the kernel ridge model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples. If kernel == "precomputed" this is instead a
precomputed kernel matrix, shape = [n_samples,
n_samples_fitted], where n_samples_fitted is the number of
samples used in the fitting for this estimator.
Returns
-------
C : array of shape (n_samples,) or (n_samples, n_targets)
Returns predicted values.
"""
X_m, _, _, _ = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64])

K = self._get_kernel(X_m, self.X_fit_)
return cp.dot(cp.asarray(K), cp.asarray(self.dual_coef_))
4 changes: 3 additions & 1 deletion python/cuml/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,5 +36,7 @@
from cuml.metrics.pairwise_distances import sparse_pairwise_distances
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_METRICS
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_SPARSE_METRICS
from cuml.metrics.pairwise_kernels import pairwise_kernels
from cuml.metrics.pairwise_kernels import PAIRWISE_KERNEL_FUNCTIONS
from cuml.metrics.hinge_loss import hinge_loss
from cuml.metrics.kl_divergence import kl_divergence
Loading