From cd140814057c49cba9a01023e6d04571b377e790 Mon Sep 17 00:00:00 2001 From: kebo01 Date: Mon, 27 May 2024 10:39:18 +0800 Subject: [PATCH] [fea] moe support --- paddlenlp/trainer/trainer.py | 44 ++++++++++++++++++++------ paddlenlp/trainer/training_args.py | 25 ++++++++++++++- paddlenlp/trainer/utils/sharding_io.py | 18 ++++++++--- paddlenlp/transformers/utils.py | 6 ++++ 4 files changed, 78 insertions(+), 15 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01b902478622..7e7cbae35887 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -945,7 +945,8 @@ def _inner_training_loop( ((step_control + 1) % args.gradient_accumulation_steps != 0) and availiable_no_sync and args._no_sync_in_gradient_accumulation - ) or (args.recompute and availiable_no_sync) + ) or (args.recompute and availiable_no_sync + ) or (args.use_moe and availiable_no_sync) # sharding # stage1. the same as ddp # stage2. manualy collect gradient on dp group @@ -965,6 +966,14 @@ def _inner_training_loop( tr_loss += tr_loss_step + def fused_allreduce_gradients_no_sync(paramlist, hcg): + paramlist = list(paramlist) + nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)] + moelist = [p for p in paramlist if getattr(p, "no_sync", False)] + if moelist and not self.args.use_moe: + logger.warning("found `no sync` param when `use_moe=False`") + fused_allreduce_gradients(nonmoe_list, hcg) + if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps @@ -983,12 +992,12 @@ def _inner_training_loop( # Case 1: Use recompute and dp / sharding stage1, # manualy collect gradient for dp. - if args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(model.parameters()), None) + if (args.recompute or args.use_moe) and availiable_no_sync: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Case 2: hack dp with master_grad - if dp_master_grad and not (args.recompute and availiable_no_sync): - fused_allreduce_gradients(list(model.parameters()), None) + elif dp_master_grad: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap pipeline_parallel_config = ( @@ -1007,8 +1016,7 @@ def _inner_training_loop( self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): - fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) - + fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg) self.timers and self.timers("all-reduce").stop() self.timers and self.timers("optimizer-step").start() @@ -1028,7 +1036,9 @@ def _inner_training_loop( ) optimizer_was_run = True if self.do_grad_scaling: - scale_before = paddle.assign(self.scaler._scale) + if args.pipeline_parallel_degree > 1: + assert not self.args.use_moe, "pipline moe not work under fp16" + scale_before = self.scaler._scale.numpy() self.scaler.step(self.optimizer) self.scaler.update() scale_after = self.scaler._scale @@ -2042,7 +2052,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, model.train() inputs = self._prepare_inputs(inputs) - + self.timers and self.timers(f"forward-acc-{self._cur_acc_step}").start() with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) @@ -2053,7 +2063,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, self.scaler.scale(loss).backward() else: loss.backward() - + self.timers and self.timers(f"backward-acc-{self._cur_acc_step}").stop() return loss.detach() def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: @@ -2142,6 +2152,18 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op if self.args.should_save_model_state and self.args.should_save: # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done")) + def _save_moe_weights( + self, + output_dir: Optional[str] = None, + merge_tensor_parallel: Optional[bool] = False,): + # save moe optimizer and model state # TODO 默认为冗余存储 + + self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) + optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix) + saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}") + paddle.save(self.optimizer.state_dict(), os.path.join(output_dir, optimizer_name)) + with open(saved_signal_path, mode="w+") as f: + f.write("1") def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" @@ -2245,6 +2267,8 @@ def _save_checkpoint(self, model, metrics=None): os.makedirs(output_dir, exist_ok=True) paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + if self.args.use_moe and self.args.data_parallel_rank > 0: + self._save_moe_weights(output_dir) # Maybe delete some older checkpoints. # For hybrid parallel training, the checkpoint files maybe on different node. need_to_rotate_checkpoints = False diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index d30876a0f2d2..d5d43093565e 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -803,6 +803,10 @@ class TrainingArguments: default=False, metadata={"help": "whether to run distributed training in auto parallel mode"}, ) + use_moe: Optional[bool] = field( + default=False, + metadata={"help": "开启moe训练"}, + ) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) @@ -1149,6 +1153,8 @@ def is_segment_parallel_supported(): order = ["dp", "sharding", "pp", "sep", "mp"] else: order = ["dp", "sharding", "pp", "mp"] + if self.use_moe: + order = order[1: -1] + ["dp", "mp"] if is_segment_parallel_supported(): hybrid_configs = { @@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self): name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) if self.sharding_parallel_degree > 1: name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree)) + if self.use_moe: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: + if self.use_moe: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None @property @@ -1652,12 +1662,16 @@ def weight_name_suffix(self): name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree)) if self.pipeline_parallel_degree > 1: name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) + if self.use_moe: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: + if self.use_moe: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None - def sharded_name_suffix(self, shard_id=None, pp_id=None): + def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None): if self.use_hybrid_parallel: name = [] if self.tensor_parallel_degree > 1: @@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None): shard_id = self.sharding_parallel_rank assert isinstance(shard_id, int) name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree)) + if self.use_moe: + if moe_id is None: + moe_id = self.data_parallel_rank + assert isinstance(moe_id, int) + name.append(self._format_name("moe", moe_id, self.data_parallel_degree)) return "_".join(name) else: + if self.use_moe: + if moe_id is None: + moe_id = self.data_parallel_rank + return self._format_name("moe", moe_id, self.data_parallel_degree) return None @property diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 56f4c426ce0a..3e2bd2087cbf 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -444,12 +444,22 @@ def filter_func(name): master_weights = reshard_util.all_gather_state_dict(master_weights, filter_func, self.sharding_group) model_state_dict = self.model.state_dict() + logger.info(f"state-dict-keys: {state_dict.keys()}, nums: {len(state_dict.keys())}") logger.info("before recover, model_state_dict number: {}".format(len(model_state_dict))) for key, param in model_state_dict.items(): - if param.name in master_weights: - assert param.shape == master_weights[param.name].shape - paddle.assign(master_weights[param.name].cuda(), model_state_dict[key]) - + if param.name in master_weigths: + assert param.shape == master_weigths[param.name].shape + paddle.assign(paddle.cast(master_weigths[param.name].cuda(), paddle.bfloat16), model_state_dict[key]) + elif key in state_dict: + logger.info(f"key: {key} is in state_dict, but not in master_weights") + paddle.assign(state_dict[key], model_state_dict[key]) + if param.name in sharding_group_param_names: + paddle.distributed.broadcast( + model_state_dict[key], + src=self.sharding_group.ranks[param2rank[param.name]], + group=self.sharding_group, + sync_op=True, + ) logger.info("after recover, casted model_state_dict number: {}".format(len(model_state_dict))) state_dict.update(model_state_dict) return state_dict diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index f785a5358af4..4a5c067fed6c 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -818,8 +818,14 @@ def weight_name_suffix(): name.append(f"tp{hcg.get_model_parallel_rank():0>2d}") if hcg.get_pipe_parallel_world_size() > 1: name.append(f"pp{hcg.get_stage_id():0>2d}") + if config and getattr(config, "moe_num_experts", 0): + dp_group = hcg.get_data_parallel_group() + name.append(f"moe{dp_group.rank:0>2d}") return "_".join(name) else: + if config and getattr(config, "moe_num_experts", 0): + rank = paddle.distributed.get_rank() + return f"moe{rank:0>2d}" return None