Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix all deprecated function calls in TUs where warnings are errors #5692

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/src/arima/batched_arima.cu
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ __global__ void sum_of_squares_kernel(const DataT* d_y,
// Compute log-likelihood and write it to global memory
if (threadIdx.x == 0) {
d_loglike[blockIdx.x] =
-0.5 * static_cast<DataT>(n_obs) * raft::myLog(ssq / static_cast<DataT>(n_obs - start_sum));
-0.5 * static_cast<DataT>(n_obs) * raft::log(ssq / static_cast<DataT>(n_obs - start_sum));
}
}

Expand Down Expand Up @@ -1000,4 +1000,4 @@ void estimate_x0(raft::handle_t& handle,
_start_params(handle, params, bm_yd, bm_exog_diff, order);
}

} // namespace ML
} // namespace ML
4 changes: 2 additions & 2 deletions cpp/src/dbscan/vertexdeg/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void launcher(const raft::handle_t& handle,
stream,
false,
[] __device__(bool adj_ij, index_t idx) { return static_cast<index_t>(adj_ij); },
raft::Sum<index_t>(),
raft::add_op(),
[d_nnz] __device__(index_t degree) {
atomicAdd(d_nnz, degree);
return degree;
Expand All @@ -143,7 +143,7 @@ void launcher(const raft::handle_t& handle,
[sample_weight] __device__(bool adj_ij, index_t j) {
return adj_ij ? sample_weight[j] : (value_t)0;
},
raft::Sum<value_t>());
raft::add_op());
RAFT_CUDA_TRY(cudaPeekAtLastError());
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/dbscan/vertexdeg/precomputed.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void launcher(const raft::handle_t& handle,
stream,
false,
[] __device__(bool adj_ij, long_index_t idx) { return static_cast<index_t>(adj_ij); },
raft::Sum<index_t>(),
raft::add_op(),
[d_nnz] __device__(index_t degree) {
atomicAdd(d_nnz, degree);
return degree;
Expand All @@ -96,7 +96,7 @@ void launcher(const raft::handle_t& handle,
[sample_weight] __device__(bool adj_ij, long_index_t j) {
return adj_ij ? sample_weight[j] : (value_t)0;
},
raft::Sum<value_t>());
raft::add_op());
RAFT_CUDA_TRY(cudaPeekAtLastError());
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/decisiontree/batched-levelalgo/builder.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ struct Builder {
// unique samples from 'n' is given by the following equation: log(1 - k/n)/log(1 - 1/n) ref:
// https://stats.stackexchange.com/questions/296005/the-expected-number-of-unique-elements-drawn-with-replacement
IdxT n_parallel_samples =
std::ceil(raft::myLog(1 - double(dataset.n_sampled_cols) / double(dataset.N)) /
(raft::myLog(1 - 1.f / double(dataset.N))));
std::ceil(raft::log(1 - double(dataset.n_sampled_cols) / double(dataset.N)) /
(raft::log(1 - 1.f / double(dataset.N))));
// maximum sampling work possible by all threads in a block :
// `max_samples_per_thread * block_thread`
// dynamically calculated sampling work to be done per block:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ __global__ void algo_L_sample_kernel(int* colids,
IdxT int_uniform_val;
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
double W = raft::myExp(raft::myLog(fp_uniform_val) / k);
double W = raft::exp(raft::log(fp_uniform_val) / k);

size_t col(0);
// initially fill the reservoir array in increasing order of cols till k
Expand All @@ -295,14 +295,14 @@ __global__ void algo_L_sample_kernel(int* colids,
while (col < n) {
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
col += static_cast<int>(raft::myLog(fp_uniform_val) / raft::myLog(1 - W)) + 1;
col += static_cast<int>(raft::log(fp_uniform_val) / raft::log(1 - W)) + 1;
if (col < n) {
// int_uniform_val will now have a random value between 0...k
raft::random::custom_next(gen, &int_uniform_val, uniform_int_dist_params, IdxT(0), IdxT(0));
colids[tid * k + int_uniform_val] = col; // the bad memory coalescing here is hidden
// fp_uniform_val will have a random value between 0 and 1
gen.next(fp_uniform_val);
W *= raft::myExp(raft::myLog(fp_uniform_val) / k);
W *= raft::exp(raft::log(fp_uniform_val) / k);
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions cpp/src/decisiontree/batched-levelalgo/objectives.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,21 @@ class EntropyObjectiveFunction {
auto lval_i = hist[n_bins * c + i].x;
if (lval_i != 0) {
auto lval = DataT(lval_i);
gain += raft::myLog(lval * invLeft) / raft::myLog(DataT(2)) * lval * invLen;
gain += raft::log(lval * invLeft) / raft::log(DataT(2)) * lval * invLen;
}

val_i += lval_i;
auto total_sum = hist[n_bins * c + n_bins - 1].x;
auto rval_i = total_sum - lval_i;
if (rval_i != 0) {
auto rval = DataT(rval_i);
gain += raft::myLog(rval * invRight) / raft::myLog(DataT(2)) * rval * invLen;
gain += raft::log(rval * invRight) / raft::log(DataT(2)) * rval * invLen;
}

val_i += rval_i;
if (val_i != 0) {
auto val = DataT(val_i) * invLen;
gain -= val * raft::myLog(val) / raft::myLog(DataT(2));
gain -= val * raft::log(val) / raft::log(DataT(2));
}
}

Expand Down Expand Up @@ -313,9 +313,9 @@ class PoissonObjectiveFunction {
return -std::numeric_limits<DataT>::max();

// compute the gain to be
DataT parent_obj = -label_sum * raft::myLog(label_sum * invLen);
DataT left_obj = -left_label_sum * raft::myLog(left_label_sum / nLeft);
DataT right_obj = -right_label_sum * raft::myLog(right_label_sum / nRight);
DataT parent_obj = -label_sum * raft::log(label_sum * invLen);
DataT left_obj = -left_label_sum * raft::log(left_label_sum / nLeft);
DataT right_obj = -right_label_sum * raft::log(right_label_sum / nRight);
DataT gain = parent_obj - (left_obj + right_obj);
gain = gain * invLen;

Expand Down Expand Up @@ -392,9 +392,9 @@ class GammaObjectiveFunction {
return -std::numeric_limits<DataT>::max();

// compute the gain to be
DataT parent_obj = len * raft::myLog(label_sum * invLen);
DataT left_obj = nLeft * raft::myLog(left_label_sum / nLeft);
DataT right_obj = nRight * raft::myLog(right_label_sum / nRight);
DataT parent_obj = len * raft::log(label_sum * invLen);
DataT left_obj = nLeft * raft::log(left_label_sum / nLeft);
DataT right_obj = nRight * raft::log(right_label_sum / nRight);
DataT gain = parent_obj - (left_obj + right_obj);
gain = gain * invLen;

Expand Down
12 changes: 5 additions & 7 deletions cpp/src/genetic/fitness.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ void weightedPearson(const raft::handle_t& h,
stream,
false,
[W] __device__(math_t v, int i) { return v * v * W[i]; },
raft::Sum<math_t>(),
[] __device__(math_t in) { return raft::mySqrt(in); });
raft::add_op(),
[] __device__(math_t in) { return raft::sqrt(in); });
math_t HYstd = y_std.element(0, stream);

// Find x_std
Expand All @@ -140,8 +140,8 @@ void weightedPearson(const raft::handle_t& h,
stream,
false,
[W] __device__(math_t v, int i) { return v * v * W[i]; },
raft::Sum<math_t>(),
[] __device__(math_t in) { return raft::mySqrt(in); });
raft::add_op(),
[] __device__(math_t in) { return raft::sqrt(in); });

// Cross covariance
raft::linalg::matrixVectorOp(
Expand Down Expand Up @@ -273,9 +273,7 @@ void meanAbsoluteError(const raft::handle_t& h,
n_samples,
false,
false,
[N, WS] __device__(math_t y_p, math_t y, math_t w) {
return N * w * raft::myAbs(y - y_p) / WS;
},
[N, WS] __device__(math_t y_p, math_t y, math_t w) { return N * w * raft::abs(y - y_p) / WS; },
stream);

// Average along rows
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/glm/preprocess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void preProcessData(const raft::handle_t& handle,
norm2_input,
norm2_input,
n_cols,
[] __device__(math_t v) { return raft::mySqrt(v); },
[] __device__(math_t v) { return raft::sqrt(v); },
stream);
raft::matrix::linewiseOp(
input,
Expand Down Expand Up @@ -105,7 +105,7 @@ void preProcessData(const raft::handle_t& handle,
raft::linalg::L2Norm,
false,
stream,
[] __device__(math_t v) { return raft::mySqrt(v); });
[] __device__(math_t v) { return raft::sqrt(v); });
raft::matrix::matrixVectorBinaryDivSkipZero(
input, norm2_input, n_rows, n_cols, false, true, stream, true);
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/glm/qn/glm_linear.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct AbsLoss : GLMBase<T, AbsLoss<T>> {
typedef GLMBase<T, AbsLoss<T>> Super;

const struct Lz {
inline __device__ T operator()(const T y, const T z) const { return raft::myAbs<T>(z - y); }
inline __device__ T operator()(const T y, const T z) const { return raft::abs<T>(z - y); }
} lz;

const struct Dlz {
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/glm/qn/glm_logistic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct LogisticLoss : GLMBase<T, LogisticLoss<T>> {
inline __device__ T log_sigmoid(const T x) const
{
// To avoid floating point overflow in the exp function
T temp = raft::myLog(1 + raft::myExp(x < 0 ? x : -x));
T temp = raft::log(1 + raft::exp(x < 0 ? x : -x));
return x < 0 ? x - temp : -temp;
}

Expand All @@ -48,7 +48,7 @@ struct LogisticLoss : GLMBase<T, LogisticLoss<T>> {
inline __device__ T operator()(const T y, const T z) const
{
// To avoid fp overflow with exp(z) when abs(z) is large
T ez = raft::myExp(z < 0 ? z : -z);
T ez = raft::exp(z < 0 ? z : -z);
T numerator = z < 0 ? ez : T(1.0);
return numerator / (T(1.0) + ez) - y;
}
Expand Down
26 changes: 11 additions & 15 deletions cpp/src/glm/qn/glm_softmax.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
* Copyright (c) 2018-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,10 +24,6 @@
namespace ML {
namespace GLM {
namespace detail {
using raft::ceildiv;
using raft::myExp;
using raft::myLog;
using raft::myMax;

// Input: matrix Z (dims: CxN)
// Computes softmax cross entropy loss across columns, i.e. normalization
Expand Down Expand Up @@ -84,7 +80,7 @@ __global__ void logSoftmaxKernel(
delta = true;
eta_y = myEta;
}
etaMax = myMax<T>(myEta, etaMax);
etaMax = raft::max<T>(myEta, etaMax);
}
}
T tmpMax = WarpRed(shm.warpStore[threadIdx.y]).Reduce(etaMax, cub::Max());
Expand All @@ -100,15 +96,15 @@ __global__ void logSoftmaxKernel(
// TODO there must be a better way to do this...
if (C <= BX) { // this means one block covers a column and myEta is valid
int idx = threadIdx.x + y * C;
if (threadIdx.x < C && idx < len) { lse = myExp<T>(myEta - etaMax); }
if (threadIdx.x < C && idx < len) { lse = raft::exp<T>(myEta - etaMax); }
} else {
for (int x = threadIdx.x; x < C; x += BX) {
int idx = x + y * C;
if (x < C && idx < len) { lse += myExp<T>(in[idx] - etaMax); }
if (x < C && idx < len) { lse += raft::exp<T>(in[idx] - etaMax); }
}
}
T tmpLse = WarpRed(shm.warpStore[threadIdx.y]).Sum(lse);
if (threadIdx.x == 0) { shm.sh_val[threadIdx.y] = etaMax + myLog<T>(tmpLse); }
if (threadIdx.x == 0) { shm.sh_val[threadIdx.y] = etaMax + raft::log<T>(tmpLse); }
__syncthreads();
lse = shm.sh_val[threadIdx.y];
__syncthreads();
Expand All @@ -123,14 +119,14 @@ __global__ void logSoftmaxKernel(
if (C <= BX) { // this means one block covers a column and myEta is valid
int idx = threadIdx.x + y * C;
if (threadIdx.x < C && idx < len) {
dZ[idx] = (myExp<T>(myEta - lse) - (getDerivative ? (threadIdx.x == label) : T(0)));
dZ[idx] = (raft::exp<T>(myEta - lse) - (getDerivative ? (threadIdx.x == label) : T(0)));
}
} else {
for (int x = threadIdx.x; x < C; x += BX) {
int idx = x + y * C;
if (x < C && idx < len) {
T logP = in[idx] - lse;
dZ[idx] = (myExp<T>(logP) - (getDerivative ? (x == label) : T(0)));
dZ[idx] = (raft::exp<T>(logP) - (getDerivative ? (x == label) : T(0)));
}
}
}
Expand All @@ -156,19 +152,19 @@ void launchLogsoftmax(
raft::interruptible::synchronize(stream);
if (C <= 4) {
dim3 bs(4, 64);
dim3 gs(ceildiv(N, 64));
dim3 gs(raft::ceildiv(N, 64));
logSoftmaxKernel<T, 4, 64><<<gs, bs, 0, stream>>>(loss_val, dldZ, Z, labels, C, N);
} else if (C <= 8) {
dim3 bs(8, 32);
dim3 gs(ceildiv(N, 32));
dim3 gs(raft::ceildiv(N, 32));
logSoftmaxKernel<T, 8, 32><<<gs, bs, 0, stream>>>(loss_val, dldZ, Z, labels, C, N);
} else if (C <= 16) {
dim3 bs(16, 16);
dim3 gs(ceildiv(N, 16));
dim3 gs(raft::ceildiv(N, 16));
logSoftmaxKernel<T, 16, 16><<<gs, bs, 0, stream>>>(loss_val, dldZ, Z, labels, C, N);
} else {
dim3 bs(32, 8);
dim3 gs(ceildiv(N, 8));
dim3 gs(raft::ceildiv(N, 8));
logSoftmaxKernel<T, 32, 8><<<gs, bs, 0, stream>>>(loss_val, dldZ, Z, labels, C, N);
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/glm/qn/glm_svm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct SVCL1Loss : GLMBase<T, SVCL1Loss<T>> {
inline __device__ T operator()(const T y, const T z) const
{
T s = 2 * y - 1;
return raft::myMax<T>(0, 1 - s * z);
return raft::max<T>(0, 1 - s * z);
}
} lz;

Expand Down Expand Up @@ -64,7 +64,7 @@ struct SVCL2Loss : GLMBase<T, SVCL2Loss<T>> {
inline __device__ T operator()(const T y, const T z) const
{
T s = 2 * y - 1;
T t = raft::myMax<T>(0, 1 - s * z);
T t = raft::max<T>(0, 1 - s * z);
return t * t;
}
} lz;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/glm/qn/simple_mat/dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ inline T squaredNorm(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
template <typename T>
inline T nrmMax(const SimpleVec<T>& u, T* tmp_dev, cudaStream_t stream)
{
auto f = [] __device__(const T x) { return raft::myAbs<T>(x); };
auto r = [] __device__(const T x, const T y) { return raft::myMax<T>(x, y); };
auto f = [] __device__(const T x) { return raft::abs<T>(x); };
auto r = [] __device__(const T x, const T y) { return raft::max<T>(x, y); };
raft::linalg::mapThenReduce(tmp_dev, u.len, T(0), f, r, stream, u.data);
T tmp_host;
raft::update_host(&tmp_host, tmp_dev, 1, stream);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/solver/cd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ __global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc,
auto r = coef > l1_alpha ? coef - l1_alpha : (coef < -l1_alpha ? coef + l1_alpha : 0);
auto squared = *squaredLoc;
r = squared > math_t(1e-5) ? r / squared : math_t(0);
auto diff = raft::myAbs(convStateLoc->coef - r);
auto diff = raft::abs(convStateLoc->coef - r);
if (convStateLoc->diffMax < diff) convStateLoc->diffMax = diff;
auto absv = raft::myAbs(r);
auto absv = raft::abs(r);
if (convStateLoc->coefMax < absv) convStateLoc->coefMax = absv;
convStateLoc->coef = -r;
*coefLoc = r;
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/svm/linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ __global__ void predictProba(T* out, const T* z, const int nRows, const int nCla
int j = threadIdx.x;
if constexpr (Binary) {
t = rowIn[0];
maxVal = raft::myMax<T>(t, 0);
maxVal = raft::max<T>(t, T{0});
t = T(j) * t; // set z[0] = 0, z[1] = t
} else {
for (; j < nClasses; j += BX) {
t = rowIn[j];
maxVal = raft::myMax<T>(maxVal, t);
maxVal = raft::max<T>(maxVal, t);
}
j -= BX;
maxVal = WarpRed(warpStore).Reduce(maxVal, cub::Max());
Expand All @@ -164,7 +164,7 @@ __global__ void predictProba(T* out, const T* z, const int nRows, const int nCla
T et; // Numerator of the softmax.
T smSum = 0; // Denominator of the softmax.
while (j >= 0) {
et = raft::myExp<T>(t - maxVal);
et = raft::exp<T>(t - maxVal);
smSum += et;
if (j < BX) break;
j -= BX;
Expand All @@ -178,13 +178,13 @@ __global__ void predictProba(T* out, const T* z, const int nRows, const int nCla
// Traverse in the forward direction again to save the results.
// Note, no extra memory reads when BX >= nClasses!
if (j < 0) return;
T d = Log ? -maxVal - raft::myLog<T>(smSum) : 1 / smSum;
T d = Log ? -maxVal - raft::log<T>(smSum) : 1 / smSum;
while (j < nClasses) {
rowOut[j] = Log ? t + d : et * d;
j += BX;
if (j >= nClasses) break;
t = rowIn[j];
if constexpr (!Log) et = raft::myExp<T>(t - maxVal);
if constexpr (!Log) et = raft::exp<T>(t - maxVal);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src_prims/functions/log.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename T, typename IdxType = int>
void f_log(T* out, T* in, T scalar, IdxType len, cudaStream_t stream)
{
raft::linalg::unaryOp(
out, in, len, [scalar] __device__(T in) { return raft::myLog(in) * scalar; }, stream);
out, in, len, [scalar] __device__(T in) { return raft::log(in) * scalar; }, stream);
}

}; // end namespace Functions
Expand Down
Loading