Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize performance of softmax_fwd when axis!=-1 #38602

Merged
merged 7 commits into from
Feb 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion paddle/fluid/operators/softmax_cudnn_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,58 @@ struct UnaryDivFunctor {
Tx n_inv;
};

template <typename Tx, typename Ty = Tx>
struct SoftmaxForwardFunctor {
HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum)
: max(max), sum(sum) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x - max) / sum);
}

private:
Tx max;
Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct SoftmaxBackwardFunctor {
HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
return static_cast<Ty>(out * (grad_out - sum));
}

private:
Tx sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxForwardFunctor {
HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum)
: max(max), log_sum(std::log(sum)) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x - max - log_sum);
}

private:
Tx max;
Tx log_sum;
};

template <typename Tx, typename Ty = Tx>
struct LogSoftmaxBackwardFunctor {
HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {}

HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
return static_cast<Ty>(grad_out - std::exp(out) * sum);
}

private:
Tx sum;
};

/*
Core function of computing softmax forward for axis=-1.
The computation includes
Expand Down Expand Up @@ -255,7 +307,8 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
ReduceMaxFunctor<AccT>(), true);
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);

// compute sum
// compute sum
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
Expand Down Expand Up @@ -443,6 +496,113 @@ void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads,
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE

/**
* <NormalSoftmaxKernel>
* Better performence when axis != -1
*/

static void GetGridDim(int high_dim, int mid_dim, int low_dim,
const dim3& block, dim3* grid) {
int device_id = paddle::platform::GetCurrentDeviceId();
int max_mp = paddle::platform::GetGPUMultiProcessors(device_id);
int max_threads_per_mp =
paddle::platform::GetGPUMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp;
int num_threads = block.x * block.y;
int max_num_blocks = max_threads / num_threads;

int grid_x = (low_dim + block.x - 1) / block.x;
grid_x = std::min(grid_x, max_num_blocks);
int grid_y = (max_num_blocks + grid_x - 1) / grid_x;
grid_y = std::min(grid_y, high_dim);
grid->x = grid_x;
grid->y = grid_y;
}

static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
#ifdef __HIPCC__
constexpr int max_num_threads = 256;
#else
constexpr int max_num_threads = 1024;
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
#endif
int block_x = 1 << log2_ceil(low_dim);
int block_y = 1 << log2_ceil(mid_dim);
block->x = std::min(block_x, 32);
block->y = std::min(block_y, static_cast<int>(max_num_threads / block->x));
block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
}

ZzSean marked this conversation as resolved.
Show resolved Hide resolved
static void GetLaunchConfig(int high_dim, int mid_dim, int low_dim, dim3* grid,
dim3* block) {
GetBlockDim(mid_dim, low_dim, block);
GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
}

template <typename T, typename AccT,
template <typename, typename> class Functor>
__global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim,
int mid_dim, int low_dim) {
using kMode = kps::details::ReduceMode;
const int high_stride = mid_dim * low_dim;
const int mid_stride = low_dim;
for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
low_id += blockDim.x * gridDim.x) {
const int input_offset = high_id * high_stride + low_id;

// 1. reduce max
AccT max_value = -std::numeric_limits<AccT>::infinity();
AccT value = -std::numeric_limits<AccT>::infinity();
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
max_value = kps::MaxFunctor<AccT>()(max_value, value);
}

if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::MaxFunctor<AccT>, kMode::kGlobalMode>(
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
&max_value, &max_value, kps::MaxFunctor<AccT>(), false);
}

// 2. reduce sum
AccT sum = 0;
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
sum += std::exp(value - max_value);
}
if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
&sum, &sum, kps::AddFunctor<AccT>(), false);
}

// 3. (log)softmax
Functor<AccT, T> functor(max_value, sum);
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = input_offset + mid_id * mid_stride;
output[data_offset] = functor(static_cast<AccT>(input[data_offset]));
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx,
T* output_data, const T* input_data,
int high_dim, int mid_dim, int low_dim) {
using AccT = typename details::MPTypeTrait<T>::Type;
dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) {
NormalSoftmaxForward<
T, AccT,
LogSoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim);
} else {
NormalSoftmaxForward<
T, AccT, SoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim);
}
}

template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const Tensor& x, const int input_axis,
Expand Down Expand Up @@ -490,6 +650,9 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
out_data, x.data<T>(), N, dim,
dim, kDimLog2);
}
} else if (D > 1) {
LaunchNormalSoftmaxForward<T, LogMode>(dev_ctx, out_data, x.data<T>(), N,
dim, D);
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
Expand Down