Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize performance of softmax_fwd when axis!=-1 #38602

Merged
merged 7 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/operators/kernel_primitives/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[details::kReduceMaxThread];
__shared__ T shared_memory[1024];
shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads();
Expand All @@ -129,7 +129,8 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
}
shared_memory[SharedMemoryIndex(0)] = val;
}
return val;
__syncthreads();
return shared_memory[threadIdx.x];
}

} // namespace details
Expand Down
161 changes: 160 additions & 1 deletion paddle/fluid/operators/softmax_cudnn_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,58 @@ struct UnaryDivFunctor {
Tx n_inv;
};

template <typename Tx, typename Ty = Tx>
struct SoftmaxForwardFunctor {
HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum)
: max(max), sum(sum) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x - max) / sum);
}

private:
Tx max;
Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct SoftmaxBackwardFunctor {
HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
return static_cast<Ty>(grad_out - out * sum);
}

private:
Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxForwardFunctor {
HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum)
: max(max), log_sum(std::log(sum)) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x - max - log_sum);
}

private:
Tx max;
Tx log_sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxBackwardFunctor {
HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
return static_cast<Ty>(grad_out - std::exp(out) * sum);
}

private:
Tx sum;
};

/*
Core function of computing softmax forward for axis=-1.
The computation includes
Expand Down Expand Up @@ -256,7 +308,8 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
ReduceMaxFunctor<AccT>(), true);
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);

// compute sum
// compute sum
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
Expand Down Expand Up @@ -444,6 +497,109 @@ void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads,
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE

/**
* <NormalSoftmaxKernel>
* Better performence when axis != -1
*/

static void GetGridDim(int high_dim, int mid_dim, int low_dim,
const dim3& block, dim3* grid) {
int device_id = paddle::platform::GetCurrentDeviceId();
int max_mp = paddle::platform::GetGPUMultiProcessors(device_id);
int max_threads_per_mp =
paddle::platform::GetGPUMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp;
int num_threads = block.x * block.y;
int max_num_blocks = max_threads / num_threads;

int grid_x = (low_dim + block.x - 1) / block.x;
grid_x = std::min(grid_x, max_num_blocks);
int grid_y = (max_num_blocks + grid_x - 1) / grid_x;
grid_y = std::min(grid_y, high_dim);
grid->x = grid_x;
grid->y = grid_y;
}

static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
constexpr int max_num_threads = 1024;
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
int block_x = 1 << log2_ceil(low_dim);
int block_y = 1 << log2_ceil(mid_dim);
block->x = std::min(block_x, 32);
block->y = std::min(block_y, static_cast<int>(max_num_threads / block->x));
block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
}

ZzSean marked this conversation as resolved.
Show resolved Hide resolved
static void GetLaunchConfig(int high_dim, int mid_dim, int low_dim, dim3* grid,
dim3* block) {
GetBlockDim(mid_dim, low_dim, block);
GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
}

template <typename T, typename AccT,
template <typename, typename> class Functor>
__global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim,
int mid_dim, int low_dim) {
using kMode = kps::details::ReduceMode;
const int high_stride = mid_dim * low_dim;
const int mid_stride = low_dim;
for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
low_id += blockDim.x * gridDim.x) {
const int input_offset = high_id * high_stride + low_id;

// 1. reduce max
AccT max_value = -std::numeric_limits<AccT>::infinity();
AccT value = -std::numeric_limits<AccT>::infinity();
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
max_value = kps::MaxFunctor<AccT>()(max_value, value);
}

if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::MaxFunctor<AccT>, kMode::kGlobalMode>(
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
&max_value, &max_value, kps::MaxFunctor<AccT>(), false);
}

// 2. reduce sum
AccT sum = 0;
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
sum += std::exp(value - max_value);
}
if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
&sum, &sum, kps::AddFunctor<AccT>(), false);
}

// 3. (log)softmax
Functor<AccT, T> functor(max_value, sum);
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = input_offset + mid_id * mid_stride;
output[data_offset] = functor(static_cast<AccT>(input[data_offset]));
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx,
T* output_data, const T* input_data,
int high_dim, int mid_dim, int low_dim) {
using AccT = typename details::MPTypeTrait<T>::Type;
dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) {
NormalSoftmaxForward<
T, AccT,
LogSoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim);
} else {
NormalSoftmaxForward<
T, AccT, SoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim);
}
}

template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const Tensor& x, const int input_axis,
Expand Down Expand Up @@ -491,6 +647,9 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
out_data, x.data<T>(), N, dim,
dim, kDimLog2);
}
} else if (D > 1) {
LaunchNormalSoftmaxForward<T, LogMode>(dev_ctx, out_data, x.data<T>(), N,
dim, D);
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
Expand Down