From e25aeb073bf7bd6ac0f35e302b2d4413b5e7ab3d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 4 Dec 2024 20:25:03 +0800 Subject: [PATCH] support amsgrad --- .../gpt-3/ppfleetx/optims/optimizer.py | 82 ++++++++++++++----- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/slm/model_zoo/gpt-3/ppfleetx/optims/optimizer.py b/slm/model_zoo/gpt-3/ppfleetx/optims/optimizer.py index 31669ba8ae5f..81e6075be7c6 100644 --- a/slm/model_zoo/gpt-3/ppfleetx/optims/optimizer.py +++ b/slm/model_zoo/gpt-3/ppfleetx/optims/optimizer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import warnings import paddle @@ -95,6 +96,11 @@ def __init__(self, learning_rate, parameters, grad_clip, **config): stop_gradient=True, ).cast(paddle.float32) + # support amsgrad: https://github.com/PaddlePaddle/Paddle/pull/68079 + self.support_amsgrad = ( + "amsgrad" in inspect.signature(paddle.optimizer.AdamW.__init__).parameters + ) + def _add_moments_pows(self, p): acc_dtype = p.dtype if self._is_dtype_fp16_or_bf16(acc_dtype): @@ -174,6 +180,15 @@ def _append_optimize_op(self, block, param_and_grad): moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0]) moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0]) + if self.support_amsgrad: + moment2_max = ( + self._get_accumulator_master( + self._moment2_acc_max_str, param_and_grad[0] + ) + if self._amsgrad + else None + ) + beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0]) beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0]) find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) @@ -191,27 +206,52 @@ def _append_optimize_op(self, block, param_and_grad): cpu_fp32_param = param_and_grad[0].cpu().cast(paddle.float32) cpu_fp32_grad = param_and_grad[1].cpu().cast(paddle.float32) - _, _, _, _, _, _ = _C_ops.adamw_( - cpu_fp32_param, - cpu_fp32_grad, - lr.cpu(), - moment1.cpu(), - moment2.cpu(), - beta1_pow_acc.cpu(), - beta2_pow_acc.cpu(), - master_weight.cpu() if master_weight is not None else None, - None, - _beta1, - _beta2, - self._epsilon, - lr_ratio_, - self._weight_decay, - with_decay, - self._lazy_mode, - 1000, - find_master, - False, - ) + if self.support_amsgrad: + _, _, _, _, _, _, _ = _C_ops.adamw_( + cpu_fp32_param, + cpu_fp32_grad, + lr.cpu(), + moment1.cpu(), + moment2.cpu(), + moment2_max, # moment2_max + beta1_pow_acc.cpu(), + beta2_pow_acc.cpu(), + master_weight.cpu() if master_weight is not None else None, + None, + _beta1, + _beta2, + self._epsilon, + lr_ratio_, + self._weight_decay, + with_decay, + self._lazy_mode, + 1000, + find_master, + False, + self._amsgrad, # amsgrad + ) + else: + _, _, _, _, _, _ = _C_ops.adamw_( + cpu_fp32_param, + cpu_fp32_grad, + lr.cpu(), + moment1.cpu(), + moment2.cpu(), + beta1_pow_acc.cpu(), + beta2_pow_acc.cpu(), + master_weight.cpu() if master_weight is not None else None, + None, + _beta1, + _beta2, + self._epsilon, + lr_ratio_, + self._weight_decay, + with_decay, + self._lazy_mode, + 1000, + find_master, + False, + ) param_and_grad[0]._clear_data() cpu_fp32_param.cuda(self._dev_id).cast(origin_dtype)._share_buffer_to(param_and_grad[0])