Skip to content

Commit

Permalink
Add MultiTensorApply to calculate L2-Norm in DistributedFusedLamb opt…
Browse files Browse the repository at this point in the history
…imizer (#39900)

* add multi tensor apply l2 norm

* add multi_tensor_apply code

* make sizeof(TensorMeta) smalller

* move code to distributed_fused_lamb_op.cu

* remove useless FLAGS
  • Loading branch information
sneaxiy authored Feb 25, 2022
1 parent 639675d commit d32a010
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ static void CopyVectorToTensor(const std::vector<T> &src,
memory::Copy(place, dst_ptr, platform::CPUPlace(), src_ptr, nbytes, stream);
}

template <typename T>
static void CopyVectorToCPUTensor(const std::vector<T> &src,
framework::Tensor *dst) {
dst->Resize({static_cast<int64_t>(src.size())});
T *dst_ptr = dst->mutable_data<T>(platform::CPUPlace());
const T *src_ptr = src.data();
auto nbytes = src.size() * sizeof(T);
std::memcpy(dst_ptr, src_ptr, nbytes);
}

template <typename T>
class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
Expand Down Expand Up @@ -677,14 +687,14 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
lengths.back());
}

CopyVectorToTensor(
CopyVectorToCPUTensor(numel_offsets,
ctx.Output<framework::Tensor>("FusedParamOffsets"));
CopyVectorToCPUTensor(
fp32_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"), place,
stream);
CopyVectorToTensor(
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"));
CopyVectorToCPUTensor(
fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"), place,
stream);
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));

// Fill the weight decay tensor
PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@ class DistributedFusedLambOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ParamInfo") {
return expected_kernel_type;
} else {
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
return expected_kernel_type;
}
};

Expand Down
282 changes: 179 additions & 103 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

#include <cmath>
#include "paddle/fluid/memory/buffer.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
#include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/collective_helper.h"
Expand All @@ -40,6 +42,163 @@ namespace operators {
template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type;

template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}

template <typename T, int BlockDim, int VecSize>
struct L2NormFunctor {
DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size,
const T *x, MasterT<T> *y, int max_chunk_num) const {
using MT = MasterT<T>;
const T *ptr = x + offset;

using BlockReduce = cub::BlockReduce<MT, BlockDim>;
__shared__ typename BlockReduce::TempStorage storage;

MT square_sum = static_cast<MT>(0);
int i;
for (i = threadIdx.x * VecSize; i + VecSize <= size;
i += (BlockDim * VecSize)) {
platform::AlignedVector<T, VecSize> tmp_vec;
platform::Load(ptr + i, &tmp_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
auto tmp = static_cast<MT>(tmp_vec[j]);
square_sum += (tmp * tmp);
}
}

for (; i < size; ++i) {
auto tmp = static_cast<MT>(ptr[i]);
square_sum += (tmp * tmp);
}

square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum());
if (threadIdx.x == 0) {
y[tensor_id * max_chunk_num + chunk_id] = square_sum;
}
}
};

template <typename InT, typename OutT, int BlockDim, bool NeedSqrt>
static __global__ void MultiTensorL2NormReduceAgainCUDAKernel(
const InT *x, OutT *y, int max_chunk_num) {
int tensor_id = blockIdx.x;
x += (tensor_id * max_chunk_num);
using BlockReduce = cub::BlockReduce<InT, BlockDim>;
__shared__ typename BlockReduce::TempStorage storage;
InT sum = static_cast<InT>(0);
for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) {
sum += x[i];
}
sum = BlockReduce(storage).Reduce(sum, cub::Sum());
if (threadIdx.x == 0) {
if (NeedSqrt) {
y[blockIdx.x] = static_cast<OutT>(sqrtf(sum));
} else {
y[blockIdx.x] = static_cast<OutT>(sum);
}
}
}

template <typename T>
static int GetChunkedVecSize(const T *ptr, int chunk_size) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");

constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
auto address = reinterpret_cast<uintptr_t>(ptr);
constexpr int vec8 = alignof(platform::AlignedVector<T, 8>);
constexpr int vec4 = alignof(platform::AlignedVector<T, 4>);
constexpr int vec2 = alignof(platform::AlignedVector<T, 2>);
if (address % vec8 == 0 && chunk_size % vec8 == 0) {
return std::min(8, valid_vec_size);
} else if (address % vec4 == 0 && chunk_size % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0 && chunk_size % vec2 == 0) {
return std::min(2, valid_vec_size);
} else {
return 1;
}
}

