Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ElementwiseTernary, Reduce, ReadDataStride #35075

Merged
merged 27 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7d58b91
Merge pull request #1 from PaddlePaddle/develop
AnnaTrainingG Mar 25, 2021
1021e08
Merge pull request #2 from PaddlePaddle/develop
AnnaTrainingG Mar 29, 2021
43f53fe
Merge pull request #3 from PaddlePaddle/develop
AnnaTrainingG Apr 19, 2021
d25ab26
Merge pull request #4 from PaddlePaddle/develop
AnnaTrainingG May 7, 2021
8c8717f
Merge pull request #5 from PaddlePaddle/develop
AnnaTrainingG May 25, 2021
9ddf5e8
Merge pull request #6 from PaddlePaddle/develop
AnnaTrainingG May 26, 2021
b0cbcca
Merge pull request #9 from PaddlePaddle/develop
AnnaTrainingG Jun 1, 2021
cdecaf0
Merge pull request #14 from PaddlePaddle/develop
AnnaTrainingG Jun 11, 2021
0da14c9
Merge pull request #16 from PaddlePaddle/develop
AnnaTrainingG Jun 15, 2021
ca95763
Merge pull request #17 from PaddlePaddle/develop
AnnaTrainingG Jun 22, 2021
25ba21c
Merge pull request #18 from PaddlePaddle/develop
AnnaTrainingG Jul 5, 2021
3ce9983
Merge pull request #19 from PaddlePaddle/develop
AnnaTrainingG Jul 6, 2021
61842ed
Merge pull request #20 from PaddlePaddle/develop
AnnaTrainingG Jul 12, 2021
0e2c73b
Merge pull request #21 from PaddlePaddle/develop
AnnaTrainingG Jul 28, 2021
c1e59cf
Merge pull request #22 from PaddlePaddle/develop
AnnaTrainingG Aug 2, 2021
3a54149
Merge pull request #23 from PaddlePaddle/develop
AnnaTrainingG Aug 4, 2021
7addd79
Merge pull request #24 from PaddlePaddle/develop
AnnaTrainingG Aug 11, 2021
1e843d1
Merge pull request #25 from PaddlePaddle/develop
AnnaTrainingG Aug 23, 2021
0ee3411
add ElementwiseTernary, Reduce, ReadDataStride
AnnaTrainingG Aug 23, 2021
f763e02
delete divFunctor
AnnaTrainingG Aug 23, 2021
b0c3dcd
add writedataBase
AnnaTrainingG Aug 23, 2021
3b74aaa
delete cast and remove cast
AnnaTrainingG Aug 23, 2021
cdcfcda
update
AnnaTrainingG Aug 23, 2021
c476bba
update
AnnaTrainingG Aug 25, 2021
4c28141
add notes
AnnaTrainingG Aug 27, 2021
acbe8e6
add notes and change the name of expFunctor
AnnaTrainingG Aug 31, 2021
6c6ea8c
update
AnnaTrainingG Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 118 additions & 21 deletions paddle/fluid/operators/kernel_primitives/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,24 @@
#endif

#include <algorithm>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
namespace kernel_primitives {
namespace details {

#ifdef __HIPCC__
constexpr int kMaxThread = 256;
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
constexpr int kWarpSize = 64;
#else
constexpr int kMaxThread = 128;
constexpr int kWarpSize = 32;
#endif

enum ReduceMode { GlobalMode, LocalMode };
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
class MPTypeTrait {
public:
Expand All @@ -41,26 +52,75 @@ class MPTypeTrait<platform::float16> {
using Type = float;
};

} // namespace details
__device__ __forceinline__ int SharedMemoryIndex(int index) {
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
return (threadIdx.y + index) * blockDim.x + threadIdx.x;
}

/*************************** Compute Functor****************************/
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
template <typename T, typename Enable = void>
struct DivFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] / args[1];
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
};
return val;
}

template <typename T>
struct DivFunctor<T, typename std::enable_if_t<std::is_integral<T>::value>> {
inline HOSTDEVICE T operator()(const T* args) const {
PADDLE_ENFORCE(args[1] != 0,
platform::errors::InvalidArgument(
"Invalid Argument Error: Integer division by zero "
"encountered in divide. Please check the input value."));
return args[0] / args[1];
/* e.g.
* |---------block---------|
* |warp0|warp1|warp2|warp3|
* |0~31|32~63|64~95|96~127| ---->blockDim.x = 128
* \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp
* res0 res1 res2 res3 ---->2. Store result of each warp to shared memory
* \ \ / / ---->3. Load the result above from shared memory
* res to warp0 and process the second WarpReduce
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
__syncthreads();
using details::kWarpSize;
__shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int bid = threadIdx.y;
val = WarpReduce(val, reducer);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
val = shared[bid * block_dim_x + lane];
}
};

unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = 1; stride < block_dim_x; stride <<= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
}

template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[details::kMaxThread];
shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads();
if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
T temp = shared_memory[SharedMemoryIndex(stride)];
val = reducer(val, temp);
}
shared_memory[SharedMemoryIndex(0)] = val;
}
return val;
}

} // namespace details

/*************************** Compute Function****************************/

Expand Down Expand Up @@ -88,7 +148,7 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
}

/**
* @brief fma eg: a * b + c, in1 in2, in3 and out has the same shape
* @brief eg: a * b + c, in1 in2, in3 and out has the same shape
* @param:
* T : the type of in1 and in2, in3
* NX: the row of in1, in2 and in3
Expand All @@ -97,12 +157,16 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
class OpFunc>
__device__ __forceinline__ void ElementwiseFma(OutT* out, const T* in1,
const T* in2, const T* in3,
OpFunc compute) {
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
const T* in2, const T* in3,
OpFunc compute) {
T args[3];
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
out[idx] = static_cast<OutT>(compute(in1[idx], in2[idx], in3[idx]));
args[0] = in1[idx];
args[1] = in2[idx];
args[2] = in3[idx];
out[idx] = static_cast<OutT>(compute(args));
}
}

Expand Down Expand Up @@ -148,6 +212,39 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
}
}

// in[NY][NX] -> in[NY]
template <typename T, int NX, int NY, int BlockSize, class OpFunc,
int ReduceMode>
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
__device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
bool reduce_lastDim) {
AnnaTrainingG marked this conversation as resolved.
Show resolved Hide resolved
if (ReduceMode == details::ReduceMode::GlobalMode) {
bool block_reduce_y = (!reduce_lastDim) && (blockDim.y > 1);
// blockYReduce
if (block_reduce_y) {
#pragma unroll
for (int i = 0; i < NY; i++) {
out[i] = details::BlockYReduce<T, OpFunc>(out[i], reducer);
}
}

// blockXReduce
if (reduce_lastDim) {
#pragma unroll
for (int i = 0; i < NY; i++) {
out[i] = details::BlockXReduce<T, OpFunc>(out[i], reducer);
}
}
} else { // else LocalMode
#pragma unroll
for (int i = 0; i < NY; ++i) {
#pragma unroll
for (int j = 0; j < NX; ++j) {
out[i] = reducer(out[i], in[i * NX + j]);
}
}
}
}

} // namespace kernel_primitives
} // namespace operators
} // namespace paddle
Loading