Skip to content

Commit

Permalink
optimize update_loss_scaling_op by fused for loop to one kernel, test…
Browse files Browse the repository at this point in the history
…=develop
  • Loading branch information
thisjiang committed Apr 26, 2021
1 parent fd85a4a commit ad79dff
Showing 1 changed file with 73 additions and 15 deletions.
88 changes: 73 additions & 15 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,37 @@ __global__ void GpuUpdateLossScaling(
}

template <typename T>
__global__ void FillIf(T* data, const int64_t num, const T value,
const bool* has_inf) {
if (*has_inf) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < num; i += blockDim.x * gridDim.x) {
data[i] = value;
}
__global__ void FusedFillIf(T** outs, const size_t xs_size,
const int64_t* starts, const T value,
const bool* has_inf) {
if (!(*has_inf)) return;

const int tid = threadIdx.x + blockIdx.x * blockDim.x;

// copy starts array from global memory to shared memory
extern __shared__ int64_t starts_s[];
for (int i = threadIdx.x; i <= xs_size; i += blockDim.x) {
starts_s[i] = starts[i];
}
__syncthreads();

const int64_t total_num = starts_s[xs_size];
int out_index = 0;

for (int64_t id = tid; id < total_num; id += blockDim.x * gridDim.x) {
// get the "out" index of "id"
int next_out_index = out_index;
while (id < starts_s[next_out_index]) next_out_index++;
// avoid some tensor's numel is zero
while (id >= starts_s[next_out_index]) next_out_index++;
out_index = next_out_index - 1;

// get data pointer and index
T* out_data = outs[out_index];
int64_t idx = id - starts_s[out_index];

// set value
out_data[idx] = value;
}
}

Expand Down Expand Up @@ -68,15 +92,49 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int64_t num = out->numel();
int block = 1024;
int grid = (block - 1 + num) / block;
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(
out_data, num, static_cast<T>(0), found_inf_data);
size_t xs_size = xs.size();
// alloc each tensor's start index and copy to device
auto starts_h_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
int64_t* starts_h = reinterpret_cast<int64_t*>(starts_h_tensor->ptr());

auto starts_d_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* starts_d = reinterpret_cast<int64_t*>(starts_d_tensor->ptr());

starts_h[0] = 0;
for (int i = 0; i < xs_size; i++) {
// the start index value of each tensor is
// the sum of previous tensor's size
starts_h[i + 1] = starts_h[i] + outs[i]->numel();
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
starts_d, platform::CPUPlace(), starts_h,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream());

// copy each tensor of "outs" data address array to device
auto outs_addr_h_tensor =
memory::Alloc(platform::CPUPlace(), xs_size * sizeof(T*));
T** outs_addr_h = reinterpret_cast<T**>(outs_addr_h_tensor->ptr());

auto outs_addr_d_tensor = memory::Alloc(dev_ctx, xs_size * sizeof(T*));
T** outs_addr_d = reinterpret_cast<T**>(outs_addr_d_tensor->ptr());

for (size_t i = 0; i < xs_size; ++i) {
outs_addr_h[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
outs_addr_d, platform::CPUPlace(), outs_addr_h,
xs_size * sizeof(T*), dev_ctx.stream());

// launch cuda kernel
int64_t total_num = starts_h[xs_size];
int64_t block = std::min(static_cast<int64_t>(1024), total_num);
int64_t block_num = block * 50; // each thread deal with 50 data
int64_t grid = (total_num + block_num - 1) / block_num;
FusedFillIf<
T><<<grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
outs_addr_d, xs_size, starts_d, static_cast<T>(0), found_inf_data);
}
};

Expand Down

0 comments on commit ad79dff

Please sign in to comment.