Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suppport more scenes for fused_fast_ln #42282

Merged
merged 2 commits into from
Apr 28, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 109 additions & 41 deletions paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,25 @@ __global__ void FusedLayernormResidualDropoutBias(
}

/*
* @brief layernorm(residual + dropout(x));
* @brief layernorm(residual + dropout(x));
* Conditions:
* (1) The number of cols is 1024;
* (1) The number of cols is 768/1024/4096;
* (2) layer_norm scale and bias is not null;
* (3) linear bias is null;
* @param
* rows: batch_size * seq_len
* cols: 1024
* x_: [rows, cols], inputs
* residual_:[rows, cols]
* bias_: [cols], linear bias, can be null
* gamma_: [cols]: layernorm scale, not null
* beta_: [cols], layernorm bias, not null
* mask_out_: [rows, cols], dropout result
* residual_out_: [rows, cols], residual + dropout(src)
* y_: [rows, cols], layernorm result
* mean_out_: [rows]: layernorm means
* var_out_: [rows]: layernorm vars
*/
*/
template <
typename T, typename U, typename ScaleT = U, typename MaskType = uint8_t,
int VecSize = 8, int WARPS_M = 4, int WARPS_N = 1, int BYTES_PER_LDG = 16,
Expand All @@ -182,14 +183,16 @@ template <
int THREADS_PER_CTA = WARPS_M *THREADS_PER_ROW, int ROWS_PER_CTA = WARPS_M,
int ELTS_PER_ROW_PER_CTA = THREADS_PER_ROW *VecSize,
int LDGS = ELTS_PER_ROW / ELTS_PER_ROW_PER_CTA>
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
__global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
int rows, int cols, uint64_t seed, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test,
const uint64_t increment, const float epsilon, const T *__restrict__ x_ptr,
const T *__restrict__ residual_ptr, const ScaleT *__restrict__ gamma_ptr,
const ScaleT *__restrict__ beta_ptr, MaskType *__restrict__ mask_out_ptr,
U *__restrict__ mean_out_ptr, U *__restrict__ var_out_ptr,
T *__restrict__ residual_out_ptr, T *__restrict__ y_ptr) {
const T *__restrict__ residual_ptr, const T *__restrict__ bias_ptr,
const ScaleT *__restrict__ gamma_ptr, const ScaleT *__restrict__ beta_ptr,
MaskType *__restrict__ mask_out_ptr, U *__restrict__ mean_out_ptr,
U *__restrict__ var_out_ptr, T *__restrict__ residual_out_ptr,
T *__restrict__ y_ptr) {
__shared__ U smem[WARPS_M * WARPS_N];
using Vec = phi::AlignedVector<T, VecSize>;
using Vec_scale = phi::AlignedVector<ScaleT, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
Expand All @@ -204,12 +207,30 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
const int c = warp_n * THREADS_PER_WARP + lane; // lane
const int r = bidx * ROWS_PER_CTA + warp_m; // row id

int idx = r * LN_NUM_COLS + c;
int idx = r * ELTS_PER_ROW + c;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);

T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);

// bias
Vec bias[LDGS];
if (bias_ptr != nullptr) {
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(bias_ptr + col * VecSize, &bias[it]);
col += THREADS_PER_ROW;
}
} else {
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
bias[it][jt] = static_cast<T>(0.0f);
}
}
}

Vec_scale gamma[LDGS];
Vec_scale beta[LDGS];
#pragma unroll
Expand All @@ -219,14 +240,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
col += THREADS_PER_ROW;
}

constexpr U rn = 1.f / U(LN_NUM_COLS);
constexpr U rn = 1.f / U(ELTS_PER_ROW);
for (int row = r; row < rows; row += gridDim.x * ROWS_PER_CTA) {
Vec x[LDGS];
Vec residual[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * LN_NUM_COLS + col * VecSize, &x[it]);
phi::Load<T, VecSize>(residual_ptr + row * LN_NUM_COLS + col * VecSize,
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
phi::Load<T, VecSize>(residual_ptr + row * ELTS_PER_ROW + col * VecSize,
&residual[it]);
col += THREADS_PER_ROW;
}
Expand Down Expand Up @@ -260,7 +281,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll
for (int jt = 0; jt < VecSize; jt++) {
// dropout(x) + residual
x[it][jt] = x[it][jt] * static_cast<T>(mask_vec[it][jt]) * factor +
x[it][jt] = (x[it][jt] + bias[it][jt]) *
static_cast<T>(mask_vec[it][jt]) * factor +
residual[it][jt];
xf[it * VecSize + jt] = U(x[it][jt]);
}
Expand All @@ -270,9 +292,9 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(
x[it], residual_out_ptr + row * LN_NUM_COLS + col * VecSize);
x[it], residual_out_ptr + row * ELTS_PER_ROW + col * VecSize);
phi::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * LN_NUM_COLS + col * VecSize);
mask_vec[it], mask_out_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW;
}

