Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LinearSVM using QN solvers #4268

Merged
merged 55 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a7e9756
[WIP] Linear SVM via QN solver
achirkin Sep 30, 2021
f9d0b51
Implemented SVRXLoss gradient kernels
achirkin Sep 30, 2021
7454ac5
Merge branch 'branch-21.12' into fea-linear-svm
achirkin Oct 4, 2021
e2a29d3
Allow a stateful functor struct in a __device__ loss function
achirkin Oct 5, 2021
f99cdf5
Minimal SVR/SVC wrappers
achirkin Oct 5, 2021
30c244a
Proper pickling and param constraints for SVC/SVR
achirkin Oct 6, 2021
3589cbc
Refactor LinearSVM to support the standard ML hierarchy api
achirkin Oct 12, 2021
3b09274
[wip] added probabolistic and multiclass (ovr) fit/predict
achirkin Oct 14, 2021
26d9141
Fixed multiclass support bugs - seems to be working correctly now.
achirkin Oct 18, 2021
b6a5bb0
Probabolistic calibration - almost done.
achirkin Oct 19, 2021
1c0cdc6
WIP refactor the lifetime of C++-side model and arrays
achirkin Oct 23, 2021
b699e47
Merge branch 'branch-21.12' into fea-linear-svm
achirkin Oct 23, 2021
922be91
Remove printfs
achirkin Oct 23, 2021
46014d2
Make sure the estimator behaves well under pickling/cloning
achirkin Oct 23, 2021
e6ad562
Disable probability scale for now
achirkin Oct 23, 2021
06a8a77
Some cleanup and documentation
achirkin Oct 23, 2021
04b3826
Split svm/linear into three files
achirkin Oct 25, 2021
2bdba98
Remove unnecessary default in LinearSVC
achirkin Oct 25, 2021
7453d18
Remove an unnecessary step in the loss function
achirkin Nov 1, 2021
298e087
Initialize QN solver with zeros, because more complex preconditioners…
achirkin Nov 1, 2021
2fd4f6c
Remove unnecessary *args
achirkin Nov 1, 2021
2d7aa08
Add LinearSVM tests
achirkin Nov 1, 2021
01cf01c
Introduce an enum to represent QN loss types
achirkin Nov 1, 2021
7ee2c26
Add a comment on glm_base complications
achirkin Nov 1, 2021
15bf1d5
Add argument comments
achirkin Nov 2, 2021
8ec9c83
Add python tests
achirkin Nov 3, 2021
8dd4225
Make base doctests happy with LinearSVM
achirkin Nov 4, 2021
0052419
Add comments regarding the location of default parameter values
achirkin Nov 4, 2021
da40a82
Try to reduce the test time
achirkin Nov 4, 2021
36a747e
Merge branch 'branch-21.12' into fea-linear-svm
achirkin Nov 5, 2021
6793c0e
Use raft helpers in place of raw cuda
achirkin Nov 5, 2021
0e2bcd2
Add some documentation to the C++ code
achirkin Nov 5, 2021
bf36b0f
Test >32 classes
achirkin Nov 5, 2021
ec4d3dc
Run clasification training in parallel in OVR setting
achirkin Nov 5, 2021
aef1e7a
Add a proper support of probabolistic output
achirkin Nov 6, 2021
b5cab0c
Refactor for the stateless api
achirkin Nov 7, 2021
f161e92
Merge branch 'branch-21.12' into fea-linear-svm
achirkin Nov 7, 2021
92af99e
Remove verbose output from a test
achirkin Nov 7, 2021
bf82e0b
Use OpenMP instead of C++ threads
achirkin Nov 7, 2021
7806c2a
Update glm_svm.cuh
achirkin Nov 7, 2021
41a91e0
Update linear.pyx
achirkin Nov 7, 2021
e250e1e
Use size_t for input dimensions
achirkin Nov 10, 2021
40d02c9
Force glm/QN loss functions be functors
achirkin Nov 10, 2021
49a8e27
Merge branch 'branch-21.12' into fea-linear-svm
achirkin Nov 15, 2021
53ecbad
Address more cpp comments
achirkin Nov 15, 2021
4b8a5d2
Address more python comments
achirkin Nov 15, 2021
8c8fc0c
Adapt to changes in raft
achirkin Nov 15, 2021
43f4e53
Even more docs and some test fixes
achirkin Nov 15, 2021
8bbfe89
A few more cosmetical changes.
achirkin Nov 16, 2021
3735548
Handling of multi_class and multiclass-strategy
achirkin Nov 16, 2021
3ba8a8f
Terminate sklearn if it takes too long to run in the tests
achirkin Nov 16, 2021
f2d6e65
Log time and force stream sync in tests
achirkin Nov 17, 2021
41b993b
Manage streams and openmp
achirkin Nov 17, 2021
6c34a02
Fix sphinx docs
achirkin Nov 17, 2021
ea0606c
Run sklearn in the current process if process forking is not available.
achirkin Nov 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ if(BUILD_CUML_CPP_LIBRARY)
src/spectral/spectral.cu
src/svm/svc.cu
src/svm/svr.cu
src/svm/linear.cu
src/svm/ws_util.cu
src/tsa/auto_arima.cu
src/tsa/stationarity.cu
Expand Down
188 changes: 188 additions & 0 deletions cpp/include/cuml/svm/linear.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/handle.hpp>

