Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] AdamW: fp16 for moment1/moment2 #62688

Merged
merged 4 commits into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 209 additions & 20 deletions paddle/phi/kernels/xpu/adamw_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,109 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
MPDType* master_out_data =
multi_precision ? dev_ctx.template Alloc<MPDType>(master_param_outs)
: nullptr;

// check moment_dtype
auto moment1_dtype = moment1.dtype();
auto moment2_dtype = moment2.dtype();
PADDLE_ENFORCE_EQ(moment1_dtype,
moment1_out->dtype(),
errors::InvalidArgument(
"moment1.dtype does not match moment1_out->dtype"));
PADDLE_ENFORCE_EQ(moment2_dtype,
moment2_out->dtype(),
errors::InvalidArgument(
"moment2.dtype does not match moment2_out->dtype"));
PADDLE_ENFORCE_EQ(
moment1_dtype,
moment2_dtype,
errors::InvalidArgument("moment1.dtype does not match moment2.dtype"));

bool moment_in_fp16 = false;
if (moment1_dtype == phi::DataType::FLOAT16) {
moment_in_fp16 = true;
} else {
PADDLE_ENFORCE_EQ(
moment1_dtype,
phi::DataType::FLOAT32,
errors::InvalidArgument("moment1.dtype is neither fp32 nor fp16"));
}

float* moment1_input_for_xdnn = nullptr;
float* moment2_input_for_xdnn = nullptr;
float* moment1_output_for_xdnn = nullptr;
float* moment2_output_for_xdnn = nullptr;

xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
if (moment_in_fp16) {
// allocate temp buffer on XPU
moment1_input_for_xdnn = RAII_GUARD.alloc_l3_or_gm<float>(moment1.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(moment1_input_for_xdnn);
moment2_input_for_xdnn = RAII_GUARD.alloc_l3_or_gm<float>(moment2.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(moment2_input_for_xdnn);
moment1_output_for_xdnn =
RAII_GUARD.alloc_l3_or_gm<float>(moment1_out->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(moment1_output_for_xdnn);
moment2_output_for_xdnn =
RAII_GUARD.alloc_l3_or_gm<float>(moment2_out->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(moment2_output_for_xdnn);

int r = 0;
using XPUType16 = typename XPUTypeTrait<phi::dtype::float16>::Type;

// cast moment1 and moment2, from fp16 to fp32
// int cast(Context* ctx, const TX* x, TY* y, int64_t len);
r = xpu::cast<XPUType16, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType16*>(
moment1.template data<phi::dtype::float16>()),
moment1_input_for_xdnn,
moment1.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment1 from fp16 to float");
r = xpu::cast<XPUType16, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType16*>(
moment2.template data<phi::dtype::float16>()),
moment2_input_for_xdnn,
moment2.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment2 from fp16 to float");

// acquire xpu_scale_value
float moment1_scale_value = XPUStorageProperties::default_xpu_scale_value;
if (moment1.storage_properties_initialized()) {
moment1_scale_value =
moment1.storage_properties<XPUStorageProperties>().xpu_scale_value;
}
float moment2_scale_value = XPUStorageProperties::default_xpu_scale_value;
if (moment2.storage_properties_initialized()) {
moment2_scale_value =
moment2.storage_properties<XPUStorageProperties>().xpu_scale_value;
}

// de-scale using scale_value
// int scale(Context* ctx, const T* x, T* y, int64_t len, bool
// bias_after_scale, float _scale, float _bias);
if (moment1_scale_value > 0) {
r = xpu::scale<float>(dev_ctx.x_context(),
moment1_input_for_xdnn,
moment1_input_for_xdnn,
moment1.numel(),
false,
1.0f / moment1_scale_value,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "de-scale for moment1");
}
if (moment2_scale_value > 0) {
r = xpu::scale<float>(dev_ctx.x_context(),
moment2_input_for_xdnn,
moment2_input_for_xdnn,
moment2.numel(),
false,
1.0f / moment2_scale_value,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "de-scale for moment2");
}
}

// template <typename T, typename TG, typename MT> DLL_EXPORT int
// adamw_v2(Context* ctx, MT beta1, MT beta2, MT epsilon, MT coeff, MT
// lr_ratio, const MT* beta1_pow, MT* beta1_pow_out, const MT* beta2_pow, MT*
Expand Down Expand Up @@ -168,18 +271,22 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
nullptr,
beta2_pow_ptr,
nullptr,
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment1_out),
moment_in_fp16 ? moment2_input_for_xdnn
: moment2.template data<MPDType>(),
moment_in_fp16 ? moment2_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
reinterpret_cast<const XPUType*>(param.data<T>()),
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
} else {
int r = xpu::adamw_v2<XPUType, XPUType, MPDType>(
dev_ctx.x_context(),
Expand All @@ -192,18 +299,22 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
nullptr,
beta2_pow_ptr,
nullptr,
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment1_out),
moment_in_fp16 ? moment2_input_for_xdnn
: moment2.template data<MPDType>(),
moment_in_fp16 ? moment2_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
reinterpret_cast<const XPUType*>(grad.data<T>()),
reinterpret_cast<const XPUType*>(param.data<T>()),
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
}
if (!use_global_beta_pow) {
// Cpu update
Expand Down Expand Up @@ -233,18 +344,22 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
nullptr, // beta1_pow_out_ptr,
beta2_pow.data<MPDType>(),
nullptr, // beta2_pow_out_ptr,
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment1_out),
moment_in_fp16 ? moment2_input_for_xdnn
: moment2.template data<MPDType>(),
moment_in_fp16 ? moment2_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
grad.data<float>(),
reinterpret_cast<const XPUType*>(param.data<T>()),
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
} else {
int r = xpu::adamw_v2<XPUType, XPUType, MPDType>(
dev_ctx.x_context(),
Expand All @@ -257,18 +372,22 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
nullptr, // beta1_pow_out_ptr,
beta2_pow.data<MPDType>(),
nullptr, // beta2_pow_out_ptr,
moment1.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment1_out),
moment2.data<MPDType>(),
dev_ctx.template Alloc<MPDType>(moment2_out),
moment_in_fp16 ? moment1_input_for_xdnn
: moment1.template data<MPDType>(),
moment_in_fp16 ? moment1_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment1_out),
moment_in_fp16 ? moment2_input_for_xdnn
: moment2.template data<MPDType>(),
moment_in_fp16 ? moment2_output_for_xdnn
: dev_ctx.template Alloc<MPDType>(moment2_out),
learning_rate.data<MPDType>(),
reinterpret_cast<const XPUType*>(grad.data<T>()),
reinterpret_cast<const XPUType*>(param.data<T>()),
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(param_out)),
master_in_data,
master_out_data,
param.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2");
}
if (!use_global_beta_pow) {
// update beta1_pow and beta2_pow
Expand All @@ -290,6 +409,76 @@ void AdamwDenseKernelKL3(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
}

if (moment_in_fp16) {
int r = 0;
using XPUType16 = typename XPUTypeTrait<phi::dtype::float16>::Type;

// findmax and calculate scale_value for moment1 and moment2
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
float* buffer_for_findmax = RAII_GUARD.alloc_l3_or_gm<float>(max_ptr_size);

// for moment1
float moment1_max = GetAbsMax<Context>(dev_ctx,
moment1_output_for_xdnn,
buffer_for_findmax,
moment1_out->numel());
float moment1_scale_value = 65504.0f / moment1_max / 2.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请教一下珏爷,这里固定的这个值不会有问题嘛

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考per tensor scale设计。内部文档就不往外面贴了。

// int scale(Context* ctx, const T* x, T* y, int64_t len, bool
// bias_after_scale, float _scale, float _bias);
r = xpu::scale<float>(dev_ctx.x_context(),
moment1_output_for_xdnn,
moment1_output_for_xdnn,
moment1_out->numel(),
false,
moment1_scale_value,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(
r, "scale before convert to fp16, for moment1_output_for_xdnn");
// write to moment1_out
std::unique_ptr<phi::StorageProperties> moment1_out_sp =
std::make_unique<phi::XPUStorageProperties>(moment1_scale_value);
moment1_out->set_storage_properties(std::move(moment1_out_sp));

// for moment2
float moment2_max = GetAbsMax<Context>(dev_ctx,
moment2_output_for_xdnn,
buffer_for_findmax,
moment2_out->numel());
float moment2_scale_value = 65504.0f / moment2_max / 2.0f;
// int scale(Context* ctx, const T* x, T* y, int64_t len, bool
// bias_after_scale, float _scale, float _bias);
r = xpu::scale<float>(dev_ctx.x_context(),
moment2_output_for_xdnn,
moment2_output_for_xdnn,
moment2_out->numel(),
false,
moment2_scale_value,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(
r, "scale before convert to fp16, for moment2_output_for_xdnn");
// write to moment2_out
std::unique_ptr<phi::StorageProperties> moment2_out_sp =
std::make_unique<phi::XPUStorageProperties>(moment2_scale_value);
moment2_out->set_storage_properties(std::move(moment2_out_sp));

// cast moment1 and moment2 output, from fp32 to fp16
// int cast(Context* ctx, const TX* x, TY* y, int64_t len);
r = xpu::cast<float, XPUType16>(
dev_ctx.x_context(),
moment1_output_for_xdnn,
reinterpret_cast<XPUType16*>(
dev_ctx.template Alloc<phi::dtype::float16>(moment1_out)),
moment1.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment1_out from float to fp16");
r = xpu::cast<float, XPUType16>(
dev_ctx.x_context(),
moment2_output_for_xdnn,
reinterpret_cast<XPUType16*>(
dev_ctx.template Alloc<phi::dtype::float16>(moment2_out)),
moment2.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment2_out from float to fp16");
}
return;
}

Expand Down