diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 939e1dd4d1de..3b1ee6079331 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -48,6 +48,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import ( GroupShardedOptimizerStage2, ) +from paddle.utils import map_structure try: from paddle.distributed.fleet.utils.hybrid_parallel_util import ( @@ -143,6 +144,7 @@ from .utils import reshard as reshard_util from .utils.helper import ( # nested_truncate, broadcast_dp_optimizer, + broadcast_moe_optimizer, distributed_concat, distributed_file, distributed_isfile, @@ -565,7 +567,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): ) self.model.set_state_dict(state_dict) else: - if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_moe): + if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel): weights_file = os.path.join( resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix) @@ -940,12 +942,17 @@ def _inner_training_loop( forbidden_no_sync = True availiable_no_sync = dp_enabled and not forbidden_no_sync + has_no_sync = hasattr(model, "no_sync") is_no_sync = ( - ((step_control + 1) % args.gradient_accumulation_steps != 0) - and availiable_no_sync - and args._no_sync_in_gradient_accumulation - ) or (args.recompute and availiable_no_sync) + ( + ((step_control + 1) % args.gradient_accumulation_steps != 0) + and availiable_no_sync + and args._no_sync_in_gradient_accumulation + ) + or (args.recompute and availiable_no_sync) + or args.use_expert_parallel + ) # sharding # stage1. the same as ddp # stage2. manualy collect gradient on dp group @@ -956,14 +963,25 @@ def _inner_training_loop( if dp_master_grad: is_no_sync = True - if is_no_sync: + if is_no_sync and has_no_sync: # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. with model.no_sync(): tr_loss_step = self.training_step(model, inputs) else: tr_loss_step = self.training_step(model, inputs) - tr_loss += tr_loss_step + def fused_allreduce_gradients_no_sync(param_list, hcg): + param_list = list(param_list) + nonmoe_list = [p for p in param_list if not getattr(p, "no_sync", False)] + moe_list = [p for p in param_list if getattr(p, "no_sync", False)] + if moe_list and not self.args.use_expert_parallel: + logger.warning("found `no_sync` param when `use_expert_parallel=False`") + fused_allreduce_gradients(nonmoe_list, hcg) + + if tr_loss_step is not None: + if tr_loss is None: + tr_loss = map_structure(lambda x: paddle.zeros_like(x), tr_loss_step) + map_structure(lambda x, y: x.add_(y), tr_loss, tr_loss_step) if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps @@ -983,12 +1001,13 @@ def _inner_training_loop( # Case 1: Use recompute and dp / sharding stage1, # manualy collect gradient for dp. - if args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(model.parameters()), None) + # Case 1.1: pure dp + moe should manually collect gradient here. + if (args.recompute or args.use_expert_parallel) and availiable_no_sync: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Case 2: hack dp with master_grad if dp_master_grad and not (args.recompute and availiable_no_sync): - fused_allreduce_gradients(list(model.parameters()), None) + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap pipeline_parallel_config = ( @@ -1007,7 +1026,9 @@ def _inner_training_loop( self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): - fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) + fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg) + else: + assert not self.args.use_expert_parallel, "moe should not use `enable_dp_comm_overlap`" self.timers and self.timers("all-reduce").stop() self.timers and self.timers("optimizer-step").start() @@ -1132,7 +1153,7 @@ def _inner_training_loop( "on multiple nodes, you should activate `--save_on_each_node`." ) - self._total_loss_scalar += tr_loss.item() + self._total_loss_scalar += tr_loss.pop("loss").item() if isinstance(tr_loss, dict) else tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) @@ -1250,12 +1271,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, logs: Dict[str, float] = {} # all_gather + mean() to get average loss over all processes - tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean()) + tr_loss_scalar = map_structure(lambda x: self._get_item_from_loss(self._nested_gather(x).mean()), tr_loss) # reset tr_loss to zero - tr_loss.subtract_(tr_loss) - - logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8) + map_structure(lambda x: x.zero_(), tr_loss) + + if isinstance(tr_loss_scalar, dict): + for k, v in tr_loss_scalar.items(): + logs[k] = round(v / (self.state.global_step - self._globalstep_last_logged), 8) + elif isinstance(tr_loss_scalar, (list, tuple)): + for i, v in enumerate(tr_loss_scalar): + logs[f"loss_{i}"] = round(v / (self.state.global_step - self._globalstep_last_logged), 8) + else: + logs["loss"] = round( + tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), + 8, + ) logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate())) logs["global_step"] = int(self.state.global_step) @@ -1290,7 +1321,9 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, ) ) - self._total_loss_scalar += tr_loss_scalar + self._total_loss_scalar += ( + tr_loss_scalar.pop("loss") if isinstance(tr_loss_scalar, dict) else tr_loss_scalar + ) self._globalstep_last_logged = self.state.global_step self._globalstep_last_start_time = time.time() @@ -2047,14 +2080,19 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, loss = self.compute_loss(model, inputs) if self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss(): - loss = loss / self.args.gradient_accumulation_steps + loss = map_structure(lambda x: x / self.args.gradient_accumulation_steps, loss) + + if isinstance(loss, dict): + total_loss = loss["loss"] + else: + total_loss = loss if self.do_grad_scaling: - self.scaler.scale(loss).backward() + self.scaler.scale(total_loss).backward() else: - loss.backward() + total_loss.backward() - return loss.detach() + return map_structure(lambda v: v.detach(), loss) def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: """ @@ -2113,6 +2151,18 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle return loss.detach() + def _save_moe_weights( + self, + output_dir, + merge_tensor_parallel: Optional[bool] = False, + ): + self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel) + if not self.args.ignore_save_lr_and_optim: + self._save_ckpt_func( + self.optimizer.state_dict(), + os.path.join(output_dir, _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)), + ) + def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False): """ Will save the model, so you can reload it using `from_pretrained()`. @@ -2126,7 +2176,12 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op if ShardingOption.FULL_SHARD in self.args.sharding: self.model_wrapped.get_all_parameters(convert2cpu=True) - if self.args.should_save_model_state: + if not self.is_in_train and self.args.use_expert_parallel: + should_save_model_state = self.args.should_save_moe_model_state + else: + should_save_model_state = self.args.should_save_model_state + + if should_save_model_state: unified_checkpoint_config_backup = self.args.unified_checkpoint_config # backup and remove unified_checkpoint_config for not trine stage if not self.is_in_train: @@ -2245,6 +2300,10 @@ def _save_checkpoint(self, model, metrics=None): os.makedirs(output_dir, exist_ok=True) paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + if self.args.use_expert_parallel and self.args.data_parallel_rank > 0: + logger.info("Saving moe weights for data_parallel_rank > 0") + self._save_moe_weights(output_dir) + # Maybe delete some older checkpoints. # For hybrid parallel training, the checkpoint files maybe on different node. need_to_rotate_checkpoints = False @@ -2452,7 +2511,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint") if not use_unified_checkpoint: - if self.args.data_parallel_rank == 0: + if self.args.use_expert_parallel or self.args.data_parallel_rank == 0: optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix) path = os.path.join(checkpoint, optimizer_name) if os.path.isfile(path): @@ -2476,7 +2535,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): # broadcast optimizer state in dp group if self.args.local_rank != -1: dist.barrier() - opt_state_dict = broadcast_dp_optimizer(opt_state_dict) + if not self.args.use_expert_parallel: + opt_state_dict = broadcast_dp_optimizer(opt_state_dict) + else: + state_dict = self.model.state_dict() + opt_state_dict = broadcast_moe_optimizer(state_dict, opt_state_dict) if opt_state_dict is not None: # Load in optimizer and scheduler states @@ -2939,6 +3002,8 @@ def prediction_step( if has_labels: with self.autocast_smart_context_manager(): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + if isinstance(loss, dict): + loss = loss.pop("loss") loss = loss.mean().detach() if isinstance(outputs, dict): diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 6dded0418619..038b727cf2f9 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -804,7 +804,7 @@ class TrainingArguments: default=False, metadata={"help": "whether to run distributed training in auto parallel mode"}, ) - use_moe: Optional[bool] = field( + use_expert_parallel: Optional[bool] = field( default=False, metadata={"help": "Use MoE training."}, ) @@ -1154,7 +1154,7 @@ def is_segment_parallel_supported(): order = ["dp", "sharding", "pp", "sep", "mp"] else: order = ["dp", "sharding", "pp", "mp"] - if self.use_moe: + if self.use_expert_parallel: order = order[1:-1] + ["dp", "mp"] if is_segment_parallel_supported(): @@ -1649,12 +1649,12 @@ def optimizer_name_suffix(self): name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) if self.sharding_parallel_degree > 1: name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree)) - if self.use_moe: - name.append(f"moe{self.data_parallel_rank:0>2d}") + if self.use_expert_parallel: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: - if self.use_moe: - return f"moe{self.data_parallel_rank:0>2d}" + if self.use_expert_parallel: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None @property @@ -1665,13 +1665,13 @@ def weight_name_suffix(self): name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree)) if self.pipeline_parallel_degree > 1: name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) - if self.use_moe: - name.append(f"moe{self.data_parallel_rank:0>2d}") + if self.use_expert_parallel: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: - if self.use_moe: - return f"moe{self.data_parallel_rank:0>2d}" + if self.use_expert_parallel: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None def sharded_name_suffix(self, shard_id=None, pp_id=None): @@ -1787,6 +1787,29 @@ def should_save_model_state(self): else: return self.process_index == 0 + @property + def should_save_moe_model_state(self): + """ + Whether or not the current process should write to disk, e.g., to save moe models and checkpoints. + + For model state: + work for data parallel, tensor parallel, sharding + For optimizer state: + work for data parallel, tensor parallel + not work for sharding + """ + if self.save_on_each_node: + return self.local_process_index == 0 + else: + if self.should_save_sharding_stage1_model: + return True + elif self.enable_auto_parallel: + return True + elif self.use_hybrid_parallel: + return self.sharding_parallel_rank == 0 + else: + return self.process_index == 0 + @property def _no_sync_in_gradient_accumulation(self): """ diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py index 25f593f71e35..37162e2fed80 100644 --- a/paddlenlp/trainer/utils/helper.py +++ b/paddlenlp/trainer/utils/helper.py @@ -226,3 +226,36 @@ def broadcast_dp_optimizer(state_dict): state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group) return state_dict + + +def broadcast_moe_optimizer(state_dict, opt_state_dict): + no_sync_vname = [] + for k, v in state_dict.items(): + if getattr(v, "no_sync", False): + no_sync_vname.append(v.name) + new_opt_state_dict = broadcast_dp_optimizer(opt_state_dict) + # 1. when updating opt_state_dict, we should disable broading the parameters with the same name when `no_sync=True`. + # 2. if the keys of opt_state_dict and new_opt_state_dict are exactly the same, there is no need to update. + # 3. if they are different, the update should be based on the `no_sync_vname`. + if len(opt_state_dict.keys()) != len(new_opt_state_dict.keys()): + for op_k, op_v in new_opt_state_dict.items(): + if op_k == "master_weights": + for k, v in new_opt_state_dict["master_weights"].items(): + no_sync = False + for no_sync_v in no_sync_vname: + if k.startswith(no_sync_v): + no_sync = True + break + if not no_sync: + opt_state_dict["master_weights"][k] = v + elif op_k == "LR_Scheduler": + pass + else: + no_sync = False + for no_sync_v in no_sync_vname: + if op_k.startswith(no_sync_v): + no_sync = True + break + if not no_sync: + opt_state_dict[op_k] = op_v + return opt_state_dict