diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 42477232e7ca1..3e7023bd1260f 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -14,7 +14,29 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/optimizers/lars_momentum_op.h" +#include "paddle/fluid/platform/fast_divmod.h" + +#if defined(__NVCC__) && CUDA_VERSION >= 11000 +/* Once CUDA_VERSION is beyond 11.0, cooperative_groups can be involved in + without adding --rdc=true compile flag, then L2_norm cuda kernel can be + set as a __device__ kernel rather than global kernel. On the contrary, + the compile flag shall be set in old version, which may affect the cuda + kernel performance in paddle, consequently, L2_norm kernel shall be set + as a __global__ kernel. +*/ +#include +#define LARS_FUNCTION_FLAG __device__ +#else +#define LARS_FUNCTION_FLAG __global__ +#endif + +#ifdef __HIPCC__ +#define LARS_BLOCK_SIZE 256 +#else +#define LARS_BLOCK_SIZE 512 +#endif namespace paddle { namespace operators { @@ -22,55 +44,207 @@ namespace operators { template using MultiPrecisionType = typename details::MPTypeTrait::Type; +__device__ __forceinline__ float Sqrt(float x) { return sqrtf(x); } +__device__ __forceinline__ double Sqrt(double x) { return sqrt(x); } +__device__ __forceinline__ float Fma(float x, float y, float z) { + return fmaf(x, y, z); +} +__device__ __forceinline__ double Fma(double x, double y, double z) { + return fma(x, y, z); +} + +template +__device__ inline void VectorizeLarsUpdate( + const T* __restrict__ grad, const MT* __restrict__ param, + const MT* __restrict__ velocity, T* __restrict__ param_out, + MT* __restrict__ velocity_out, const MT mu, MT local_lr, + const MT lars_weight_decay, const MT rescale_grad, const int tid, + const int grid_stride, const int numel, + MT* __restrict__ master_param_out = nullptr) { + using VecType = paddle::platform::AlignedVector; + using VecMType = paddle::platform::AlignedVector; + int main = numel >> (VecSize >> 1); + int tail_offset = main * VecSize; + + const VecType* __restrict__ grad_vec = reinterpret_cast(grad); + const VecMType* __restrict__ param_vec = + reinterpret_cast(param); + const VecMType* __restrict__ velocity_vec = + reinterpret_cast(velocity); + VecType* param_out_vec = reinterpret_cast(param_out); + VecMType* velocity_out_vec = reinterpret_cast(velocity_out); + + VecMType* master_param_out_vec; + if (IsAmp) { + master_param_out_vec = reinterpret_cast(master_param_out); + } + + for (int i = tid; i < main; i += grid_stride) { + VecType param_out_tmp; + VecMType velocity_tmp, param_tmp; + VecType grad_data = grad_vec[i]; + VecMType param_data = param_vec[i]; + VecMType velocity_data = velocity_vec[i]; + +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + MT grad_val = static_cast(grad_data[j]) * rescale_grad; + velocity_tmp[j] = + Fma(velocity_data[j], mu, + local_lr * Fma(lars_weight_decay, param_data[j], grad_val)); + param_tmp[j] = param_data[j] - velocity_tmp[j]; + param_out_tmp[j] = static_cast(param_tmp[j]); + } + param_out_vec[i] = param_out_tmp; + velocity_out_vec[i] = velocity_tmp; + if (IsAmp) { + master_param_out_vec[i] = param_tmp; + } + } + + for (int i = tid + tail_offset; i < numel; i += grid_stride) { + MT grad_val = static_cast(grad[i]) * rescale_grad; + MT param_val = param[i]; + MT velocity_tmp = Fma(velocity[i], mu, local_lr * Fma(lars_weight_decay, + param_val, grad_val)); + MT param_tmp = param_val - velocity_tmp; + param_out[i] = static_cast(param_tmp); + velocity_out[i] = velocity_tmp; + if (IsAmp) { + master_param_out[i] = param_tmp; + } + } +} + template -__global__ void MomentumLarsKernel( - const T* p, const T* g, const MT* v, - const MultiPrecisionType* learning_rate, const MT mu, const int64_t num, - const MT lars_coeff, const MT lars_weight_decay, - const MultiPrecisionType* p_norm, const MultiPrecisionType* g_norm, - T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out, - const MultiPrecisionType rescale_grad) { - const MT lr = static_cast(learning_rate[0]); - MT local_lr = lr; - const MT p_n = static_cast(p_norm[0]); - const MT g_n = static_cast(g_norm[0]); +LARS_FUNCTION_FLAG void L2NormKernel( + const T* __restrict__ p_data, const T* __restrict__ g_data, + MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, + const int repeat_times, const int64_t numel, const MT rescale_grad, + MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int grid_stride = LARS_BLOCK_SIZE * gridDim.x; + const MT rescale_grad_pow = rescale_grad * rescale_grad; + __shared__ MT s_buffer[2]; + s_buffer[0] = static_cast(0); + s_buffer[1] = static_cast(0); + MT p_tmp_val = static_cast(0); + MT g_tmp_val = static_cast(0); - if (lars_weight_decay > static_cast(0) && p_n > static_cast(0) && - g_n > static_cast(0)) { - local_lr = - lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon); + if (repeat_times == 0) { + if (tid < numel) { + p_tmp_val = static_cast(p_data[tid]); + g_tmp_val = static_cast(g_data[tid]); + } + s_buffer[0] += math::blockReduceSum(p_tmp_val * p_tmp_val, FINAL_MASK); + s_buffer[1] += math::blockReduceSum(g_tmp_val * g_tmp_val, FINAL_MASK); + } else { + /* To avoid occupy too much temp buffer. Hence, slice the whole data into 2 + parts, the front of them whose quantity is excatly multiple of grid-thread + number, and this part of data is delt in for loop, the rest of data is delt + with another step to avoid visiting data address beyond bound. */ + for (int i = 0; i < repeat_times; ++i) { + p_tmp_val = static_cast(p_data[tid]); + g_tmp_val = static_cast(g_data[tid]); + tid += grid_stride; + s_buffer[0] += + math::blockReduceSum(p_tmp_val * p_tmp_val, FINAL_MASK); + s_buffer[1] += + math::blockReduceSum(g_tmp_val * g_tmp_val, FINAL_MASK); + __syncthreads(); + } + MT p_val = 0; + MT g_val = 0; + if (tid < numel) { + p_val = static_cast(p_data[tid]); + g_val = static_cast(g_data[tid]); + } + s_buffer[0] += math::blockReduceSum(p_val * p_val, FINAL_MASK); + s_buffer[1] += math::blockReduceSum(g_val * g_val, FINAL_MASK); } - CUDA_KERNEL_LOOP(i, num) { - MT grad = static_cast(g[i]) * static_cast(rescale_grad); - MT param = master_p ? master_p[i] : static_cast(p[i]); + __syncthreads(); + + if (threadIdx.x == 0) { + p_buffer[blockIdx.x] = s_buffer[0]; + g_buffer[blockIdx.x] = s_buffer[1]; + } + +#if CUDA_VERSION >= 11000 + // Grid sync for completely writring partial result back to gloabl memory + const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); + cg.sync(); + MT p_partial_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; + MT g_partial_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; + *p_n = Sqrt(math::blockReduceSum(p_partial_sum, FINAL_MASK)); + *g_n = Sqrt(rescale_grad_pow * + math::blockReduceSum(g_partial_sum, FINAL_MASK)); +#endif +} - MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param); - MT p_new = param - v_new; +template +__global__ void MomentumLarsKernel( + const T* __restrict__ param, const T* __restrict__ grad, + const MT* __restrict__ velocity, T* param_out, MT* velocity_out, + const MT* __restrict__ master_param, MT* __restrict__ master_param_out, + const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer, + MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, + const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, + const int repeat_times, const int thresh, const int64_t numel) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int grid_stride = gridDim.x * LARS_BLOCK_SIZE; +#if CUDA_VERSION >= 11000 + MT param_norm = static_cast(0); + MT grad_norm = static_cast(0); + L2NormKernel(param, grad, p_buffer, g_buffer, repeat_times, numel, + rescale_grad, ¶m_norm, &grad_norm); +#else + const MT rescale_grad_pow = rescale_grad * rescale_grad; + MT param_parital_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; + MT grad_parital_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; + __syncthreads(); + MT param_norm = + Sqrt(math::blockReduceSum(param_parital_norm, FINAL_MASK)); + MT grad_norm = Sqrt(rescale_grad_pow * + math::blockReduceSum(grad_parital_norm, FINAL_MASK)); +#endif - v_out[i] = v_new; - p_out[i] = static_cast(p_new); - if (master_p_out) master_p_out[i] = p_new; + const MT lr = learning_rate[0]; + MT local_lr = lr; + if (lars_weight_decay > static_cast(0)) { + local_lr = lr * lars_coeff * param_norm / + (Fma(lars_weight_decay, param_norm, grad_norm) + epsilon); + } + + if (master_param_out) { + VectorizeLarsUpdate(grad, master_param, velocity, param_out, + velocity_out, mu, local_lr, + lars_weight_decay, rescale_grad, tid, + grid_stride, numel, master_param_out); + } else { + if (std::is_same::value || + std::is_same::value) { + // As for multiple-precision, type T and MT cannot be more than fp16 or + // fp32, Then, the maximum data IO size could be set to 4. + VectorizeLarsUpdate( + grad, reinterpret_cast(param), velocity, param_out, + velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, + grid_stride, numel); + } else { + VectorizeLarsUpdate( + grad, reinterpret_cast(param), velocity, param_out, + velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, + grid_stride, numel); + } } } template class LarsMomentumOpCUDAKernel : public framework::OpKernel { - using MPDType = MultiPrecisionType; + using MT = MultiPrecisionType; public: void Compute(const framework::ExecutionContext& ctx) const override { const bool multi_precision = ctx.Attr("multi_precision"); - if (multi_precision) { - InnerCompute(ctx, multi_precision); - } else { - InnerCompute(ctx, multi_precision); - } - } - - private: - template - void InnerCompute(const framework::ExecutionContext& ctx, - const bool multi_precision) const { auto param_out = ctx.Output("ParamOut"); auto velocity_out = ctx.Output("VelocityOut"); auto param = ctx.Input("Param"); @@ -78,8 +252,13 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { auto grad = ctx.Input("Grad"); auto learning_rate = ctx.Input("LearningRate"); + int64_t numel = param->numel(); + int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; const framework::Tensor* master_param = nullptr; framework::Tensor* master_param_out = nullptr; + const MT* master_param_data = nullptr; + MT* master_param_out_data = nullptr; + if (multi_precision) { bool has_master = ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); @@ -90,56 +269,114 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel { "the attr `multi_precision` is true")); master_param = ctx.Input("MasterParam"); master_param_out = ctx.Output("MasterParamOut"); + master_param_data = master_param->data(); + master_param_out_data = + master_param_out->mutable_data(ctx.GetPlace()); } - - const MT* master_p = multi_precision ? master_param->data() : nullptr; - MT* master_p_out = multi_precision - ? master_param_out->mutable_data(ctx.GetPlace()) - : nullptr; - - T* p_out = param_out->mutable_data(ctx.GetPlace()); - MT* v_out = velocity_out->mutable_data(ctx.GetPlace()); - MT mu = static_cast(ctx.Attr("mu")); MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); MT lars_weight_decay = static_cast(ctx.Attr("lars_weight_decay")); MT epsilon = static_cast(ctx.Attr("epsilon")); - MPDType rescale_grad = - static_cast(ctx.Attr("rescale_grad")); - - auto* p = param->data(); - auto* g = grad->data(); - auto* v = velocity->data(); - auto* lr = learning_rate->data(); - - int block = 512; - int grid = (param->numel() + block - 1) / block; - - auto eigen_p = framework::EigenVector::Flatten(*param); - auto eigen_g = framework::EigenVector::Flatten(*grad); - // calculate norms using eigein and launch the kernel. - framework::Tensor p_norm_t, g_norm_t; - p_norm_t.Resize({1}); - g_norm_t.Resize({1}); - auto* p_norm_data = p_norm_t.mutable_data(ctx.GetPlace()); - auto* g_norm_data = g_norm_t.mutable_data(ctx.GetPlace()); - auto ep_norm = framework::EigenScalar::From(p_norm_t); - auto eg_norm = framework::EigenScalar::From(g_norm_t); - - auto* place = ctx.template device_context().eigen_device(); - - // eigen unsupport fp16 l2-norm - ep_norm.device(*place) = - eigen_p.template cast().square().sum().sqrt(); - eg_norm.device(*place) = - (eigen_g.template cast() * rescale_grad).square().sum().sqrt(); + MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - MomentumLarsKernel< - T, MT><<>>( - p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay, - p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out, + auto* param_data = param->data(); + auto* grad_data = grad->data(); + auto* velocity_data = velocity->data(); + auto* lr = learning_rate->data(); + auto& cuda_ctx = ctx.template device_context(); + T* param_out_data = param_out->mutable_data(ctx.GetPlace()); + MT* velocity_out_data = velocity_out->mutable_data(ctx.GetPlace()); + +#if CUDA_VERSION >= 11000 + /* + Once model trainning with lars optimizer, whose principal implementation + is achieved by following two steps: + 1. Figure out the L2 norm statistic result of grad data and param data. + 2. Update param and velocity data with usage of L2 norm statistic result. + + Orignally, these two steps were fulfilled by respective eigen function and + cuda kernel, however the overhead of eigen function occupied much ratio in + total, consequently affect the performance of lars op, make it necessary + to combine 2 steps into one cuda kernel. + Since the step1 is l2 norm statistic, grid level reduce is needed. To + achieve this and continuous calculation of step 2 in only one global + lanuch, essential basis is to control all grid-threads while running. Apart + from normal lanuch form, cuda9.0 provides `cudaLaunchCooperativeKernel` + api : + - The thread quantity shall less than pyhsical SM limited threads + - Launches a device function where thread blocks can cooperate and + synchronize as they execute. + */ + // Figure out how many blocks can be active in each sm. + int num_blocks_per_sm = 0; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, + MomentumLarsKernel, + LARS_BLOCK_SIZE, sizeof(MT)); + int sm_num = cuda_ctx.GetSMCount(); + int grid_real = + std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE); + framework::Tensor tmp_buffer_t = + ctx.AllocateTmpTensor( + {LARS_BLOCK_SIZE << 1}, cuda_ctx); + auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); + auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; + int grid_stride = LARS_BLOCK_SIZE * grid; + int repeat_times = (numel + grid_stride - 1) / grid_stride - 1; + int thresh = 0; + + // Uniform kernel parameter for cudaLaunchCooperativeKernel + void* cuda_param[] = { + reinterpret_cast(¶m_data), + reinterpret_cast(&grad_data), + reinterpret_cast(&velocity_data), + reinterpret_cast(¶m_out_data), + reinterpret_cast(&velocity_out_data), + reinterpret_cast(&master_param_data), + reinterpret_cast(&master_param_out_data), + reinterpret_cast(&lr), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&mu), + reinterpret_cast(&lars_coeff), + reinterpret_cast(&lars_weight_decay), + reinterpret_cast(&epsilon), + reinterpret_cast(&rescale_grad), + reinterpret_cast(&repeat_times), + reinterpret_cast(&thresh), // Just a placeholder + reinterpret_cast(&numel)}; + // Lanuch all sm theads. + cudaLaunchCooperativeKernel( + reinterpret_cast(MomentumLarsKernel), grid_real, + LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream()); +#else + // Determine to read 4 fp16 or float data once, but 2 double data once. + int grid_lars = + sizeof(T) < sizeof(double) + ? (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2) + : (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1); + + int grid_norm = std::min(grid, LARS_BLOCK_SIZE); + framework::Tensor p_buffer_t = + ctx.AllocateTmpTensor( + {LARS_BLOCK_SIZE << 1}, cuda_ctx); + auto* p_buffer = p_buffer_t.mutable_data(ctx.GetPlace()); + auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; + + const int grid_stride = LARS_BLOCK_SIZE * grid_norm; + const int repeat_times = (numel + grid_stride - 1) / grid_stride - 1; + + L2NormKernel<<>>( + param_data, grad_data, p_buffer, g_buffer, repeat_times, numel, rescale_grad); + + MomentumLarsKernel< + T, MT><<>>( + param_data, grad_data, velocity_data, param_out_data, velocity_out_data, + master_param_data, master_param_out_data, lr, p_buffer, g_buffer, mu, + lars_coeff, lars_weight_decay, epsilon, rescale_grad, 0, grid_norm, + numel); // 0 is just a placeholder. +#endif } };