Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt code to amsgrad supported adamw #9568

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions slm/model_zoo/gpt-3/ppfleetx/optims/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import warnings

import paddle
Expand Down Expand Up @@ -95,6 +96,11 @@ def __init__(self, learning_rate, parameters, grad_clip, **config):
stop_gradient=True,
).cast(paddle.float32)

# support amsgrad: https://github.com/PaddlePaddle/Paddle/pull/68079
self.support_amsgrad = (
"amsgrad" in inspect.signature(paddle.optimizer.AdamW.__init__).parameters
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种写法是默认amsgrad吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种写法是默认amsgrad吗?

不,这个support_amsgrad的意思是说当前paddle是否支持amsgrad,因为支持amsgrad后,adam/adamw的C_ops接口有变动,位置参数发生变化,所以需要用这个flag表示当前运行的paddle是amsgrad之前的还是之后的。是否启用amsgrad是根据 self._amsgrad 决定的


def _add_moments_pows(self, p):
acc_dtype = p.dtype
if self._is_dtype_fp16_or_bf16(acc_dtype):
Expand Down Expand Up @@ -174,6 +180,15 @@ def _append_optimize_op(self, block, param_and_grad):

moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0])
moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0])
if self.support_amsgrad:
moment2_max = (
self._get_accumulator_master(
self._moment2_acc_max_str, param_and_grad[0]
)
if self._amsgrad
else None
)

beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0])
beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0])
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
Expand All @@ -191,27 +206,52 @@ def _append_optimize_op(self, block, param_and_grad):
cpu_fp32_param = param_and_grad[0].cpu().cast(paddle.float32)
cpu_fp32_grad = param_and_grad[1].cpu().cast(paddle.float32)

_, _, _, _, _, _ = _C_ops.adamw_(
cpu_fp32_param,
cpu_fp32_grad,
lr.cpu(),
moment1.cpu(),
moment2.cpu(),
beta1_pow_acc.cpu(),
beta2_pow_acc.cpu(),
master_weight.cpu() if master_weight is not None else None,
None,
_beta1,
_beta2,
self._epsilon,
lr_ratio_,
self._weight_decay,
with_decay,
self._lazy_mode,
1000,
find_master,
False,
)
if self.support_amsgrad:
_, _, _, _, _, _, _ = _C_ops.adamw_(
cpu_fp32_param,
cpu_fp32_grad,
lr.cpu(),
moment1.cpu(),
moment2.cpu(),
moment2_max, # moment2_max
beta1_pow_acc.cpu(),
beta2_pow_acc.cpu(),
master_weight.cpu() if master_weight is not None else None,
None,
_beta1,
_beta2,
self._epsilon,
lr_ratio_,
self._weight_decay,
with_decay,
self._lazy_mode,
1000,
find_master,
False,
self._amsgrad, # amsgrad
)
else:
_, _, _, _, _, _ = _C_ops.adamw_(
cpu_fp32_param,
cpu_fp32_grad,
lr.cpu(),
moment1.cpu(),
moment2.cpu(),
beta1_pow_acc.cpu(),
beta2_pow_acc.cpu(),
master_weight.cpu() if master_weight is not None else None,
None,
_beta1,
_beta2,
self._epsilon,
lr_ratio_,
self._weight_decay,
with_decay,
self._lazy_mode,
1000,
find_master,
False,
)

param_and_grad[0]._clear_data()
cpu_fp32_param.cuda(self._dev_id).cast(origin_dtype)._share_buffer_to(param_and_grad[0])
Expand Down