Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of elementwise CUDA kernel #30801

Merged
merged 8 commits into from
Mar 10, 2021
21 changes: 14 additions & 7 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ inline void get_mid_dims(const framework::DDim &x_dims,
(*post) *= x_dims[i];
}
}

inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim,
const int *index_array) {
int index_ = 0;
Expand Down Expand Up @@ -202,12 +203,16 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename Functor, typename T, typename OutType>
__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre,
int n, int post, int total, Functor func) {
__global__ void ElementwiseKernel(const T *__restrict__ x_data,
const T *__restrict__ y_data,
OutType *__restrict__ out_data, int n,
int post, const size_t total, Functor func) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int idx = tid / post % n;
if (tid < total) {
out[tid] = func(x[tid], y[idx]);
int stride = blockDim.x * gridDim.x;

for (int i = tid; i < total; i += stride) {
int idx = i / post % n;
out_data[i] = func(x_data[i], y_data[idx]);
}
}

Expand All @@ -224,14 +229,16 @@ void ComputeElementwiseCUDA(const framework::Tensor *x,
int numel = pre * n * post;
int threads = 256;
int blocks = (numel + threads - 1) / threads;

if (is_xsize_larger) {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
x_data, y_data, out_data, pre, n, post, numel, func);
x_data, y_data, out_data, n, post, numel, func);

} else {
ElementwiseKernel<Functor, T,
OutType><<<blocks, threads, 0, ctx.stream()>>>(
y_data, x_data, out_data, pre, n, post, numel, func);
y_data, x_data, out_data, n, post, numel, func);
}
}

Expand Down