From f3c7518e678946418b7c83db5878f4ab5226d51e Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Mon, 27 May 2024 10:50:57 +0800 Subject: [PATCH] [XPU] update xdnn adamw_v2 (#63108) * [XPU] update xdnn adamw_v2 * revert #48626 --- paddle/phi/kernels/xpu/adamw_kernel.cc | 80 ++++++++++++-------------- python/paddle/optimizer/optimizer.py | 6 +- 2 files changed, 39 insertions(+), 47 deletions(-) diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index f60e02c61a323..72c1c5d578eaf 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -245,20 +245,18 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, // template DLL_EXPORT int // adamw_v2(Context* ctx, MT beta1, MT beta2, MT epsilon, MT coeff, MT - // lr_ratio, const MT* beta1_pow, MT* beta1_pow_out, const MT* beta2_pow, MT* - // beta2_pow_out, const MT* moment1, MT* moment1_out, const MT* moment2, MT* - // moment2_out, const MT* lr, const TG* grad, const T* param, T* param_out, - // const MT* master_param, MT* master_param_out, int64_t n); + // lr_ratio, const MT* beta1_pow, MT beta1_pow_scalar, const MT* beta2_pow, MT + // beta2_pow_scalar, const MT* moment1, MT* moment1_out, const MT* moment2, + // MT* moment2_out, const MT* lr, const TG* grad, const T* param, T* + // param_out, const MT* master_param, MT* master_param_out, int64_t n, bool + // round_bf16_output); + bool round_bf16_output = false; + if (std::getenv("XPU_PADDLE_ADAMW_ROUND_BF16_OUTPUT") != nullptr) { + round_bf16_output = true; + } if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { - DenseTensor xpu_beta1_pow; - DenseTensor xpu_beta2_pow; - phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, &xpu_beta1_pow); - phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, &xpu_beta2_pow); - dev_ctx.Wait(); - const MPDType* beta1_pow_ptr = xpu_beta1_pow.data(); - const MPDType* beta2_pow_ptr = xpu_beta2_pow.data(); - + // Compute with betapow in REG if (grad_type == phi::DataType::FLOAT32) { int r = xpu::adamw_v2( dev_ctx.x_context(), @@ -267,10 +265,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, epsilon_, coeff_, lr_ratio_, - beta1_pow_ptr, - nullptr, - beta2_pow_ptr, - nullptr, + nullptr, // beta1_pow + *beta1_pow.data(), // beta1_pow_scalar + nullptr, // beta2_pow + *beta2_pow.data(), // beta2_pow_scalar moment_in_fp16 ? moment1_input_for_xdnn : moment1.template data(), moment_in_fp16 ? moment1_output_for_xdnn @@ -285,7 +283,8 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, reinterpret_cast(dev_ctx.template Alloc(param_out)), master_in_data, master_out_data, - param.numel()); + param.numel(), + round_bf16_output); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); } else { int r = xpu::adamw_v2( @@ -295,10 +294,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, epsilon_, coeff_, lr_ratio_, - beta1_pow_ptr, - nullptr, - beta2_pow_ptr, - nullptr, + nullptr, // beta1_pow + *beta1_pow.data(), // beta1_pow_scalar + nullptr, // beta2_pow + *beta2_pow.data(), // beta2_pow_scalar moment_in_fp16 ? moment1_input_for_xdnn : moment1.template data(), moment_in_fp16 ? moment1_output_for_xdnn @@ -313,7 +312,8 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, reinterpret_cast(dev_ctx.template Alloc(param_out)), master_in_data, master_out_data, - param.numel()); + param.numel(), + round_bf16_output); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); } if (!use_global_beta_pow) { @@ -324,14 +324,6 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, beta2_ * beta2_pow.data()[0]; } } else { - MPDType* beta1_pow_out_ptr = nullptr; - MPDType* beta2_pow_out_ptr = nullptr; - - if (!use_global_beta_pow) { - beta1_pow_out_ptr = dev_ctx.template Alloc(beta1_pow_out); - beta2_pow_out_ptr = dev_ctx.template Alloc(beta2_pow_out); - } - if (grad_type == phi::DataType::FLOAT32) { int r = xpu::adamw_v2( dev_ctx.x_context(), @@ -340,10 +332,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, epsilon_, coeff_, lr_ratio_, - beta1_pow.data(), - nullptr, // beta1_pow_out_ptr, - beta2_pow.data(), - nullptr, // beta2_pow_out_ptr, + beta1_pow.data(), // beta1_pow + 0.0f, // beta1_pow_scalar + beta2_pow.data(), // beta2_pow + 0.0f, // beta2_pow_scalar moment_in_fp16 ? moment1_input_for_xdnn : moment1.template data(), moment_in_fp16 ? moment1_output_for_xdnn @@ -358,7 +350,8 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, reinterpret_cast(dev_ctx.template Alloc(param_out)), master_in_data, master_out_data, - param.numel()); + param.numel(), + round_bf16_output); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); } else { int r = xpu::adamw_v2( @@ -368,10 +361,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, epsilon_, coeff_, lr_ratio_, - beta1_pow.data(), - nullptr, // beta1_pow_out_ptr, - beta2_pow.data(), - nullptr, // beta2_pow_out_ptr, + beta1_pow.data(), // beta1_pow + 0.0f, // beta1_pow_scalar + beta2_pow.data(), // beta2_pow + 0.0f, // beta2_pow_scalar moment_in_fp16 ? moment1_input_for_xdnn : moment1.template data(), moment_in_fp16 ? moment1_output_for_xdnn @@ -386,14 +379,15 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, reinterpret_cast(dev_ctx.template Alloc(param_out)), master_in_data, master_out_data, - param.numel()); + param.numel(), + round_bf16_output); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); } if (!use_global_beta_pow) { - // update beta1_pow and beta2_pow + // Update with xpu int r = xpu::scale(dev_ctx.x_context(), beta1_pow.data(), - beta1_pow_out_ptr, + dev_ctx.template Alloc(beta1_pow_out), beta1_pow.numel(), false, beta1_, @@ -401,7 +395,7 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); r = xpu::scale(dev_ctx.x_context(), beta2_pow.data(), - beta2_pow_out_ptr, + dev_ctx.template Alloc(beta2_pow_out), beta2_pow.numel(), false, beta2_, diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index a41a17a1c7e2c..4c94a85107a4b 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -960,10 +960,8 @@ def _add_accumulator( belong_to_optimizer=True, ) - if ( - in_dygraph_mode() - and (device == 'cpu' or isinstance(device, core.CPUPlace)) - and (not core.is_compiled_with_xpu()) + if in_dygraph_mode() and ( + device == 'cpu' or isinstance(device, core.CPUPlace) ): _C_ops.full_( var,