Skip to content

Commit

Permalink
Kernel ridge regression (#4492)
Browse files Browse the repository at this point in the history
Sklearn reference implementation: https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09b/sklearn/kernel_ridge.py#L16

I've tried to avoid touching the c++/cuda layer so far. Pairwise kernels are implented based on a numba kernel for now. I've also used cupy's lapack wrapper to access cuSolver.

The implementation of `pairwise_kernels` here can be reused to very easily implement kernel PCA.

Todo:
- [x] Single target fit/predict
- [x] Standard kernels implemented
- [x] Support custom kernels
- [x] Support sample weights
- [ ] ~~Support CSR X matrix. Maybe too difficult for this PR.~~
- [x] Multi-target fit/predict
- [x] Change .py files to .pyx and moved to correct places.
- [x] Benchmarking on reasonably large files
- [x] Tests take less than 20s
- [x] Ensure correct handling of input/output array types (I think I need to be using CumlArray and maybe some decorators)
- [x] Documentation

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Micka (https://github.com/lowener)

URL: #4492
  • Loading branch information
RAMitchell authored Feb 9, 2022
1 parent 9b52960 commit 9921c61
Show file tree
Hide file tree
Showing 7 changed files with 887 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ log
.DS_Store
dask-worker-space/
tmp/
.hypothesis

## files pickled in notebook when ran during python docstring generation
docs/source/*.model
Expand Down
4 changes: 3 additions & 1 deletion python/cuml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,6 +45,8 @@
from cuml.internals.global_settings import (
GlobalSettings, _global_settings_data)

from cuml.kernel_ridge.kernel_ridge import KernelRidge

from cuml.linear_model.elastic_net import ElasticNet
from cuml.linear_model.lasso import Lasso
from cuml.linear_model.linear_regression import LinearRegression
Expand Down
18 changes: 18 additions & 0 deletions python/cuml/kernel_ridge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


from cuml.kernel_ridge.kernel_ridge import KernelRidge
291 changes: 291 additions & 0 deletions python/cuml/kernel_ridge/kernel_ridge.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# distutils: language = c++

import numpy as np
import warnings
from cupy import linalg
import cupy as cp
from cupyx import lapack, geterr, seterr
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.common.base import Base
from cuml.common.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_cuml_array

from cuml.metrics import pairwise_kernels


# cholesky solve with fallback to least squares for singular problems
def _safe_solve(K, y):
try:
# we need to set the error mode of cupy to raise
# otherwise we silently get an array of NaNs
err_mode = geterr()["linalg"]
seterr(linalg="raise")
dual_coef = lapack.posv(K, y)
seterr(linalg=err_mode)
except np.linalg.LinAlgError:
warnings.warn(
"Singular matrix in solving dual problem. Using "
"least-squares solution instead."
)
dual_coef = linalg.lstsq(K, y, rcond=None)[0]
return dual_coef


def _solve_cholesky_kernel(K, y, alpha, sample_weight=None):
# dual_coef = inv(X X^t + alpha*Id) y
n_samples = K.shape[0]
n_targets = y.shape[1]

K = cp.array(K, dtype=np.float64)

alpha = cp.atleast_1d(alpha)
one_alpha = alpha.size == 1
has_sw = sample_weight is not None

if has_sw:
# Unlike other solvers, we need to support sample_weight directly
# because K might be a pre-computed kernel.
sw = cp.sqrt(cp.atleast_1d(sample_weight))
y = y * sw[:, cp.newaxis]
K *= cp.outer(sw, sw)

if one_alpha:
# Only one penalty, we can solve multi-target problems in one time.
K.flat[:: n_samples + 1] += alpha[0]

dual_coef = _safe_solve(K, y)

if has_sw:
dual_coef *= sw[:, cp.newaxis]

return dual_coef
else:
# One penalty per target. We need to solve each target separately.
dual_coefs = cp.empty([n_targets, n_samples], K.dtype)

for dual_coef, target, current_alpha in zip(dual_coefs, y.T, alpha):
K.flat[:: n_samples + 1] += current_alpha

dual_coef[:] = _safe_solve(K, target).ravel()

K.flat[:: n_samples + 1] -= current_alpha

if has_sw:
dual_coefs *= sw[cp.newaxis, :]

return dual_coefs.T


class KernelRidge(Base, RegressorMixin):
"""
Kernel ridge regression (KRR) performs l2 regularised ridge regression
using the kernel trick. The kernel trick allows the estimator to learn a
linear function in the space induced by the kernel. This may be a
non-linear function in the original feature space (when a non-linear
kernel is used).
This estimator supports multi-output regression (when y is 2 dimensional).
See the sklearn user guide for more information.
Parameters
----------
alpha : float or array-like of shape (n_targets,), default=1.0
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
If an array is passed, penalties are assumed to be specific
to the targets.
kernel : str or callable, default="linear"
Kernel mapping used internally. This parameter is directly passed to
:class:`~cuml.metrics.pairwise_kernel`.
If `kernel` is a string, it must be one of the metrics
in `cuml.metrics.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed".
If `kernel` is "precomputed", X is assumed to be a kernel matrix.
`kernel` may be a callable numba device function. If so, is called on
each pair of instances (rows) and the resulting value recorded. The
callable should take two rows from X as input and return the
corresponding kernel value as a single number.
gamma : float, default=None
Gamma parameter for the RBF, laplacian, polynomial, exponential chi2
and sigmoid kernels. Interpretation of the default value is left to
the kernel; see the documentation for sklearn.metrics.pairwise.
Ignored by other kernels.
degree : float, default=3
Degree of the polynomial kernel. Ignored by other kernels.
coef0 : float, default=1
Zero coefficient for polynomial and sigmoid kernels.
Ignored by other kernels.
kernel_params : mapping of str to any, default=None
Additional parameters (keyword arguments) for kernel function passed
as callable object.
output_type : {'input', 'cudf', 'cupy', 'numpy', 'numba'}, default=None
Variable to control output type of the results and attributes of
the estimator. If None, it'll inherit the output type set at the
module level, `cuml.global_settings.output_type`.
See :ref:`output-data-type-configuration` for more info.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the
CUDA stream that will be used for the model's computations, so
users can run different models concurrently in different streams
by creating handles in several streams.
If it is None, a new one is created.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Attributes
----------
dual_coef_ : ndarray of shape (n_samples,) or (n_samples, n_targets)
Representation of weight vector(s) in kernel space
X_fit_ : ndarray of shape (n_samples, n_features)
Training data, which is also required for prediction. If
kernel == "precomputed" this is instead the precomputed
training matrix, of shape (n_samples, n_samples).
Examples
--------
.. code-block:: python
import cupy as cp
from cuml.kernel_ridge import KernelRidge
from numba import cuda
import math
n_samples, n_features = 10, 5
rng = cp.random.RandomState(0)
y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)
model = KernelRidge(kernel="poly").fit(X, y)
pred = model.predict(X)
@cuda.jit(device=True)
def custom_rbf_kernel(x, y, gamma=None):
if gamma is None:
gamma = 1.0 / len(x)
sum = 0.0
for i in range(len(x)):
sum += (x[i] - y[i]) ** 2
return math.exp(-gamma * sum)
model = KernelRidge(kernel=custom_rbf_kernel,
kernel_params={"gamma": 2.0}).fit(X, y)
pred = model.predict(X)
"""

dual_coef_ = CumlArrayDescriptor()

def __init__(
self,
*,
alpha=1,
kernel="linear",
gamma=None,
degree=3,
coef0=1,
kernel_params=None,
output_type=None,
handle=None,
verbose=False
):
super().__init__(handle=handle, verbose=verbose,
output_type=output_type)
self.alpha = cp.asarray(alpha)
self.kernel = kernel
self.gamma = gamma
self.degree = degree
self.coef0 = coef0
self.kernel_params = kernel_params

def get_param_names(self):
return super().get_param_names() + [
"alpha",
"kernel",
"gamma",
"degree",
"coef0",
"kernel_params",
]

def _get_kernel(self, X, Y=None):
if isinstance(self.kernel, str):
params = {"gamma": self.gamma,
"degree": self.degree, "coef0": self.coef0}
else:
params = self.kernel_params or {}
return pairwise_kernels(X, Y, metric=self.kernel,
filter_params=True, **params)

@generate_docstring()
def fit(self, X, y, sample_weight=None,
convert_dtype=True) -> "KernelRidge":

ravel = False
if len(y.shape) == 1:
y = y.reshape(-1, 1)
ravel = True

X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64]
)

y_m, _, _, _ = input_to_cuml_array(
y,
check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype else None),
check_rows=n_rows,
)

if self.n_cols < 1:
msg = "X matrix must have at least a column"
raise TypeError(msg)

K = self._get_kernel(X_m)
self.dual_coef_ = _solve_cholesky_kernel(
K, cp.asarray(y_m), self.alpha, sample_weight
)

if ravel:
self.dual_coef_ = self.dual_coef_.ravel()
self.X_fit_ = X_m
return self

def predict(self, X):
"""Predict using the kernel ridge model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples. If kernel == "precomputed" this is instead a
precomputed kernel matrix, shape = [n_samples,
n_samples_fitted], where n_samples_fitted is the number of
samples used in the fitting for this estimator.
Returns
-------
C : array of shape (n_samples,) or (n_samples, n_targets)
Returns predicted values.
"""
X_m, _, _, _ = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64])

K = self._get_kernel(X_m, self.X_fit_)
return cp.dot(cp.asarray(K), cp.asarray(self.dual_coef_))
4 changes: 3 additions & 1 deletion python/cuml/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,5 +36,7 @@
from cuml.metrics.pairwise_distances import sparse_pairwise_distances
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_METRICS
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_SPARSE_METRICS
from cuml.metrics.pairwise_kernels import pairwise_kernels
from cuml.metrics.pairwise_kernels import PAIRWISE_KERNEL_FUNCTIONS
from cuml.metrics.hinge_loss import hinge_loss
from cuml.metrics.kl_divergence import kl_divergence
Loading

0 comments on commit 9921c61

Please sign in to comment.