#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \
case __vec_size: { \
constexpr int kVecSize = __vec_size; \
__VA_ARGS__; \
break; \
}

#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \
do { \
switch (__vec_size) { \
PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \
PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \
} \
} while (0)

// TODO(zengjinle): which chunk_size is better?
template <typename InT, typename OutT, bool NeedSqrt = false,
int MaxTensorNumPerLaunch = 50, int MaxChunkNumPerLaunch = 680,
int BlockDim = 512>
static void MultiTensorL2Norm(const platform::CUDAPlace &place,
gpuStream_t stream, const InT *x,
const int *offsets, int n, OutT *y,
int chunk_size = 65536) {
if (n <= 0) return;

constexpr int kNumTensor = MaxTensorNumPerLaunch;
constexpr int kNumChunk = MaxChunkNumPerLaunch;
constexpr int kBlockDim = BlockDim;

int max_chunk_num = -1;
int vec_size = 8;
int total_chunk_num = 0;
for (int i = 0; i < n; ++i) {
vec_size = std::min(
vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size));
int length = offsets[i + 1] - offsets[i];
auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size;
max_chunk_num = std::max(max_chunk_num, tmp_chunk_num);
total_chunk_num += tmp_chunk_num;
}

VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num
<< " , total_chunk_num = " << total_chunk_num
<< " , tensor_num = " << n;

using MT = MasterT<InT>;
memory::Buffer tmp_out(place);
auto *tmp_out_ptr = tmp_out.Alloc<MT>(n * max_chunk_num);
FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream);

#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \
do { \
using FunctorT = L2NormFunctor<InT, kBlockDim, kVecSize>; \
VLOG(10) << __func__ << " " << typeid(InT).name() \
<< " VecSize = " << kVecSize; \
MultiTensorApply<FunctorT, kBlockDim, kNumTensor, kNumChunk>( \
FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \
max_chunk_num); \
} while (0)

PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL);
#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL

MultiTensorL2NormReduceAgainCUDAKernel<MT, OutT, kBlockDim,
NeedSqrt><<<n, kBlockDim, 0, stream>>>(
tmp_out_ptr, y, max_chunk_num);
}

template <int LogLevel>
static void LogParamAndTrustRatioDivSquareNorm(
const framework::ExecutionContext &ctx, const float *param_square_norm,
Expand Down Expand Up @@ -620,76 +779,6 @@ static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out,
num_items, reduction_op, init, stream));
}

template <typename InputIteratorT, typename OutputIteratorT,
typename OffsetIteratorT, typename ReductionOp, typename T>
static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out,
int num_segments,
OffsetIteratorT d_begin_offsets,
OffsetIteratorT d_end_offsets,
ReductionOp reduction_op, T initial_value,
gpuStream_t stream,
memory::Buffer *buffer) {
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
}

template <typename T>
struct AddConstantFunctor {
explicit AddConstantFunctor(T bias) : bias_(bias) {}

T operator()(T x) const { return x + bias_; }

private:
T bias_;
};

template <typename T>
struct OffsetWithBiasFunctor {
OffsetWithBiasFunctor(const T *offset, T bias)
: offset_(offset), bias_(bias) {}

HOSTDEVICE T operator()(T idx) const { return offset_[idx] - bias_; }

HOSTDEVICE constexpr bool operator==(const OffsetWithBiasFunctor<T> &) const {
return true;
}

private:
const T *offset_;
const T bias_;
};

