Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix grid dim.y should less than 65535 bug #4

Merged
merged 1 commit into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions paddle/fluid/operators/fused/fused_dropout_act_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,20 @@ __global__ void FusedDropoutActBias(
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int col_id = threadIdx.x;
int row_id = gridDim.y * blockIdx.x + blockIdx.y;
int idx = row_id * cols + col_id;

curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);

const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);

for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
int i = col_id * VecSize;
int r = row_id;
int stride = blockDim.x * VecSize;
for (; r < rows; r += blockDim.y * gridDim.y * gridDim.x) {
for (; i < cols; i += stride) {
FusedResidualDropoutBiasOneThread<T,
MaskType,
VecSize,
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/fused/fused_dropout_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
std::min(tmp_cols,
static_cast<uint32_t>(std::min(
ctx.GetMaxThreadsPerBlock(), 512))));
const auto blocks_x =
auto blocks_x =
std::max(static_cast<uint32_t>(1), (tmp_cols + threads - 1) / threads);
const auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
auto blocks_y = std::max(static_cast<uint32_t>(1), rows);
platform::GpuLaunchConfig config;
while (blocks_y > 65535) {
blocks_x *= 2;
blocks_y /= 2;
}
config.block_per_grid.x = blocks_x;
config.block_per_grid.y = blocks_y;
config.thread_per_block.x = threads;
Expand Down
12 changes: 7 additions & 5 deletions paddle/fluid/operators/fused/fused_residual_dropout_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,18 @@ __global__ void FusedResidualDropoutBias(
const float *dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_next_in_scale = 1.0) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int col_id = threadIdx.x;
int row_id = gridDim.y * blockIdx.x + blockIdx.y;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
phi::funcs::ReluFunctor<T> relu;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
int i = col_id * VecSize;
int r = row_id;
int stride = blockDim.x * VecSize;
for (; r < rows; r += blockDim.y * gridDim.y * gridDim.x) {
for (; i < cols; i += stride) {
FusedResidualDropoutBiasOneThread<T,
MaskType,
VecSize,
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/operators/fused/fused_softmax_mask.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,16 @@ __global__ void FusedSoftmaxMaskVecKernel(T* dst,
// gridDim/blockIdx = (DIV_UP(seq_len, warps_per_block), batch_size, head_num)
// every block processes 4(warps_per_block) sequences
// seq_id = seq_id * 4 + warp_id, eg.seq_len=128, 127=31*4+3
int seq_id = blockIdx.x * warps_per_block + threadIdx.y;
// int seq_id = blockIdx.x * warps_per_block + threadIdx.y;
// FIX
int64_t seq_id = blockIdx.x * warps_per_block + threadIdx.y;
if (seq_id >= seq_len) return;

// ((bid*head_num + hid)*seq_len + seq_id) * seq_len
int offset =
int64_t offset =
((blockIdx.y * gridDim.z + blockIdx.z) * seq_len + seq_id) * seq_len;
// (bid * seq_len + seq_id) * seq_len
int mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len;
int64_t mask_offset = (blockIdx.y * seq_len + seq_id) * seq_len;
src += offset;
dst += offset;
mask += mask_offset;
Expand Down