Skip to content

Commit

Permalink
move code to distributed_fused_lamb_op.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Feb 24, 2022
1 parent 9a701b0 commit 0e32d82
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 208 deletions.
200 changes: 170 additions & 30 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <cmath>
#include "paddle/fluid/memory/buffer.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
Expand Down Expand Up @@ -43,6 +44,163 @@ namespace operators {
template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type;

template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}

template <typename T, int BlockDim, int VecSize>
struct L2NormFunctor {
DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size,
const T *x, MasterT<T> *y, int max_chunk_num) const {
using MT = MasterT<T>;
const T *ptr = x + offset;

using BlockReduce = cub::BlockReduce<MT, BlockDim>;
__shared__ typename BlockReduce::TempStorage storage;

MT square_sum = static_cast<MT>(0);
int i;
for (i = threadIdx.x * VecSize; i + VecSize <= size;
i += (BlockDim * VecSize)) {
platform::AlignedVector<T, VecSize> tmp_vec;
platform::Load(ptr + i, &tmp_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
auto tmp = static_cast<MT>(tmp_vec[j]);
square_sum += (tmp * tmp);
}
}

for (; i < size; ++i) {
auto tmp = static_cast<MT>(ptr[i]);
square_sum += (tmp * tmp);
}

square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum());
if (threadIdx.x == 0) {
y[tensor_id * max_chunk_num + chunk_id] = square_sum;
}
}
};

template <typename InT, typename OutT, int BlockDim, bool NeedSqrt>
static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
const InT *x, OutT *y, int max_chunk_num) {
int tensor_id = blockIdx.x;
x += (tensor_id * max_chunk_num);
using BlockReduce = cub::BlockReduce<InT, BlockDim>;
__shared__ typename BlockReduce::TempStorage storage;
InT sum = static_cast<InT>(0);
for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) {
sum += x[i];
}
sum = BlockReduce(storage).Reduce(sum, cub::Sum());
if (threadIdx.x == 0) {
if (NeedSqrt) {
y[blockIdx.x] = static_cast<OutT>(sqrtf(sum));
} else {
y[blockIdx.x] = static_cast<OutT>(sum);
}
}
}

template <typename T>
static int GetChunkedVecSize(const T *ptr, int chunk_size) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");

constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
auto address = reinterpret_cast<uintptr_t>(ptr);
constexpr int vec8 = alignof(platform::AlignedVector<T, 8>);
constexpr int vec4 = alignof(platform::AlignedVector<T, 4>);
constexpr int vec2 = alignof(platform::AlignedVector<T, 2>);
if (address % vec8 == 0 && chunk_size % vec8 == 0) {
return std::min(8, valid_vec_size);
} else if (address % vec4 == 0 && chunk_size % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0 && chunk_size % vec2 == 0) {
return std::min(2, valid_vec_size);
} else {
return 1;
}
}

#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \
case __vec_size: { \
constexpr int kVecSize = __vec_size; \
__VA_ARGS__; \
break; \
}

#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \
do { \
switch (__vec_size) { \
PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \
} \
} while (0)

// TODO(zengjinle): which chunk_size is better?
template <typename InT, typename OutT, bool NeedSqrt = false,
int MaxTensorNumPerLaunch = 50, int MaxChunkNumPerLaunch = 680,
int BlockDim = 512>
static void MultiTensorL2Norm(const platform::CUDAPlace &place,
gpuStream_t stream, const InT *x,
const int *offsets, int n, OutT *y,
int chunk_size = 65536) {
if (n <= 0) return;

constexpr int kNumTensor = MaxTensorNumPerLaunch;
constexpr int kNumChunk = MaxChunkNumPerLaunch;
constexpr int kBlockDim = BlockDim;

int max_chunk_num = -1;
int vec_size = 8;
int total_chunk_num = 0;
for (int i = 0; i < n; ++i) {
vec_size = std::min(
vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size));
int length = offsets[i + 1] - offsets[i];
auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size;
max_chunk_num = std::max(max_chunk_num, tmp_chunk_num);
total_chunk_num += tmp_chunk_num;
}

VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num
<< " , total_chunk_num = " << total_chunk_num
<< " , tensor_num = " << n;

