Skip to content

Commit

Permalink
fix pir dtype (#9130)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder authored Sep 13, 2024
1 parent cd3dc95 commit 399490b
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions legacy/model_zoo/moe/dygraph/run_moe_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,27 +158,33 @@ def initialize_mp_dp_parameters(model, hcg):
def unscale_method(self, optimizer):
if not self._enable:
return
if paddle.framework.use_pir_api():
type_float16 = core.DataType.FLOAT16
type_float32 = core.DataType.FLOAT32
else:
type_float16 = core.VarDesc.VarType.FP16
type_float32 = core.VarDesc.VarType.FP32

if getattr(optimizer, "_param_groups", None) and isinstance(optimizer._param_groups[0], dict):
param_grads_fp16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group["params"]:
if param._grad_ivar() is not None:
if param._grad_ivar().dtype == core.VarDesc.VarType.FP16:
if param._grad_ivar().dtype == type_float16:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads_fp16 = [
param._grad_ivar()
for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16)
if (param._grad_ivar() is not None) and (param._grad_ivar().dtype == type_float16)
]
param_grads_fp32 = [
param._grad_ivar()
for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32)
if (param._grad_ivar() is not None) and (param._grad_ivar().dtype == type_float32)
]
temp_found_inf_fp16 = paddle.to_tensor(np.array([0]).astype(np.bool_))
temp_found_inf_fp32 = paddle.to_tensor(np.array([0]).astype(np.bool_))
Expand Down

0 comments on commit 399490b

Please sign in to comment.