Skip to content

Commit

Permalink
vec scale kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Mar 1, 2022
1 parent 75280d3 commit 1907197
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,30 @@ struct AndFunctor {
HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; }
};

template <typename T1, typename T2>
template <typename T1, typename T2, int VecSize>
static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x,
const T2 *__restrict__ scale,
T1 *__restrict__ y, int num) {
static_assert(sizeof(T1) <= sizeof(T2),
"sizeof(T1) must be not greater than sizeof(T2).");
T2 s = scale[0];
CUDA_KERNEL_LOOP(i, num) {

int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
int stride = blockDim.x * gridDim.x * VecSize;

for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T1, VecSize> x_vec;
platform::AlignedVector<T1, VecSize> y_vec;

platform::Load(x + i, &x_vec);
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
y_vec[j] = static_cast<T1>(static_cast<T2>(x_vec[j]) * s);
}
platform::Store(y_vec, y + i);
}

for (; i < num; ++i) {
y[i] = static_cast<T1>(static_cast<T2>(x[i]) * s);
}
}
Expand Down Expand Up @@ -396,7 +412,6 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
for (; i + VecSize <= num; i += stride) {
platform::AlignedVector<T, VecSize> param_vec;
platform::AlignedVector<GradT, VecSize> grad_vec;
platform::AlignedVector<T, VecSize> weight_decay_vec;
platform::AlignedVector<T, VecSize> mom1_vec;
platform::AlignedVector<T, VecSize> mom2_vec;
platform::AlignedVector<T, VecSize> trust_ratio_div_vec;
Expand Down Expand Up @@ -760,6 +775,24 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
return false;
}

template <typename T1, typename T2>
static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx,
const T1 *x, const T2 *scale, T1 *y, int n,
gpuStream_t stream) {
int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0));
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size);

#define PD_LAMB_VEC_SCALE_KERNEL_CASE \
do { \
ScaleCUDAKernel<T1, T2, kVecSize><<<config.block_per_grid, \
config.thread_per_block, 0, stream>>>( \
x, scale, y, n); \
} while (0)

PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAMB_VEC_SCALE_KERNEL_CASE);
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE
}

template <typename T>
static void NCCLReduceScatterWithScale(
const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks,
Expand All @@ -775,10 +808,8 @@ static void NCCLReduceScatterWithScale(
PADDLE_ENFORCE_EQ(nranks, 1,
platform::errors::InvalidArgument(
"nranks must be 1 when scale != nullptr."));
auto numel = recvcount * nranks;
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel);
ScaleCUDAKernel<<<config.block_per_grid, config.thread_per_block, 0,
stream>>>(sendbuff, scale, recvbuff, numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, recvcount * nranks,
stream);
}
return;
}
Expand All @@ -792,9 +823,7 @@ static void NCCLReduceScatterWithScale(
if (scale && !should_destroy_op) {
size_t numel = recvcount * nranks;
T *new_sendbuff = buffer.Alloc<T>(numel);
auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel);
ScaleCUDAKernel<<<config.block_per_grid, config.thread_per_block, 0,
stream>>>(sendbuff, scale, new_sendbuff, numel);
LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream);
sendbuff = new_sendbuff;
}

Expand Down

0 comments on commit 1907197

Please sign in to comment.