Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize check_finite_and_unscale_op #31954

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 84 additions & 21 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,48 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
}

template <typename T, typename MT>
__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num,
bool* found_inf, T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;

if (idx < num) {
MT val = static_cast<MT>(in[idx]) * (*scale);
__global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale,
int64_t size, int64_t* starts,
bool* found_inf, T** outs) {
const int64_t tid = threadIdx.x + blockIdx.x * blockDim.x;

// copy starts array from global memory to shared memory
extern __shared__ int64_t s_starts[];
for (int i = threadIdx.x; i <= size; i += blockDim.x) {
s_starts[i] = starts[i];
}
__syncthreads();

const int64_t num = s_starts[size];
int pre_xs_index = 0;
bool t_found_inf = false;
const MT t_scale = *scale;
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
// get the xs's index of thread
int xs_index = pre_xs_index;
while (idx < s_starts[xs_index]) xs_index++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code in line 48 may not be triggered forever.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已在PR32554中删除该行

// avoid some tensor's numel is zero
while (idx >= s_starts[xs_index]) xs_index++;
pre_xs_index = xs_index - 1;

// get in data and out data
const T* in = xs[pre_xs_index];
T* out = outs[pre_xs_index];
int64_t in_idx = idx - s_starts[pre_xs_index];

// Unscale
MT val = static_cast<MT>(in[in_idx]) * t_scale;
T narrow_val = static_cast<T>(val);
out[idx] = narrow_val;
out[in_idx] = narrow_val;

// CheckFinite
if (!isfinite(narrow_val)) {
*found_inf = true;
t_found_inf = true;
}
}
if (t_found_inf) {
*found_inf = true;
}
}

template <typename T>
Expand All @@ -63,20 +93,53 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
scale_data, inverse_scale_v, found_inf_data);

for (size_t i = 0; i < xs.size(); ++i) {
const auto* x = xs[i];
auto* out = outs[i];
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());

int num = x->numel();
int block = 1024;
int grid = (num + block - 1) / block;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, inverse_scale_v, num, found_inf_data, out_data);
VLOG(3) << "finish kernel";
size_t xs_size = xs.size();
// calculate each tensor's start index and copy to device
auto h_starts_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());

auto d_starts_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());

h_starts[0] = 0;
for (int i = 1; i <= xs_size; i++) {
// the start index value of each tensor is
// the sum of previous tensor's size
h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
}
int64_t total_num = h_starts[xs_size];
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_starts, platform::CPUPlace(), h_starts,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream());

// copy each tensor's data address to device
auto h_mem = memory::Alloc(platform::CPUPlace(), 2 * xs_size * sizeof(T*));
const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;

auto d_mem = memory::Alloc(dev_ctx, 2 * xs_size * sizeof(T*));
const T** d_xs = reinterpret_cast<const T**>(d_mem->ptr());
T** d_outs = reinterpret_cast<T**>(d_mem->ptr()) + xs_size;

for (size_t i = 0; i < xs_size; ++i) {
h_xs[i] = xs[i]->data<T>();
h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs,
platform::CPUPlace(), h_xs, 2 * xs_size * sizeof(T*),
dev_ctx.stream());

// Launch Kernel
int block = 1024;
int block_num = block * 20; // each thread deal with 20 number
int grid = (total_num + block_num - 1) / block_num;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T, MPDType><<<
grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
VLOG(3) << "finish kernel";
}
};
} // namespace operators
Expand Down