Skip to content

Commit

Permalink
remove useless FLAGS
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Feb 24, 2022
1 parent 0e32d82 commit b6ea3dc
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/math_function.h"

DECLARE_bool(use_multi_tensor_apply);

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -689,25 +687,14 @@ class DistributedFusedLambInitOpKernel<platform::CUDADeviceContext, T>
lengths.back());
}

if (!FLAGS_use_multi_tensor_apply) {
CopyVectorToTensor(
fp32_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"), place,
stream);
CopyVectorToTensor(
fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"), place,
stream);
} else {
CopyVectorToCPUTensor(numel_offsets,
ctx.Output<framework::Tensor>("FusedParamOffsets"));
CopyVectorToCPUTensor(
fp32_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"));
CopyVectorToCPUTensor(
fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));
}
CopyVectorToCPUTensor(numel_offsets,
ctx.Output<framework::Tensor>("FusedParamOffsets"));
CopyVectorToCPUTensor(
fp32_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP32ShardFusedParamOffsets"));
CopyVectorToCPUTensor(
fp16_partial_numel_offsets,
ctx.Output<framework::Tensor>("FP16ShardFusedParamOffsets"));

// Fill the weight decay tensor
PADDLE_ENFORCE_EQ(lengths.size(), shard_weight_decay.size(),
Expand Down
101 changes: 14 additions & 87 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
namespace cub = hipcub;
#endif

DECLARE_bool(use_multi_tensor_apply);

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -781,72 +779,6 @@ static void CubDeviceReduce(InputIteratorT d_in, OutputIteratorT d_out,
num_items, reduction_op, init, stream));
}

template <typename InputIteratorT, typename OutputIteratorT,
typename OffsetIteratorT, typename ReductionOp, typename T>
static void CubDeviceSegmentedReduce(InputIteratorT d_in, OutputIteratorT d_out,
int num_segments,
OffsetIteratorT d_begin_offsets,
OffsetIteratorT d_end_offsets,
ReductionOp reduction_op, T initial_value,
gpuStream_t stream,
memory::Buffer *buffer) {
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedReduce::Reduce(
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments,
d_begin_offsets, d_end_offsets, reduction_op, initial_value, stream));
}

template <typename T>
struct OffsetWithBiasFunctor {
OffsetWithBiasFunctor(const T *offset, T bias)
: offset_(offset), bias_(bias) {}

HOSTDEVICE T operator()(T idx) const { return offset_[idx] - bias_; }

HOSTDEVICE constexpr bool operator==(const OffsetWithBiasFunctor<T> &) const {
return true;
}

private:
const T *offset_;
const T bias_;
};

template <typename T, typename OffsetT>
static void DeviceSegmentedSquareNorm(const platform::CUDAPlace &place,
gpuStream_t stream, const T *x,
MasterT<T> *y, int n,
const OffsetT *offset,
OffsetT init_offset,
memory::Buffer *buffer) {
if (!FLAGS_use_multi_tensor_apply) {
if (n <= 0) return;
cub::TransformInputIterator<MasterT<T>, SquareFunctor<T>, const T *> iter(
x, SquareFunctor<T>());
if (init_offset == static_cast<OffsetT>(0)) {
CubDeviceSegmentedReduce(iter, y, n, offset, offset + 1, cub::Sum(),
static_cast<MasterT<T>>(0), stream, buffer);
} else {
cub::CountingInputIterator<OffsetT> cnt_iter(0);
OffsetWithBiasFunctor<OffsetT> functor(offset, init_offset);
cub::TransformInputIterator<OffsetT, OffsetWithBiasFunctor<OffsetT>,
cub::CountingInputIterator<OffsetT>>
offset_iter(cnt_iter, functor);
CubDeviceSegmentedReduce(iter, y, n, offset_iter, offset_iter + 1,
cub::Sum(), static_cast<MasterT<T>>(0), stream,
buffer);
}
return;
}

MultiTensorL2Norm(place, stream, x, offset, n, y);
}

template <typename T>
static void GetSquareGradNormImpl(const T *grad, int n, float *square_norm,
gpuStream_t stream,
Expand Down Expand Up @@ -1357,31 +1289,26 @@ class DistributedFusedLambOpKernel<platform::CUDADeviceContext, T>
FillZeroWithPtr(trust_ratio_div_square_norm, param_num, stream);
}
}
DeviceSegmentedSquareNorm(place, stream, fp32_param, param_square_norm,
fp32_global_param_num, fused_offsets, 0,
&cub_tmp_buffer);
MultiTensorL2Norm(place, stream, fp32_param, fused_offsets,
fp32_global_param_num, param_square_norm);
if (use_master_param_norm) {
DeviceSegmentedSquareNorm(place, stream, master_param + fp16_offset,
param_square_norm + fp16_local_start_idx,
fp16_local_param_num,
fp16_partial_fused_offsets, 0, &cub_tmp_buffer);
MultiTensorL2Norm(place, stream, master_param + fp16_offset,
fp16_partial_fused_offsets, fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
} else {
// NOTE: extra computation is performed. We can improve this performance
// if needed in the future.
DeviceSegmentedSquareNorm(
place, stream, fp16_param, param_square_norm + fp32_global_param_num,
fp16_global_param_num, fused_offsets + fp32_global_param_num,
static_cast<int>(fp32_numel), &cub_tmp_buffer);
MultiTensorL2Norm(
place, stream, fp16_param, fused_offsets + fp32_global_param_num,
fp16_global_param_num, param_square_norm + fp32_global_param_num);
}

DeviceSegmentedSquareNorm(
place, stream, trust_ratio_div,
trust_ratio_div_square_norm + fp32_local_start_idx,
fp32_local_param_num, fp32_partial_fused_offsets, 0, &cub_tmp_buffer);
DeviceSegmentedSquareNorm(
place, stream, trust_ratio_div + fp32_numel_each_device,
trust_ratio_div_square_norm + fp16_local_start_idx,
fp16_local_param_num, fp16_partial_fused_offsets, 0, &cub_tmp_buffer);
MultiTensorL2Norm(place, stream, trust_ratio_div,
fp32_partial_fused_offsets, fp32_local_param_num,
trust_ratio_div_square_norm + fp32_local_start_idx);
MultiTensorL2Norm(place, stream, trust_ratio_div + fp32_numel_each_device,
fp16_partial_fused_offsets, fp16_local_param_num,
trust_ratio_div_square_norm + fp16_local_start_idx);

VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
<< FlattenToString(trust_ratio_div_square_norm, param_num, place);
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -773,5 +773,3 @@ DEFINE_bool(enable_ins_parser_file, false,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
#endif

PADDLE_DEFINE_EXPORTED_bool(use_multi_tensor_apply, false, "");

1 comment on commit b6ea3dc

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.