Skip to content

Commit

Permalink
[XPU] update xdnn adamw_v2 (#63108)
Browse files Browse the repository at this point in the history
* [XPU] update xdnn adamw_v2

* revert #48626
  • Loading branch information
houj04 authored May 27, 2024
1 parent 2ce5261 commit f3c7518
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 47 deletions.
80 changes: 37 additions & 43 deletions paddle/phi/kernels/xpu/adamw_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,18 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,

// template <typename T, typename TG, typename MT> DLL_EXPORT int
// adamw_v2(Context* ctx, MT beta1, MT beta2, MT epsilon, MT coeff, MT
// lr_ratio, const MT* beta1_pow, MT* beta1_pow_out, const MT* beta2_pow, MT*
// beta2_pow_out, const MT* moment1, MT* moment1_out, const MT* moment2, MT*
// moment2_out, const MT* lr, const TG* grad, const T* param, T* param_out,
// const MT* master_param, MT* master_param_out, int64_t n);
// lr_ratio, const MT* beta1_pow, MT beta1_pow_scalar, const MT* beta2_pow, MT
// beta2_pow_scalar, const MT* moment1, MT* moment1_out, const MT* moment2,
// MT* moment2_out, const MT* lr, const TG* grad, const T* param, T*
// param_out, const MT* master_param, MT* master_param_out, int64_t n, bool
// round_bf16_output);
bool round_bf16_output = false;
if (std::getenv("XPU_PADDLE_ADAMW_ROUND_BF16_OUTPUT") != nullptr) {
round_bf16_output = true;
}

if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
DenseTensor xpu_beta1_pow;
DenseTensor xpu_beta2_pow;
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, &xpu_beta1_pow);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, &xpu_beta2_pow);
dev_ctx.Wait();
const MPDType* beta1_pow_ptr = xpu_beta1_pow.data<MPDType>();
const MPDType* beta2_pow_ptr = xpu_beta2_pow.data<MPDType>();

// Compute with betapow in REG
if (grad_type == phi::DataType::FLOAT32) {
int r = xpu::adamw_v2<XPUType, float, MPDType>(
dev_ctx.x_context(),
Expand All @@ -267,10 +265,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow_ptr,
nullptr,
beta2_pow_ptr,
nullptr,
nullptr, // beta1_pow
*beta1_pow.data<MPDType>(), // beta1_pow_scalar
nullptr, // beta2_pow
*beta2_pow.data<MPDType>(), // beta2_pow_scalar
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
Expand All @@ -285,7 +283,8 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
param.numel(),
round_bf16_output);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
} else {
int r = xpu::adamw_v2<XPUType, XPUType, MPDType>(
Expand All @@ -295,10 +294,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow_ptr,
nullptr,
beta2_pow_ptr,
nullptr,
nullptr, // beta1_pow
*beta1_pow.data<MPDType>(), // beta1_pow_scalar
nullptr, // beta2_pow
*beta2_pow.data<MPDType>(), // beta2_pow_scalar
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
Expand All @@ -313,7 +312,8 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
param.numel(),
round_bf16_output);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
}
if (!use_global_beta_pow) {
Expand All @@ -324,14 +324,6 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
beta2_ * beta2_pow.data<MPDType>()[0];
}
} else {
MPDType* beta1_pow_out_ptr = nullptr;
MPDType* beta2_pow_out_ptr = nullptr;

if (!use_global_beta_pow) {
beta1_pow_out_ptr = dev_ctx.template Alloc<MPDType>(beta1_pow_out);
beta2_pow_out_ptr = dev_ctx.template Alloc<MPDType>(beta2_pow_out);
}

if (grad_type == phi::DataType::FLOAT32) {
int r = xpu::adamw_v2<XPUType, float, MPDType>(
dev_ctx.x_context(),
Expand All @@ -340,10 +332,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
nullptr, // beta1_pow_out_ptr,
beta2_pow.data<MPDType>(),
nullptr, // beta2_pow_out_ptr,
beta1_pow.data<MPDType>(), // beta1_pow
0.0f, // beta1_pow_scalar
beta2_pow.data<MPDType>(), // beta2_pow
0.0f, // beta2_pow_scalar
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
Expand All @@ -358,7 +350,8 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
param.numel(),
round_bf16_output);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
} else {
int r = xpu::adamw_v2<XPUType, XPUType, MPDType>(
Expand All @@ -368,10 +361,10 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
epsilon_,
coeff_,
lr_ratio_,
beta1_pow.data<MPDType>(),
nullptr, // beta1_pow_out_ptr,
beta2_pow.data<MPDType>(),
nullptr, // beta2_pow_out_ptr,
beta1_pow.data<MPDType>(), // beta1_pow
0.0f, // beta1_pow_scalar
beta2_pow.data<MPDType>(), // beta2_pow
0.0f, // beta2_pow_scalar
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
Expand All @@ -386,22 +379,23 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
param.numel(),
round_bf16_output);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
}
if (!use_global_beta_pow) {
// update beta1_pow and beta2_pow
// Update with xpu
int r = xpu::scale(dev_ctx.x_context(),
beta1_pow.data<MPDType>(),
beta1_pow_out_ptr,
dev_ctx.template Alloc<MPDType>(beta1_pow_out),
beta1_pow.numel(),
false,
beta1_,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::scale(dev_ctx.x_context(),
beta2_pow.data<MPDType>(),
beta2_pow_out_ptr,
dev_ctx.template Alloc<MPDType>(beta2_pow_out),
beta2_pow.numel(),
false,
beta2_,
Expand Down
6 changes: 2 additions & 4 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,10 +960,8 @@ def _add_accumulator(
belong_to_optimizer=True,
)

if (
in_dygraph_mode()
and (device == 'cpu' or isinstance(device, core.CPUPlace))
and (not core.is_compiled_with_xpu())
if in_dygraph_mode() and (
device == 'cpu' or isinstance(device, core.CPUPlace)
):
_C_ops.full_(
var,
Expand Down

0 comments on commit f3c7518

Please sign in to comment.