using MT = MasterT<InT>;
memory::Buffer tmp_out(place);
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);

#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \
do { \
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
VLOG(10) << __func__ << " " << typeid(InT).name() \
<< " VecSize = " << kVecSize; \
MultiTensorApply<FunctorT, kBlockDim, kNumTensor, kNumChunk>( \
FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \
max_chunk_num); \
} while (0)

PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL);
#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL

MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim,
NeedSqrt><<<n, kBlockDim, 0, stream>>>(
tmp_out_ptr, y, max_chunk_num);
}

template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm(
const framework::ExecutionContext &ctx, const float *param_square_norm,
Expand Down Expand Up @@ -643,16 +801,6 @@ static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out,
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
}

template <typename T>
struct AddConstantFunctor {
explicit AddConstantFunctor(T bias) : bias_(bias) {}

T operator()(T x) const { return x + bias_; }

private:
T bias_;
};

template <typename T>
struct OffsetWithBiasFunctor {
OffsetWithBiasFunctor(const T *offset, T bias)
Expand All @@ -670,12 +818,14 @@ struct OffsetWithBiasFunctor {
};

template <typename T, typename OffsetT>
static void DeviceSegmentedSquareNorm(
const platform::CUDADeviceContext &dev_ctx, const T *x, MasterT<T> *y,
int n, const OffsetT *offset, OffsetT init_offset, memory::Buffer *buffer) {
static void DeviceSegmentedSquareNorm(const platform::CUDAPlace &place,
gpuStream_t stream, const T *x,
MasterT<T> *y, int n,
const OffsetT *offset,
OffsetT init_offset,
memory::Buffer *buffer) {
if (!FLAGS_use_multi_tensor_apply) {
if (n <= 0) return;
auto stream = dev_ctx.stream();
cub::TransformInputIterator<MasterT<T>, SquareFunctor<T>, const T *> iter(
x, SquareFunctor<T>());
if (init_offset == static_cast<OffsetT>(0)) {
Expand All @@ -694,7 +844,7 @@ static void DeviceSegmentedSquareNorm(
return;
}

MultiTensorL2Norm(dev_ctx, x, offset, n, y);
MultiTensorL2Norm(place, stream, x, offset, n, y);
}

template <typename T>
Expand Down Expand Up @@ -869,16 +1019,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel,
}
}

template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}

template <typename T>
class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
Expand Down Expand Up @@ -1217,29 +1357,29 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
}
}
DeviceSegmentedSquareNorm(dev_ctx, fp32_param, param_square_norm,
DeviceSegmentedSquareNorm(place, stream, fp32_param, param_square_norm,
fp32_global_param_num, fused_offsets, 0,
&cub_tmp_buffer);
if (use_master_param_norm) {
DeviceSegmentedSquareNorm(dev_ctx, master_param + fp16_offset,
DeviceSegmentedSquareNorm(place, stream, master_param + fp16_offset,
param_square_norm + fp16_local_start_idx,
fp16_local_param_num,
fp16_partial_fused_offsets, 0, &cub_tmp_buffer);
} else {
// NOTE: extra computation is performed. We can improve this performance
// if needed in the future.
DeviceSegmentedSquareNorm(
dev_ctx, fp16_param, param_square_norm + fp32_global_param_num,
place, stream, fp16_param, param_square_norm + fp32_global_param_num,
fp16_global_param_num, fused_offsets + fp32_global_param_num,
static_cast<int>(fp32_numel), &cub_tmp_buffer);
}

DeviceSegmentedSquareNorm(
dev_ctx, trust_ratio_div,
place, stream, trust_ratio_div,
trust_ratio_div_square_norm + fp32_local_start_idx,
fp32_local_param_num, fp32_partial_fused_offsets, 0, &cub_tmp_buffer);
DeviceSegmentedSquareNorm(
dev_ctx, trust_ratio_div + fp32_numel_each_device,
place, stream, trust_ratio_div + fp32_numel_each_device,
trust_ratio_div_square_norm + fp16_local_start_idx,
fp16_local_param_num, fp16_partial_fused_offsets, 0, &cub_tmp_buffer);

Expand Down
Loading

0 comments on commit 0e32d82

Please sign in to comment.