template <typename T, typename OffsetT>
static void CubDeviceSegmentedSquareNorm(const T *x, MasterT<T> *y, int n,
const OffsetT *offset,
OffsetT init_offset,
gpuStream_t stream,
memory::Buffer *buffer) {
if (n <= 0) return;
cub::TransformInputIterator<MasterT<T>, SquareFunctor<T>, const T *> iter(
x, SquareFunctor<T>());
if (init_offset == static_cast<OffsetT>(0)) {
CubDeviceSegmentedReduce(iter, y, n, offset, offset + 1, cub::Sum(),
static_cast<MasterT<T>>(0), stream, buffer);
} else {
cub::CountingInputIterator<OffsetT> cnt_iter(0);
OffsetWithBiasFunctor<OffsetT> functor(offset, init_offset);
cub::TransformInputIterator<OffsetT, OffsetWithBiasFunctor<OffsetT>,
cub::CountingInputIterator<OffsetT>>
offset_iter(cnt_iter, functor);
CubDeviceSegmentedReduce(iter, y, n, offset_iter, offset_iter + 1,
cub::Sum(), static_cast<MasterT<T>>(0), stream,
buffer);
}
}

template <typename T>
static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm,
gpuStream_t stream,
Expand Down Expand Up @@ -862,16 +951,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel,
}
}

template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
static_assert(!std::is_same<T, void>::value, "T cannot be void.");
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream));
#endif
}

template <typename T>
class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
Expand Down Expand Up @@ -1191,13 +1270,16 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
fp16_partial_fused_offsets_t->data<int>();

VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(fused_offsets, fused_offsets_t->numel(), place);
<< FlattenToString(fused_offsets, fused_offsets_t->numel(),
fused_offsets_t->place());
VLOG(1) << "FP32ShardFusedParamOffsets: "
<< FlattenToString(fp32_partial_fused_offsets,
fp32_partial_fused_offsets_t->numel(), place);
fp32_partial_fused_offsets_t->numel(),
fp32_partial_fused_offsets_t->place());
VLOG(1) << "FP16ShardFusedParamOffsets: "
<< FlattenToString(fp16_partial_fused_offsets,
fp16_partial_fused_offsets_t->numel(), place);
fp16_partial_fused_offsets_t->numel(),
fp16_partial_fused_offsets_t->place());

if (num_devices > 1) {
if (use_master_param_norm) {
Expand All @@ -1207,32 +1289,26 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
}
}
CubDeviceSegmentedSquareNorm(fp32_param, param_square_norm,
fp32_global_param_num, fused_offsets, 0,
stream, &cub_tmp_buffer);
MultiTensorL2Norm(place, stream, fp32_param, fused_offsets,
fp32_global_param_num, param_square_norm);
if (use_master_param_norm) {
CubDeviceSegmentedSquareNorm(
master_param + fp16_offset, param_square_norm + fp16_local_start_idx,
fp16_local_param_num, fp16_partial_fused_offsets, 0, stream,
&cub_tmp_buffer);
MultiTensorL2Norm(place, stream, master_param + fp16_offset,
fp16_partial_fused_offsets, fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
} else {
// NOTE: extra computation is performed. We can improve this performance
// if needed in the future.
CubDeviceSegmentedSquareNorm(
fp16_param, param_square_norm + fp32_global_param_num,
fp16_global_param_num, fused_offsets + fp32_global_param_num,
static_cast<int>(fp32_numel), stream, &cub_tmp_buffer);
MultiTensorL2Norm(
place, stream, fp16_param, fused_offsets + fp32_global_param_num,
fp16_global_param_num, param_square_norm + fp32_global_param_num);
}

CubDeviceSegmentedSquareNorm(
trust_ratio_div, trust_ratio_div_square_norm + fp32_local_start_idx,
fp32_local_param_num, fp32_partial_fused_offsets, 0, stream,
&cub_tmp_buffer);
CubDeviceSegmentedSquareNorm(
trust_ratio_div + fp32_numel_each_device,
trust_ratio_div_square_norm + fp16_local_start_idx,
fp16_local_param_num, fp16_partial_fused_offsets, 0, stream,
&cub_tmp_buffer);
MultiTensorL2Norm(place, stream, trust_ratio_div,
fp32_partial_fused_offsets, fp32_local_param_num,
trust_ratio_div_square_norm + fp32_local_start_idx);
MultiTensorL2Norm(place, stream, trust_ratio_div + fp32_numel_each_device,
fp16_partial_fused_offsets, fp16_local_param_num,
trust_ratio_div_square_norm + fp16_local_start_idx);

VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
<< FlattenToString(trust_ratio_div_square_norm, param_num, place);
Expand Down
Loading

0 comments on commit d32a010

Please sign in to comment.