Skip to content

Commit

Permalink
refine dtype use (PaddlePaddle#9366)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder authored and DesmonDay committed Nov 7, 2024
1 parent 2cd5991 commit bdd71a1
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 13 deletions.
2 changes: 1 addition & 1 deletion legacy/model_zoo/bert/static/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def reset_program_state_dict(args, model, state_dict, pretrained_state_dict):
reset_parameter_names.append(n)
else:
dtype_str = "float32"
if str(p.dtype) == "VarType.FP64":
if p.dtype == paddle.float64:
dtype_str = "float64"
reset_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
logger.info("the following parameter had reset, please check. {}".format(reset_parameter_names))
Expand Down
2 changes: 1 addition & 1 deletion legacy/model_zoo/bert/static/run_glue_with_sparaity.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def reset_program_state_dict(args, model, state_dict, pretrained_state_dict):
reset_parameter_names.append(n)
else:
dtype_str = "float32"
if str(p.dtype) == "VarType.FP64":
if p.dtype == paddle.float64:
dtype_str = "float64"
reset_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
logger.info("the following parameter had reset, please check. {}".format(reset_parameter_names))
Expand Down
2 changes: 1 addition & 1 deletion legacy/model_zoo/bert/static/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def reset_program_state_dict(model, state_dict):
for n, p in state_dict.items():
if "layer_norm" not in p.name:
dtype_str = "float32"
if str(p.dtype) == "VarType.FP64":
if p.dtype == paddle.float64:
dtype_str = "float64"
new_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
return new_state_dict
Expand Down
2 changes: 1 addition & 1 deletion legacy/model_zoo/gpt-3/ppfleetx/optims/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, learning_rate, parameters, grad_clip, **config):
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if self._is_dtype_fp16_or_bf16(acc_dtype):
acc_dtype = core.VarDesc.VarType.FP32
acc_dtype = paddle.float32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype, device="cpu")
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, device="cpu")
self._add_accumulator(
Expand Down
3 changes: 1 addition & 2 deletions legacy/model_zoo/moe/dygraph/framework/group_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import (
GroupShardedStage2,
)
from paddle.framework import core
from paddle.incubate.distributed.models.moe.grad_clip import ClipGradForMOEByGlobalNorm
from paddle.optimizer import Optimizer

Expand Down Expand Up @@ -99,7 +98,7 @@ def _dygraph_clip(self, params_grads):
params_and_grads.append((p, g))
continue
# TODO(wangxi): use inplace elementwise_mul
clip_input = clip_var.astype("float16") if g.dtype == core.VarDesc.VarType.FP16 else clip_var
clip_input = clip_var.astype("float16") if g.dtype == paddle.float16 else clip_var
new_grad = paddle.multiply(x=g, y=clip_input)
params_and_grads.append((p, new_grad))
return params_and_grads
Expand Down
9 changes: 2 additions & 7 deletions legacy/model_zoo/moe/dygraph/run_moe_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler,
)
from paddle.framework import core
from paddle.incubate.distributed.models import moe
from utils import get_timers, set_timers
from visualdl import LogWriter
Expand Down Expand Up @@ -158,12 +157,8 @@ def initialize_mp_dp_parameters(model, hcg):
def unscale_method(self, optimizer):
if not self._enable:
return
if paddle.framework.use_pir_api():
type_float16 = core.DataType.FLOAT16
type_float32 = core.DataType.FLOAT32
else:
type_float16 = core.VarDesc.VarType.FP16
type_float32 = core.VarDesc.VarType.FP32
type_float16 = paddle.float16
type_float32 = paddle.float32

if getattr(optimizer, "_param_groups", None) and isinstance(optimizer._param_groups[0], dict):
param_grads_fp16 = []
Expand Down

0 comments on commit bdd71a1

Please sign in to comment.