Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick] Optimize update_loss_scaling_op(#32554) #32606

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,36 @@ __global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale,
__syncthreads();

const int64_t num = s_starts[size];
int pre_xs_index = 0;
bool t_found_inf = false;
const MT t_scale = *scale;
int xs_index = 0;
bool local_found_inf = false;
const MT local_scale = *scale;
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
// get the xs's index of thread
int xs_index = pre_xs_index;
while (idx < s_starts[xs_index]) xs_index++;
// avoid some tensor's numel is zero
while (idx >= s_starts[xs_index]) xs_index++;
pre_xs_index = xs_index - 1;
// get the "out" index of "id"
// For example:
// idx = 15, starts = [0, 10, 10, 20, 30]
// because 10 <= idx < 20 ==>
// the idx element locate in the 3rd tensor (notice the 2nd tensor size is
// 0)
int next_xs_index = xs_index;
while (idx >= s_starts[next_xs_index]) next_xs_index++;
xs_index = next_xs_index - 1;

// get in data and out data
const T* in = xs[pre_xs_index];
T* out = outs[pre_xs_index];
int64_t in_idx = idx - s_starts[pre_xs_index];
const T* in = xs[xs_index];
T* out = outs[xs_index];
int64_t in_idx = idx - s_starts[xs_index];

// Unscale
MT val = static_cast<MT>(in[in_idx]) * t_scale;
MT val = static_cast<MT>(in[in_idx]) * local_scale;
T narrow_val = static_cast<T>(val);
out[in_idx] = narrow_val;

// CheckFinite
if (!isfinite(narrow_val)) {
t_found_inf = true;
local_found_inf = true;
}
}
if (t_found_inf) {
if (local_found_inf) {
*found_inf = true;
}
}
Expand Down Expand Up @@ -94,28 +97,30 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
scale_data, inverse_scale_v, found_inf_data);

size_t xs_size = xs.size();
const auto& cpu_place = platform::CPUPlace();
// calculate each tensor's start index and copy to device
auto h_starts_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());

auto d_starts_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());

// the start index value of each tensor is
// the sum of previous tensor's size. For example:
// xs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
h_starts[0] = 0;
for (int i = 1; i <= xs_size; i++) {
// the start index value of each tensor is
// the sum of previous tensor's size
h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
}
int64_t total_num = h_starts[xs_size];
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_starts, platform::CPUPlace(), h_starts,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream());
d_starts, cpu_place, h_starts, (xs_size + 1) * sizeof(int64_t),
dev_ctx.stream());

// copy each tensor's data address to device
auto h_mem = memory::Alloc(platform::CPUPlace(), 2 * xs_size * sizeof(T*));
auto h_mem = memory::Alloc(cpu_place, 2 * xs_size * sizeof(T*));
const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;

Expand All @@ -128,16 +133,18 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs,
platform::CPUPlace(), h_xs, 2 * xs_size * sizeof(T*),
dev_ctx.stream());
cpu_place, h_xs, 2 * xs_size * sizeof(T*), dev_ctx.stream());

// Launch Kernel
int block = 1024;
int block_num = block * 20; // each thread deal with 20 number
int grid = (total_num + block_num - 1) / block_num;
int threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
int elements_per_block =
threads_per_block * 20; // each thread deal with 20 number
int blocks_per_grid =
(total_num + elements_per_block - 1) / elements_per_block;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T, MPDType><<<
grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
CheckFiniteAndUnscale<
T, MPDType><<<blocks_per_grid, threads_per_block,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
VLOG(3) << "finish kernel";
}
Expand Down
93 changes: 78 additions & 15 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,39 @@ __global__ void GpuUpdateLossScaling(
}

template <typename T>
__global__ void FillIf(T* data, const int64_t num, const T value,
const bool* has_inf) {
if (*has_inf) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < num; i += blockDim.x * gridDim.x) {
data[i] = value;
}
__global__ void FusedFillIf(T** outs, const size_t xs_size,
const int64_t* starts, const T value,
const bool* has_inf) {
if (!(*has_inf)) return;

const int tid = threadIdx.x + blockIdx.x * blockDim.x;

// copy starts array from global memory to shared memory
extern __shared__ int64_t s_starts[];
for (int i = threadIdx.x; i <= xs_size; i += blockDim.x) {
s_starts[i] = starts[i];
}
__syncthreads();

const int64_t total_num = s_starts[xs_size];
int out_index = 0;

for (int64_t id = tid; id < total_num; id += blockDim.x * gridDim.x) {
// get the "out" index of "id"
// For example:
// id = 15, starts = [0, 10, 10, 20, 30]
// because 10 <= id < 20 ==>
// the id element locate in the 3rd tensor (notice the 2nd tensor size is 0)
int next_out_index = out_index;
while (id >= s_starts[next_out_index]) next_out_index++;
out_index = next_out_index - 1;

// get data pointer and index
T* out_data = outs[out_index];
int64_t idx = id - s_starts[out_index];

// set value
out_data[idx] = value;
}
}

Expand Down Expand Up @@ -68,15 +94,52 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int64_t num = out->numel();
int block = 1024;
int grid = (block - 1 + num) / block;
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(
out_data, num, static_cast<T>(0), found_inf_data);
size_t xs_size = xs.size();
const auto& cpu_place = platform::CPUPlace();
// alloc each tensor's start index and copy to device
auto h_in_starts_mem =
memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_in_starts_mem->ptr());

auto d_in_starts_mem =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_in_starts_mem->ptr());

// the start index value of each tensor is
// the sum of previous tensor's size. For example:
// outs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
h_starts[0] = 0;
for (int i = 0; i < xs_size; i++) {
h_starts[i + 1] = h_starts[i] + outs[i]->numel();
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_starts, cpu_place, h_starts, (xs_size + 1) * sizeof(int64_t),
dev_ctx.stream());

// copy each tensor of "outs" data address array to device
auto h_out_addrs_mem = memory::Alloc(cpu_place, xs_size * sizeof(T*));
T** h_out_addrs = reinterpret_cast<T**>(h_out_addrs_mem->ptr());

auto d_out_addrs_mem = memory::Alloc(dev_ctx, xs_size * sizeof(T*));
T** d_out_addrs = reinterpret_cast<T**>(d_out_addrs_mem->ptr());

for (size_t i = 0; i < xs_size; ++i) {
h_out_addrs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_out_addrs, cpu_place, h_out_addrs, xs_size * sizeof(T*),
dev_ctx.stream());

// launch cuda kernel
int64_t total_num = h_starts[xs_size];
int64_t threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
int64_t elements_per_block =
threads_per_block * 50; // each thread deal with 50 data
int64_t blocks_per_grid =
(total_num + elements_per_block - 1) / elements_per_block;
FusedFillIf<T><<<blocks_per_grid, threads_per_block,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
d_out_addrs, xs_size, d_starts, static_cast<T>(0), found_inf_data);
}
};

Expand Down