From a3c3800339e456da7d929e406bf52a83218b77cc Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Thu, 7 Nov 2024 13:21:38 +0800 Subject: [PATCH 1/2] [Cherry-pick] Cherry-pick some PRs. (#9381) * [Unified Checkpoint] Support empty state_dict saving (#9380) * fix empty state_dict * refine dtype use (#9366) --------- Co-authored-by: wanghuancoder --- legacy/model_zoo/bert/static/run_glue.py | 2 +- legacy/model_zoo/bert/static/run_glue_with_sparaity.py | 2 +- legacy/model_zoo/bert/static/run_pretrain.py | 2 +- legacy/model_zoo/gpt-3/ppfleetx/optims/optimizer.py | 2 +- legacy/model_zoo/moe/dygraph/framework/group_sharded.py | 3 +-- legacy/model_zoo/moe/dygraph/run_moe_pretrain.py | 9 ++------- paddlenlp/trainer/training_args.py | 9 +++++++++ paddlenlp/trainer/unified_checkpoint/async_handler.py | 5 +++++ 8 files changed, 21 insertions(+), 13 deletions(-) diff --git a/legacy/model_zoo/bert/static/run_glue.py b/legacy/model_zoo/bert/static/run_glue.py index d6cb54dc960f..6a42cdb84cd9 100644 --- a/legacy/model_zoo/bert/static/run_glue.py +++ b/legacy/model_zoo/bert/static/run_glue.py @@ -159,7 +159,7 @@ def reset_program_state_dict(args, model, state_dict, pretrained_state_dict): reset_parameter_names.append(n) else: dtype_str = "float32" - if str(p.dtype) == "VarType.FP64": + if p.dtype == paddle.float64: dtype_str = "float64" reset_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str) logger.info("the following parameter had reset, please check. {}".format(reset_parameter_names)) diff --git a/legacy/model_zoo/bert/static/run_glue_with_sparaity.py b/legacy/model_zoo/bert/static/run_glue_with_sparaity.py index 023c58b0b812..fe3701304ba6 100644 --- a/legacy/model_zoo/bert/static/run_glue_with_sparaity.py +++ b/legacy/model_zoo/bert/static/run_glue_with_sparaity.py @@ -159,7 +159,7 @@ def reset_program_state_dict(args, model, state_dict, pretrained_state_dict): reset_parameter_names.append(n) else: dtype_str = "float32" - if str(p.dtype) == "VarType.FP64": + if p.dtype == paddle.float64: dtype_str = "float64" reset_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str) logger.info("the following parameter had reset, please check. {}".format(reset_parameter_names)) diff --git a/legacy/model_zoo/bert/static/run_pretrain.py b/legacy/model_zoo/bert/static/run_pretrain.py index 38a172c734d0..8dfa8c354de3 100644 --- a/legacy/model_zoo/bert/static/run_pretrain.py +++ b/legacy/model_zoo/bert/static/run_pretrain.py @@ -151,7 +151,7 @@ def reset_program_state_dict(model, state_dict): for n, p in state_dict.items(): if "layer_norm" not in p.name: dtype_str = "float32" - if str(p.dtype) == "VarType.FP64": + if p.dtype == paddle.float64: dtype_str = "float64" new_state_dict[p.name] = np.random.normal(loc=0.0, scale=scale, size=p.shape).astype(dtype_str) return new_state_dict diff --git a/legacy/model_zoo/gpt-3/ppfleetx/optims/optimizer.py b/legacy/model_zoo/gpt-3/ppfleetx/optims/optimizer.py index 795bcbabfe60..4fd081d69c7a 100644 --- a/legacy/model_zoo/gpt-3/ppfleetx/optims/optimizer.py +++ b/legacy/model_zoo/gpt-3/ppfleetx/optims/optimizer.py @@ -98,7 +98,7 @@ def __init__(self, learning_rate, parameters, grad_clip, **config): def _add_moments_pows(self, p): acc_dtype = p.dtype if self._is_dtype_fp16_or_bf16(acc_dtype): - acc_dtype = core.VarDesc.VarType.FP32 + acc_dtype = paddle.float32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype, device="cpu") self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, device="cpu") self._add_accumulator( diff --git a/legacy/model_zoo/moe/dygraph/framework/group_sharded.py b/legacy/model_zoo/moe/dygraph/framework/group_sharded.py index 27e753abf3f3..af3361604ced 100644 --- a/legacy/model_zoo/moe/dygraph/framework/group_sharded.py +++ b/legacy/model_zoo/moe/dygraph/framework/group_sharded.py @@ -37,7 +37,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import ( GroupShardedStage2, ) -from paddle.framework import core from paddle.incubate.distributed.models.moe.grad_clip import ClipGradForMOEByGlobalNorm from paddle.optimizer import Optimizer @@ -99,7 +98,7 @@ def _dygraph_clip(self, params_grads): params_and_grads.append((p, g)) continue # TODO(wangxi): use inplace elementwise_mul - clip_input = clip_var.astype("float16") if g.dtype == core.VarDesc.VarType.FP16 else clip_var + clip_input = clip_var.astype("float16") if g.dtype == paddle.float16 else clip_var new_grad = paddle.multiply(x=g, y=clip_input) params_and_grads.append((p, new_grad)) return params_and_grads diff --git a/legacy/model_zoo/moe/dygraph/run_moe_pretrain.py b/legacy/model_zoo/moe/dygraph/run_moe_pretrain.py index de8ccd08fb8a..47a707b6edab 100644 --- a/legacy/model_zoo/moe/dygraph/run_moe_pretrain.py +++ b/legacy/model_zoo/moe/dygraph/run_moe_pretrain.py @@ -37,7 +37,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( GroupShardedScaler, ) -from paddle.framework import core from paddle.incubate.distributed.models import moe from utils import get_timers, set_timers from visualdl import LogWriter @@ -158,12 +157,8 @@ def initialize_mp_dp_parameters(model, hcg): def unscale_method(self, optimizer): if not self._enable: return - if paddle.framework.use_pir_api(): - type_float16 = core.DataType.FLOAT16 - type_float32 = core.DataType.FLOAT32 - else: - type_float16 = core.VarDesc.VarType.FP16 - type_float32 = core.VarDesc.VarType.FP32 + type_float16 = paddle.float16 + type_float32 = paddle.float32 if getattr(optimizer, "_param_groups", None) and isinstance(optimizer._param_groups[0], dict): param_grads_fp16 = [] diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 66cb257d37c8..f861f03222b9 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1136,6 +1136,15 @@ def split_parallel_config(parallel_config): raise ValueError( "If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True." ) + if ( + enable_sharding_comm_overlap + and self.unified_checkpoint + and "split_param" in split_parallel_config(self.sharding_parallel_config) + ): + logger.warning( + "Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`." + ) + enable_sharding_comm_overlap = False dygraph_pp_configs = { "delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False, diff --git a/paddlenlp/trainer/unified_checkpoint/async_handler.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py index 4206821b50e5..942ea41508bf 100644 --- a/paddlenlp/trainer/unified_checkpoint/async_handler.py +++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py @@ -77,6 +77,11 @@ def _file_save_async_or_sync( state_dict[k] = state_dict.pop(k).cpu().numpy() safe_save_file(state_dict, path, metadata={"format": "np"}) else: + if len(state_dict.keys()) == 0: + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{self.global_rank}") + paddle.save(self.global_rank, saved_signal_path) + return + if state_dict_type == "model_weight": if self._shm_model_weight is None: self._meta_dict_model, buffer_size = create_meta_dict(state_dict) From d02c40658c31ac8e25ea0bf2061f5bf52c9bc13c Mon Sep 17 00:00:00 2001 From: lugimzzz <63761690+lugimzzz@users.noreply.github.com> Date: Thu, 7 Nov 2024 20:13:51 +0800 Subject: [PATCH 2/2] fix dpo&kto oom (#9390) --- paddlenlp/trl/dpo_criterion.py | 8 ++++---- paddlenlp/trl/kto_criterion.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddlenlp/trl/dpo_criterion.py b/paddlenlp/trl/dpo_criterion.py index 2af2a8ef2096..be454e2ce4d1 100644 --- a/paddlenlp/trl/dpo_criterion.py +++ b/paddlenlp/trl/dpo_criterion.py @@ -287,10 +287,10 @@ def forward( ) loss = dpo_loss + sft_loss if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.sft_loss.append(sft_loss) - infohub.dpo_loss.append(dpo_loss) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.sft_loss.append(sft_loss.detach()) + infohub.dpo_loss.append(dpo_loss.detach()) return loss else: return policy_chosen_logps, policy_rejected_logps, sft_loss, dpo_loss, loss diff --git a/paddlenlp/trl/kto_criterion.py b/paddlenlp/trl/kto_criterion.py index 52745e996999..a6ca6c4c837a 100644 --- a/paddlenlp/trl/kto_criterion.py +++ b/paddlenlp/trl/kto_criterion.py @@ -247,10 +247,10 @@ def forward( reference_kl_logps, ) if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.policy_kl_logps.append(policy_kl_logps) - infohub.kl.append(kl) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.policy_kl_logps.append(policy_kl_logps.detach()) + infohub.kl.append(kl.detach()) return loss else: return (