Skip to content

Commit

Permalink
Add detail namespace for linear models (#5107)
Browse files Browse the repository at this point in the history
Related to #5084.

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5107
  • Loading branch information
lowener authored Mar 29, 2023
1 parent 453ae3b commit 75dd597
Show file tree
Hide file tree
Showing 16 changed files with 179 additions and 218 deletions.
14 changes: 10 additions & 4 deletions cpp/include/cuml/linear_model/glm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ void gemmPredict(const raft::handle_t& handle,
* overwritten by final result.
* @param f host pointer holding the final objective value
* @param num_iters host pointer holding the actual number of iterations taken
* @param sample_weight
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr
for uniform weights)
* @param svr_eps epsilon parameter for svr
*/
template <typename T, typename I = int>
void qnFit(const raft::handle_t& cuml_handle,
Expand All @@ -163,7 +165,8 @@ void qnFit(const raft::handle_t& cuml_handle,
T* w0,
T* f,
int* num_iters,
T* sample_weight = nullptr);
T* sample_weight = nullptr,
T svr_eps = 0);

/**
* @brief Fit a GLM using quasi newton methods.
Expand All @@ -183,7 +186,9 @@ void qnFit(const raft::handle_t& cuml_handle,
* overwritten by final result.
* @param f host pointer holding the final objective value
* @param num_iters host pointer holding the actual number of iterations taken
* @param sample_weight
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr
for uniform weights)
* @param svr_eps epsilon parameter for svr
*/
template <typename T, typename I = int>
void qnFitSparse(const raft::handle_t& cuml_handle,
Expand All @@ -199,7 +204,8 @@ void qnFitSparse(const raft::handle_t& cuml_handle,
T* w0,
T* f,
int* num_iters,
T* sample_weight = nullptr);
T* sample_weight = nullptr,
T svr_eps = 0);

/**
* @brief Obtain the confidence scores of samples
Expand Down
201 changes: 85 additions & 116 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,17 @@ void olsFit(const raft::handle_t& handle,
int algo,
float* sample_weight)
{
olsFit(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
handle.get_stream(),
algo,
sample_weight);
detail::olsFit(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
algo,
sample_weight);
}

void olsFit(const raft::handle_t& handle,
Expand All @@ -64,18 +63,17 @@ void olsFit(const raft::handle_t& handle,
int algo,
double* sample_weight)
{
olsFit(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
handle.get_stream(),
algo,
sample_weight);
detail::olsFit(handle,
input,
n_rows,
n_cols,
labels,
coef,
intercept,
fit_intercept,
normalize,
algo,
sample_weight);
}

void gemmPredict(const raft::handle_t& handle,
Expand All @@ -86,7 +84,7 @@ void gemmPredict(const raft::handle_t& handle,
float intercept,
float* preds)
{
gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds, handle.get_stream());
detail::gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds);
}

void gemmPredict(const raft::handle_t& handle,
Expand All @@ -97,7 +95,7 @@ void gemmPredict(const raft::handle_t& handle,
double intercept,
double* preds)
{
gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds, handle.get_stream());
detail::gemmPredict(handle, input, n_rows, n_cols, coef, intercept, preds);
}

void ridgeFit(const raft::handle_t& handle,
Expand All @@ -114,20 +112,19 @@ void ridgeFit(const raft::handle_t& handle,
int algo,
float* sample_weight)
{
ridgeFit(handle,
input,
n_rows,
n_cols,
labels,
alpha,
n_alpha,
coef,
intercept,
fit_intercept,
normalize,
handle.get_stream(),
algo,
sample_weight);
detail::ridgeFit(handle,
input,
n_rows,
n_cols,
labels,
alpha,
n_alpha,
coef,
intercept,
fit_intercept,
normalize,
algo,
sample_weight);
}

void ridgeFit(const raft::handle_t& handle,
Expand All @@ -144,20 +141,19 @@ void ridgeFit(const raft::handle_t& handle,
int algo,
double* sample_weight)
{
ridgeFit(handle,
input,
n_rows,
n_cols,
labels,
alpha,
n_alpha,
coef,
intercept,
fit_intercept,
normalize,
handle.get_stream(),
algo,
sample_weight);
detail::ridgeFit(handle,
input,
n_rows,
n_cols,
labels,
alpha,
n_alpha,
coef,
intercept,
fit_intercept,
normalize,
algo,
sample_weight);
}

template <typename T, typename I>
Expand All @@ -172,21 +168,11 @@ void qnFit(const raft::handle_t& cuml_handle,
T* w0,
T* f,
int* num_iters,
T* sample_weight)
T* sample_weight,
T svr_eps)
{
qnFit<T>(cuml_handle,
pams,
X,
X_col_major,
y,
N,
D,
C,
w0,
f,
num_iters,
cuml_handle.get_stream(),
sample_weight);
detail::qnFit<T>(
cuml_handle, pams, X, X_col_major, y, N, D, C, w0, f, num_iters, sample_weight, svr_eps);
}

template void qnFit<float>(const raft::handle_t&,
Expand All @@ -200,7 +186,8 @@ template void qnFit<float>(const raft::handle_t&,
float*,
float*,
int*,
float*);
float*,
float);
template void qnFit<double>(const raft::handle_t&,
const qn_params&,
double*,
Expand All @@ -212,7 +199,8 @@ template void qnFit<double>(const raft::handle_t&,
double*,
double*,
int*,
double*);
double*,
double);

template <typename T, typename I>
void qnFitSparse(const raft::handle_t& cuml_handle,
Expand All @@ -228,23 +216,24 @@ void qnFitSparse(const raft::handle_t& cuml_handle,
T* w0,
T* f,
int* num_iters,
T* sample_weight)
T* sample_weight,
T svr_eps)
{
qnFitSparse<T>(cuml_handle,
pams,
X_values,
X_cols,
X_row_ids,
X_nnz,
y,
N,
D,
C,
w0,
f,
num_iters,
cuml_handle.get_stream(),
sample_weight);
detail::qnFitSparse<T>(cuml_handle,
pams,
X_values,
X_cols,
X_row_ids,
X_nnz,
y,
N,
D,
C,
w0,
f,
num_iters,
sample_weight,
svr_eps);
}

template void qnFitSparse<float>(const raft::handle_t&,
Expand All @@ -260,7 +249,8 @@ template void qnFitSparse<float>(const raft::handle_t&,
float*,
float*,
int*,
float*);
float*,
float);
template void qnFitSparse<double>(const raft::handle_t&,
const qn_params&,
double*,
Expand All @@ -274,7 +264,8 @@ template void qnFitSparse<double>(const raft::handle_t&,
double*,
double*,
int*,
double*);
double*,
double);

template <typename T, typename I>
void qnDecisionFunction(const raft::handle_t& cuml_handle,
Expand All @@ -287,8 +278,7 @@ void qnDecisionFunction(const raft::handle_t& cuml_handle,
T* params,
T* scores)
{
qnDecisionFunction<T>(
cuml_handle, pams, X, X_col_major, N, D, C, params, scores, cuml_handle.get_stream());
detail::qnDecisionFunction<T>(cuml_handle, pams, X, X_col_major, N, D, C, params, scores);
}

template void qnDecisionFunction<float>(
Expand All @@ -309,18 +299,8 @@ void qnDecisionFunctionSparse(const raft::handle_t& cuml_handle,
T* params,
T* scores)
{
qnDecisionFunctionSparse<T>(cuml_handle,
pams,
X_values,
X_cols,
X_row_ids,
X_nnz,
N,
D,
C,
params,
scores,
cuml_handle.get_stream());
detail::qnDecisionFunctionSparse<T>(
cuml_handle, pams, X_values, X_cols, X_row_ids, X_nnz, N, D, C, params, scores);
}

template void qnDecisionFunctionSparse<float>(
Expand Down Expand Up @@ -348,8 +328,7 @@ void qnPredict(const raft::handle_t& cuml_handle,
T* params,
T* scores)
{
qnPredict<T>(
cuml_handle, pams, X, X_col_major, N, D, C, params, scores, cuml_handle.get_stream());
detail::qnPredict<T>(cuml_handle, pams, X, X_col_major, N, D, C, params, scores);
}

template void qnPredict<float>(
Expand All @@ -370,18 +349,8 @@ void qnPredictSparse(const raft::handle_t& cuml_handle,
T* params,
T* preds)
{
qnPredictSparse<T>(cuml_handle,
pams,
X_values,
X_cols,
X_row_ids,
X_nnz,
N,
D,
C,
params,
preds,
cuml_handle.get_stream());
detail::qnPredictSparse<T>(
cuml_handle, pams, X_values, X_cols, X_row_ids, X_nnz, N, D, C, params, preds);
}

template void qnPredictSparse<float>(
Expand Down
Loading

0 comments on commit 75dd597

Please sign in to comment.