diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index b0e6d7eeaa933..7db761aa4bb8c 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -14,6 +14,7 @@ #include #include "paddle/fluid/memory/buffer.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/optimizers/cast_with_ptr.h" #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" #include "paddle/fluid/operators/optimizers/multi_tensor_apply.h" @@ -43,6 +44,163 @@ namespace operators { template using MasterT = typename details::MPTypeTrait::Type; +template +static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { + static_assert(!std::is_same::value, "T cannot be void."); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream)); +#endif +} + +template +struct L2NormFunctor { + DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size, + const T *x, MasterT *y, int max_chunk_num) const { + using MT = MasterT; + const T *ptr = x + offset; + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage storage; + + MT square_sum = static_cast(0); + int i; + for (i = threadIdx.x * VecSize; i + VecSize <= size; + i += (BlockDim * VecSize)) { + platform::AlignedVector tmp_vec; + platform::Load(ptr + i, &tmp_vec); +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + auto tmp = static_cast(tmp_vec[j]); + square_sum += (tmp * tmp); + } + } + + for (; i < size; ++i) { + auto tmp = static_cast(ptr[i]); + square_sum += (tmp * tmp); + } + + square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum()); + if (threadIdx.x == 0) { + y[tensor_id * max_chunk_num + chunk_id] = square_sum; + } + } +}; + +template +static __global__ void MultiTensorL2NormReduceAgainCUDAKernel( + const InT *x, OutT *y, int max_chunk_num) { + int tensor_id = blockIdx.x; + x += (tensor_id * max_chunk_num); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage storage; + InT sum = static_cast(0); + for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) { + sum += x[i]; + } + sum = BlockReduce(storage).Reduce(sum, cub::Sum()); + if (threadIdx.x == 0) { + if (NeedSqrt) { + y[blockIdx.x] = static_cast(sqrtf(sum)); + } else { + y[blockIdx.x] = static_cast(sum); + } + } +} + +template +static int GetChunkedVecSize(const T *ptr, int chunk_size) { + static_assert(!std::is_same::value, "T cannot be void."); + + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); + auto address = reinterpret_cast(ptr); + constexpr int vec8 = alignof(platform::AlignedVector); + constexpr int vec4 = alignof(platform::AlignedVector); + constexpr int vec2 = alignof(platform::AlignedVector); + if (address % vec8 == 0 && chunk_size % vec8 == 0) { + return std::min(8, valid_vec_size); + } else if (address % vec4 == 0 && chunk_size % vec4 == 0) { + return std::min(4, valid_vec_size); + } else if (address % vec2 == 0 && chunk_size % vec2 == 0) { + return std::min(2, valid_vec_size); + } else { + return 1; + } +} + +#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \ + case __vec_size: { \ + constexpr int kVecSize = __vec_size; \ + __VA_ARGS__; \ + break; \ + } + +#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \ + do { \ + switch (__vec_size) { \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \ + PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \ + } \ + } while (0) + +// TODO(zengjinle): which chunk_size is better? +template +static void MultiTensorL2Norm(const platform::CUDAPlace &place, + gpuStream_t stream, const InT *x, + const int *offsets, int n, OutT *y, + int chunk_size = 65536) { + if (n <= 0) return; + + constexpr int kNumTensor = MaxTensorNumPerLaunch; + constexpr int kNumChunk = MaxChunkNumPerLaunch; + constexpr int kBlockDim = BlockDim; + + int max_chunk_num = -1; + int vec_size = 8; + int total_chunk_num = 0; + for (int i = 0; i < n; ++i) { + vec_size = std::min( + vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size)); + int length = offsets[i + 1] - offsets[i]; + auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size; + max_chunk_num = std::max(max_chunk_num, tmp_chunk_num); + total_chunk_num += tmp_chunk_num; + } + + VLOG(1) << "MultiTensorL2Norm max_chunk_num = " << max_chunk_num + << " , total_chunk_num = " << total_chunk_num + << " , tensor_num = " << n; + + using MT = MasterT; + memory::Buffer tmp_out(place); + auto *tmp_out_ptr = tmp_out.Alloc(n * max_chunk_num); + FillZeroWithPtr(tmp_out_ptr, n * max_chunk_num, stream); + +#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \ + do { \ + using FunctorT = L2NormFunctor; \ + VLOG(10) << __func__ << " " << typeid(InT).name() \ + << " VecSize = " << kVecSize; \ + MultiTensorApply( \ + FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \ + max_chunk_num); \ + } while (0) + + PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL); +#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL + + MultiTensorL2NormReduceAgainCUDAKernel<<>>( + tmp_out_ptr, y, max_chunk_num); +} + template static void LogParamAndTrustRatioDivSquareNorm( const framework::ExecutionContext &ctx, const float *param_square_norm, @@ -643,16 +801,6 @@ static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out, d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream)); } -template -struct AddConstantFunctor { - explicit AddConstantFunctor(T bias) : bias_(bias) {} - - T operator()(T x) const { return x + bias_; } - - private: - T bias_; -}; - template struct OffsetWithBiasFunctor { OffsetWithBiasFunctor(const T *offset, T bias) @@ -670,12 +818,14 @@ struct OffsetWithBiasFunctor { }; template -static void DeviceSegmentedSquareNorm( - const platform::CUDADeviceContext &dev_ctx, const T *x, MasterT *y, - int n, const OffsetT *offset, OffsetT init_offset, memory::Buffer *buffer) { +static void DeviceSegmentedSquareNorm(const platform::CUDAPlace &place, + gpuStream_t stream, const T *x, + MasterT *y, int n, + const OffsetT *offset, + OffsetT init_offset, + memory::Buffer *buffer) { if (!FLAGS_use_multi_tensor_apply) { if (n <= 0) return; - auto stream = dev_ctx.stream(); cub::TransformInputIterator, SquareFunctor, const T *> iter( x, SquareFunctor()); if (init_offset == static_cast(0)) { @@ -694,7 +844,7 @@ static void DeviceSegmentedSquareNorm( return; } - MultiTensorL2Norm(dev_ctx, x, offset, n, y); + MultiTensorL2Norm(place, stream, x, offset, n, y); } template @@ -869,16 +1019,6 @@ static void CheckHasNanInfGrad(const float *fp32_grad, int fp32_numel, } } -template -static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { - static_assert(!std::is_same::value, "T cannot be void."); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(x, 0, n * sizeof(T), stream)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(x, 0, n * sizeof(T), stream)); -#endif -} - template class DistributedFusedLambOpKernel : public framework::OpKernel { @@ -1217,11 +1357,11 @@ class DistributedFusedLambOpKernel FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream); } } - DeviceSegmentedSquareNorm(dev_ctx, fp32_param, param_square_norm, + DeviceSegmentedSquareNorm(place, stream, fp32_param, param_square_norm, fp32_global_param_num, fused_offsets, 0, &cub_tmp_buffer); if (use_master_param_norm) { - DeviceSegmentedSquareNorm(dev_ctx, master_param + fp16_offset, + DeviceSegmentedSquareNorm(place, stream, master_param + fp16_offset, param_square_norm + fp16_local_start_idx, fp16_local_param_num, fp16_partial_fused_offsets, 0, &cub_tmp_buffer); @@ -1229,17 +1369,17 @@ class DistributedFusedLambOpKernel // NOTE: extra computation is performed. We can improve this performance // if needed in the future. DeviceSegmentedSquareNorm( - dev_ctx, fp16_param, param_square_norm + fp32_global_param_num, + place, stream, fp16_param, param_square_norm + fp32_global_param_num, fp16_global_param_num, fused_offsets + fp32_global_param_num, static_cast(fp32_numel), &cub_tmp_buffer); } DeviceSegmentedSquareNorm( - dev_ctx, trust_ratio_div, + place, stream, trust_ratio_div, trust_ratio_div_square_norm + fp32_local_start_idx, fp32_local_param_num, fp32_partial_fused_offsets, 0, &cub_tmp_buffer); DeviceSegmentedSquareNorm( - dev_ctx, trust_ratio_div + fp32_numel_each_device, + place, stream, trust_ratio_div + fp32_numel_each_device, trust_ratio_div_square_norm + fp16_local_start_idx, fp16_local_param_num, fp16_partial_fused_offsets, 0, &cub_tmp_buffer); diff --git a/paddle/fluid/operators/optimizers/multi_tensor_apply.h b/paddle/fluid/operators/optimizers/multi_tensor_apply.h index cbf1e1381211b..5d8d03c733dae 100644 --- a/paddle/fluid/operators/optimizers/multi_tensor_apply.h +++ b/paddle/fluid/operators/optimizers/multi_tensor_apply.h @@ -14,45 +14,73 @@ #pragma once -#ifdef __NVCC__ -#include "cub/cub.cuh" +#include #include "math.h" // NOLINT -#endif - -#ifdef __HIPCC__ -#include -#include "math.h" // NOLINT -namespace cub = hipcub; -#endif -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/platform/aligned_vector.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace operators { -template +template struct TensorMetaList { - static constexpr int kTensorNum = NumTensor; - static constexpr int kChunkNum = NumChunk; + static constexpr int kTensorNum = MaxTensorNumPerLaunch; + static constexpr int kChunkNum = MaxChunkNumPerLaunch; static_assert(kTensorNum > 0 && kTensorNum < 256, "kTensorNum must be inside (0, 256)."); static_assert(kChunkNum > 0 && kChunkNum < 65536, "kChunkNum must be inside (0, 65536)."); + /** + * The tensor numel offset of each tensor. + * The offsets[0] would be always 0 in the first launch, + * and then offsets[0] >= 0 in the following other launches. + * The numel of the i-th tensor would be offsets[i + 1] - offsets[i]. + */ int offsets[kTensorNum + 1]; + + /** + * The tensor id of each chunk. The tensor_ids[0] is always 0. + * Note that tensor_ids would be always in the ascending order. + * The actual tensor id is start_tensor_id + tensor_ids[i]. + * + * The reason why we assume that the actual tensor id is + * start_tensor_id + tensor_ids[i] is to make tensor_ids to be + * a uint8_t array instead of an int array, making sizeof(TensorMetaList) + * smaller, so that kChunkNum can be larger. + */ uint8_t tensor_ids[kChunkNum]; + + /** + * The chunk id of the chunk inside each tensor. It would be + * something like chunk_ids = [0, 1, 2, 0, 0, 1, 2, 3], meaning + * that there are 3 tensors and each tensor contains 3, 1 and 4 + * chunks. Note that chunk_ids[0] is always 0 and the actual + * chunk id of the first tensor is always start_chunk_id + chunk_ids[i]. + * + * The reason why we assume that the actual chunk id of the first + * tensor is always start_chunk_id + chunk_ids[i] is to make + * chunk_ids to be a uint16_t array instead of an int array, making + * sizeof(TensorMetaList) smaller, so that kChunkNum can be larger. + */ uint16_t chunk_ids[kChunkNum]; + + /** + * The tensor_ids offset. + */ int start_tensor_id; + + /** + * The chunk_ids offset. + */ int start_chunk_id; }; -template +template static __global__ void MultiTensorApplyCUDAKernel( - Functor functor, TensorMetaList meta, int chunk_size, - Args... args) { + Functor functor, + TensorMetaList meta, + int chunk_size, Args... args) { const int block_id = blockIdx.x; const int tensor_id = meta.tensor_ids[block_id]; const int chunk_id = static_cast(meta.chunk_ids[block_id]) + @@ -66,19 +94,15 @@ static __global__ void MultiTensorApplyCUDAKernel( args...); } -template -static std::string ToString(const T *x, int n) { - std::vector vec(x, x + n); - return "[" + string::join_strings(vec, ", ") + "]"; -} - -template +template static void MultiTensorApply(Functor functor, gpuStream_t stream, const int *offsets, int n, int chunk_size, Args... args) { if (n == 0) return; + constexpr auto NumTensor = MaxTensorNumPerLaunch; + constexpr auto NumChunk = MaxChunkNumPerLaunch; TensorMetaList metas; int tensor_id = 0; @@ -128,154 +152,5 @@ static void MultiTensorApply(Functor functor, gpuStream_t stream, } } -template -struct L2NormFunctor { - using MT = typename details::MPTypeTrait::Type; - - DEVICE void operator()(int tensor_id, int chunk_id, int offset, int size, - const T *x, MT *y, int max_chunk_num) const { - const T *ptr = x + offset; - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage storage; - - MT square_sum = static_cast(0); - int i; - for (i = threadIdx.x * VecSize; i + VecSize <= size; - i += (BlockDim * VecSize)) { - platform::AlignedVector tmp_vec; - platform::Load(ptr + i, &tmp_vec); -#pragma unroll - for (int j = 0; j < VecSize; ++j) { - auto tmp = static_cast(tmp_vec[j]); - square_sum += (tmp * tmp); - } - } - - for (; i < size; ++i) { - auto tmp = static_cast(ptr[i]); - square_sum += (tmp * tmp); - } - - square_sum = BlockReduce(storage).Reduce(square_sum, cub::Sum()); - if (threadIdx.x == 0) { - y[tensor_id * max_chunk_num + chunk_id] = square_sum; - } - } -}; - -template -static __global__ void MultiTensorL2NormReduceAgainCUDAKernel( - const InT *x, OutT *y, int max_chunk_num) { - int tensor_id = blockIdx.x; - x += (tensor_id * max_chunk_num); - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage storage; - InT sum = static_cast(0); - for (int i = threadIdx.x; i < max_chunk_num; i += BlockDim) { - sum += x[i]; - } - sum = BlockReduce(storage).Reduce(sum, cub::Sum()); - if (threadIdx.x == 0) { - if (NeedSqrt) { - y[blockIdx.x] = static_cast(sqrtf(sum)); - } else { - y[blockIdx.x] = static_cast(sum); - } - } -} - -template -static int GetChunkedVecSize(const T *ptr, int chunk_size) { - static_assert(!std::is_same::value, "T cannot be void."); - - constexpr int max_load_bits = 128; - int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); - auto address = reinterpret_cast(ptr); - constexpr int vec8 = alignof(platform::AlignedVector); - constexpr int vec4 = alignof(platform::AlignedVector); - constexpr int vec2 = alignof(platform::AlignedVector); - if (address % vec8 == 0 && chunk_size % vec8 == 0) { - return std::min(8, valid_vec_size); - } else if (address % vec4 == 0 && chunk_size % vec4 == 0) { - return std::min(4, valid_vec_size); - } else if (address % vec2 == 0 && chunk_size % vec2 == 0) { - return std::min(2, valid_vec_size); - } else { - return 1; - } -} - -#define PD_VEC_MULTI_TENSOR_APPLY_CASE(__vec_size, ...) \ - case __vec_size: { \ - constexpr int kVecSize = __vec_size; \ - __VA_ARGS__; \ - break; \ - } - -#define PD_VEC_MULTI_TENSOR_APPLY(__vec_size, ...) \ - do { \ - switch (__vec_size) { \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(8, __VA_ARGS__); \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(4, __VA_ARGS__); \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(2, __VA_ARGS__); \ - PD_VEC_MULTI_TENSOR_APPLY_CASE(1, __VA_ARGS__); \ - } \ - } while (0) - -template -static void MultiTensorL2Norm(const platform::CUDADeviceContext &dev_ctx, - const InT *x, const int *offsets, int n, - OutT *y) { - if (n == 0) return; - - constexpr int kNumTensor = 110; - constexpr int kNumChunk = 320; - constexpr int kBlockDim = 512; - - // TODO(zengjinle): which chunk_size is better? - constexpr int chunk_size = 2048 * 32; - - int max_chunk_num = -1; - int vec_size = 8; - - for (int i = 0; i < n; ++i) { - vec_size = std::min( - vec_size, GetChunkedVecSize(x + offsets[i] - offsets[0], chunk_size)); - int length = offsets[i + 1] - offsets[i]; - auto tmp_chunk_num = (length + chunk_size - 1) / chunk_size; - max_chunk_num = std::max(max_chunk_num, tmp_chunk_num); - } - - using MT = typename details::MPTypeTrait::Type; - auto place = dev_ctx.GetPlace(); - auto stream = dev_ctx.stream(); - memory::Buffer tmp_out(place); - auto *tmp_out_ptr = tmp_out.Alloc(n * max_chunk_num); - auto nbytes = n * max_chunk_num * sizeof(MT); -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(tmp_out_ptr, 0, nbytes, stream)); -#else - PADDLE_ENFORCE_GPU_SUCCESS(hipMemsetAsync(tmp_out_ptr, 0, nbytes, stream)); -#endif - -#define PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL \ - do { \ - using FunctorT = L2NormFunctor; \ - VLOG(10) << __func__ << " " << typeid(InT).name() \ - << " VecSize = " << kVecSize; \ - MultiTensorApply( \ - FunctorT(), stream, offsets, n, chunk_size, x, tmp_out_ptr, \ - max_chunk_num); \ - } while (0) - - PD_VEC_MULTI_TENSOR_APPLY(vec_size, PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL); -#undef PD_LAUNCH_MULTI_TENSOR_APPLY_KERNEL - - MultiTensorL2NormReduceAgainCUDAKernel<<>>( - tmp_out_ptr, y, max_chunk_num); -} - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 319c274bd51b0..fc9b4a1134a47 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -762,8 +762,6 @@ DEFINE_bool(enable_slotrecord_reset_shrink, false, DEFINE_bool(enable_ins_parser_file, false, "enable parser ins file , default false"); -PADDLE_DEFINE_EXPORTED_bool(use_multi_tensor_apply, false, ""); - /** * ProcessGroupNCCL related FLAG * Name: nccl_blocking_wait @@ -775,3 +773,5 @@ PADDLE_DEFINE_EXPORTED_bool(use_multi_tensor_apply, false, ""); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); #endif + +PADDLE_DEFINE_EXPORTED_bool(use_multi_tensor_apply, false, "");