Expand All @@ -289,6 +311,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
}
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = mu_local;
}
__syncthreads();
if (tidx == 0) {
mu_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
mu_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = mu_local;
}
__syncthreads();
mu_local = smem[warp_m];
}
mu_local *= rn;
if (lane == 0) {
mean_out_ptr[row] = mu_local;
Expand All @@ -308,6 +346,22 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
}
if (WARPS_N > 1) {
if (lane == 0) {
smem[warp_m * WARPS_N + warp_n] = var_local;
}
__syncthreads();
if (tidx == 0) {
var_local = 0.f;
#pragma unroll
for (int it = 0; it < WARPS_N; ++it) {
var_local += smem[warp_m * WARPS_N + it];
}
smem[warp_m] = var_local;
}
__syncthreads();
var_local = smem[warp_m];
}
U rsigma = rsqrtf(var_local * rn + epsilon);
if (lane == 0) {
// Note: the stored var is different for paddle(ln) and apex (fast ln).
Expand All @@ -332,7 +386,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_fwd_1024_kernel(

#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * LN_NUM_COLS + col * VecSize);
phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW;
}
}
Expand Down Expand Up @@ -390,12 +444,37 @@ void LaunchLayernormResidualDropoutBias(
return;
}

bool can_call_1024_kernel = false;
if (cols == 1024 && scale != nullptr && layernorm_bias != nullptr &&
bias == nullptr) {
can_call_1024_kernel = true;
#define LAUNCH_FUSED_FAST_LN_KERNEL_BASE(cols) \
case (cols): { \
constexpr int WARPS_N = cols < 1024 ? 1 : (cols / 1024); \
constexpr int WARPS_M = 4 / WARPS_N; \
const int THREADS_PER_WARP = 32; \
const int BYTES_PER_LDG = 16; \
const int VecSize = BYTES_PER_LDG / sizeof(T); \
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M; \
const int ROWS_PER_CTA = WARPS_M; \
const int grid = \
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA))); \
fused_fast_ln_fwd_kernel< \
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t, \
VecSize, WARPS_M, WARPS_N, BYTES_PER_LDG, \
cols><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>( \
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, \
increment, epsilon, src, residual, bias, scale, layernorm_bias, \
mask_data, mean, var, dst, layernorm_dst); \
} break

#define LAUNCH_FUSED_FAST_LN_KERNEL \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(768); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1024); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)

bool can_call_fast_ln_kernel = false;
if ((cols == 768 || cols == 1024 || cols == 4096) && scale != nullptr &&
layernorm_bias != nullptr) {
can_call_fast_ln_kernel = true;
}
VLOG(6) << "can_call_1024_kernel = " << can_call_1024_kernel;
VLOG(6) << "can_call_fast_ln_kernel = " << can_call_fast_ln_kernel;

const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) {
Expand All @@ -407,26 +486,15 @@ void LaunchLayernormResidualDropoutBias(
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
} else {
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);

const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;

// Note: the grid can not exceed max_grid of the gpu.
const int grid =
static_cast<int>(std::ceil(rows / static_cast<float>(ROWS_PER_CTA)));
fused_ln_fwd_1024_kernel<
T, U, LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX>, uint8_t,
VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test,
increment, epsilon, src, residual, scale, layernorm_bias, mask_data,
mean, var, dst, layernorm_dst);
if (can_call_fast_ln_kernel) {
switch (cols) {
LAUNCH_FUSED_FAST_LN_KERNEL;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only when column is equal to 768/1024/4096 is supported for "
"now"));
break;
}
} else {
int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias<
Expand Down