namespace ML {
namespace SVM {

struct LinearSVMParams {
/** The regularization term. */
enum Penalty {
/** Abs. value of the weights: `sum |w|` */
L1,
/** Squared value of the weights: `sum w^2` */
L2
};
/** The loss function. */
enum Loss {
/** `max(1 - y_i x_i w, 0)` */
HINGE,
/** `max(1 - y_i x_i w, 0)^2` */
SQUARED_HINGE,
/** `max(|y_i - x_i w| - epsilon, 0)` */
EPSILON_INSENSITIVE,
/** `max(|y_i - x_i w| - epsilon, 0)^2` */
SQUARED_EPSILON_INSENSITIVE
};

/** The regularization term. */
Penalty penalty = L2;
/** The loss function. */
Loss loss = HINGE;
/** Whether to fit the bias term. */
bool fit_intercept = true;
/** When true, the bias term is treated the same way as other data features.
* Enabling this feature forces an extra copying the input data X.
*/
bool penalized_intercept = false;
/** Whether to estimate probabilities using Platt scaling (applicable to SVC). */
bool probability = false;
/** Maximum number of iterations for the underlying QN solver. */
int max_iter = 1000;
/**
* Maximum number of linesearch (inner loop) iterations for the underlying QN solver.
*/
int linesearch_max_iter = 100;
/**
* Number of vectors approximating the hessian for the underlying QN solver (l-bfgs).
*/
int lbfgs_memory = 5;
/** Triggers extra output when greater than zero. */
int verbose = 0;
/**
* The constant scaling factor of the main term in the loss function.
* (You can also think of that as the inverse factor of the penalty term).
*/
double C = 1.0;
/** The threshold on the gradient for the underlying QN solver. */
double grad_tol = 0.0001;
/** The threshold on the function change for the underlying QN solver. */
double change_tol = 0.00001;
/** The epsilon-sensitivity parameter (applicable to the SVM-regression (SVR) loss functions). */
double epsilon = 0.0;
};

template <typename T>
struct LinearSVMModel {
/**
* C-style (row-major) matrix of coefficients of size `(coefRows, coefCols)`
* where
* coefRows = nCols + (params.fit_intercept ? 1 : 0)
* coefCols = nClasses == 2 ? 1 : nClasses
*/
T* w;
/** Sorted, unique values of input array `y`. */
T* classes = nullptr;
/**
* C-style (row-major) matrix of the probabolistic model calibration coefficients.
* It's empty if `LinearSVMParams.probability == false`.
* Otherwise, it's size is `(2, coefCols)`.
* where
* coefCols = nClasses == 2 ? 1 : nClasses
*/
T* probScale = nullptr;
/** Number of classes (not applicable for regression). */
std::size_t nClasses = 0;
/** Number of rows of `w`, which is the number of data features plus maybe bias. */
std::size_t coefRows;

/** It's 1 for binary classification or regression; nClasses for multiclass. */
inline std::size_t coefCols() const { return nClasses <= 2 ? 1 : nClasses; }

/**
* @brief Allocate and fit the LinearSVM model.
*
* @param [in] handle the cuML handle.
* @param [in] params the model parameters.
* @param [in] X the input data matrix of size (nRows, nCols) in column-major format.
* @param [in] nRows the number of input samples.
* @param [in] nCols the number of feature dimensions.
* @param [in] y the target - a single vector of either real (regression) or
* categorical (classification) values (nRows, ).
* @param [in] sampleWeight the non-negative weights for the training sample (nRows, ).
* @return the trained model (don't forget to call `free` on it after use).
*/
static LinearSVMModel<T> fit(const raft::handle_t& handle,
const LinearSVMParams& params,
const T* X,
const std::size_t nRows,
const std::size_t nCols,
const T* y,
const T* sampleWeight);

/**
* @brief Explicitly allocate the data for the model without training it.
*
* @param [in] handle the cuML handle.
* @param [in] params the model parameters.
* @param [in] nCols the number of feature dimensions.
* @param [in] nClasses the number of classes in the dataset (not applicable for regression).
* @return the trained model (don't forget to call `free` on it after use).
*/
static LinearSVMModel<T> allocate(const raft::handle_t& handle,
const LinearSVMParams& params,
const std::size_t nCols,
const std::size_t nClasses = 0);

/** @brief Free the allocated memory. The model is not usable after the call of this method. */
static void free(const raft::handle_t& handle, LinearSVMModel<T>& model);

/**
* @brief Predict using the trained LinearSVM model.
*
* @param [in] handle the cuML handle.
* @param [in] params the model parameters.
* @param [in] model the trained model.
* @param [in] X the input data matrix of size (nRows, nCols) in column-major format.
* @param [in] nRows the number of input samples.
* @param [in] nCols the number of feature dimensions.
* @param [out] out the predictions (nRows, ).
*/
static void predict(const raft::handle_t& handle,
const LinearSVMParams& params,
const LinearSVMModel<T>& model,
const T* X,
const std::size_t nRows,
const std::size_t nCols,
T* out);

/**
* @brief For SVC, predict the probabilities for each outcome.
*
* @param [in] handle the cuML handle.
* @param [in] params the model parameters.
* @param [in] model the trained model.
* @param [in] X the input data matrix of size (nRows, nCols) in column-major format.
* @param [in] nRows the number of input samples.
* @param [in] nCols the number of feature dimensions.
* @param [in] log whether to output log-probabilities instead of probabilities.
* @param [out] out the estimated probabilities (nRows, nClasses) in row-major format.
*/
static void predictProba(const raft::handle_t& handle,
const LinearSVMParams& params,
const LinearSVMModel<T>& model,
const T* X,
const std::size_t nRows,
const std::size_t nCols,
const bool log,
T* out);
};

} // namespace SVM
} // namespace ML
24 changes: 12 additions & 12 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ void qnFit(const raft::handle_t& cuml_handle,
w0,
f,
num_iters,
loss_type,
(QN_LOSS_TYPE)loss_type,
cuml_handle.get_stream(),
sample_weight);
}
Expand Down Expand Up @@ -241,7 +241,7 @@ void qnFit(const raft::handle_t& cuml_handle,
w0,
f,
num_iters,
loss_type,
(QN_LOSS_TYPE)loss_type,
cuml_handle.get_stream(),
sample_weight);
}
Expand Down Expand Up @@ -291,7 +291,7 @@ void qnFitSparse(const raft::handle_t& cuml_handle,
w0,
f,
num_iters,
loss_type,
(QN_LOSS_TYPE)loss_type,
cuml_handle.get_stream(),
sample_weight);
}
Expand Down Expand Up @@ -341,7 +341,7 @@ void qnFitSparse(const raft::handle_t& cuml_handle,
w0,
f,
num_iters,
loss_type,
(QN_LOSS_TYPE)loss_type,
cuml_handle.get_stream(),
sample_weight);
}
Expand All @@ -365,7 +365,7 @@ void qnDecisionFunction(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
preds,
cuml_handle.get_stream());
}
Expand All @@ -389,7 +389,7 @@ void qnDecisionFunction(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
scores,
cuml_handle.get_stream());
}
Expand Down Expand Up @@ -417,7 +417,7 @@ void qnDecisionFunctionSparse(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
scores,
cuml_handle.get_stream());
}
Expand Down Expand Up @@ -445,7 +445,7 @@ void qnDecisionFunctionSparse(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
scores,
cuml_handle.get_stream());
}
Expand All @@ -469,7 +469,7 @@ void qnPredict(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
scores,
cuml_handle.get_stream());
}
Expand All @@ -493,7 +493,7 @@ void qnPredict(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
preds,
cuml_handle.get_stream());
}
Expand Down Expand Up @@ -521,7 +521,7 @@ void qnPredictSparse(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
preds,
cuml_handle.get_stream());
}
Expand Down Expand Up @@ -549,7 +549,7 @@ void qnPredictSparse(const raft::handle_t& cuml_handle,
C,
fit_intercept,
params,
loss_type,
(QN_LOSS_TYPE)loss_type,
preds,
cuml_handle.get_stream());
}
Expand Down
51 changes: 35 additions & 16 deletions cpp/src/glm/qn/glm_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -128,36 +128,55 @@ struct GLMBase : GLMDims {
* 2. loss_val <- sum loss(Z)
*
* Default: elementwise application of loss and its derivative
*
* NB: for this method to work, loss implementations must have two functor fields `lz` and `dlz`.
* These two compute loss value and its derivative w.r.t. `z`.
*/
inline void getLossAndDZ(T* loss_val,
SimpleDenseMat<T>& Z,
const SimpleVec<T>& y,
cudaStream_t stream)
{
// Base impl assumes simple case C = 1
Loss* loss = static_cast<Loss*>(this);

// TODO would be nice to have a kernel that fuses these two steps
// This would be easy, if mapThenSumReduce allowed outputing the result of
// map (supporting inplace)
auto lz_copy = static_cast<Loss*>(this)->lz;
auto dlz_copy = static_cast<Loss*>(this)->dlz;
if (this->sample_weights) { // Sample weights are in use
T normalization = 1.0 / this->weights_sum;
auto f_l = [=] __device__(const T y, const T z, const T weight) {
return loss->lz(y, z) * (weight * normalization);
};
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data, Z.data, sample_weights);

auto f_dl = [=] __device__(const T y, const T z, const T weight) {
return weight * loss->dlz(y, z);
};
raft::linalg::map(Z.data, y.len, f_dl, stream, y.data, Z.data, sample_weights);
raft::linalg::mapThenSumReduce(
loss_val,
y.len,
[lz_copy, normalization] __device__(const T y, const T z, const T weight) {
return lz_copy(y, z) * (weight * normalization);
},
stream,
y.data,
Z.data,
sample_weights);
raft::linalg::map(
Z.data,
y.len,
[dlz_copy] __device__(const T y, const T z, const T weight) {
return weight * dlz_copy(y, z);
},
stream,
y.data,
Z.data,
sample_weights);
} else { // Sample weights are not used
T normalization = 1.0 / y.len;
auto f_l = [=] __device__(const T y, const T z) { return loss->lz(y, z) * normalization; };
raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data, Z.data);

auto f_dl = [=] __device__(const T y, const T z) { return loss->dlz(y, z); };
raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, f_dl, stream);
raft::linalg::mapThenSumReduce(
loss_val,
y.len,
[lz_copy, normalization] __device__(const T y, const T z) {
return lz_copy(y, z) * normalization;
},
stream,
y.data,
Z.data);
raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, dlz_copy, stream);
}
}

Expand Down
Loading