diff --git a/docs/trainer.md b/docs/trainer.md index a1dde0af4f94..736226dfb270 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -719,4 +719,8 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并 Whether use flatten_param_grads method in optimizer, only used on NPU devices.(default:False) + --use_expert_parallel + Whether to enable MoE (Mixture of Experts) expert parallel training. + (default: False) + ``` diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01b902478622..eacb4a6780a9 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -143,6 +143,7 @@ from .utils import reshard as reshard_util from .utils.helper import ( # nested_truncate, broadcast_dp_optimizer, + broadcast_moe_optimizer, distributed_concat, distributed_file, distributed_isfile, @@ -565,7 +566,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): ) self.model.set_state_dict(state_dict) else: - if resume_from_checkpoint is not None and self.args.dataset_rank == 0: + if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel): weights_file = os.path.join( resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix) @@ -930,22 +931,17 @@ def _inner_training_loop( self.control = self.callback_handler.on_step_begin(args, self.state, self.control) self.timers and self.timers("forward-backward").start() - dp_enabled = ( - self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1 - ) - forbidden_no_sync = False # stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API # hybrid_parallel (tp or pp or sharding stage 1) should not no_sync - if self.args.use_hybrid_parallel: - forbidden_no_sync = True - - availiable_no_sync = dp_enabled and not forbidden_no_sync - + availiable_no_sync = hasattr(model, "no_sync") is_no_sync = ( - ((step_control + 1) % args.gradient_accumulation_steps != 0) - and availiable_no_sync - and args._no_sync_in_gradient_accumulation - ) or (args.recompute and availiable_no_sync) + ( + ((step_control + 1) % args.gradient_accumulation_steps != 0) + and args._no_sync_in_gradient_accumulation + ) + or args.recompute + or args.use_expert_parallel + ) and availiable_no_sync # sharding # stage1. the same as ddp # stage2. manualy collect gradient on dp group @@ -965,6 +961,14 @@ def _inner_training_loop( tr_loss += tr_loss_step + def fused_allreduce_gradients_no_sync(paramlist, hcg): + paramlist = list(paramlist) + nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)] + moelist = [p for p in paramlist if getattr(p, "no_sync", False)] + if moelist and not self.args.use_expert_parallel: + logger.warning("found `no sync` param when `use_expert_parallel=False`") + fused_allreduce_gradients(nonmoe_list, hcg) + if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps @@ -983,12 +987,12 @@ def _inner_training_loop( # Case 1: Use recompute and dp / sharding stage1, # manualy collect gradient for dp. - if args.recompute and availiable_no_sync: - fused_allreduce_gradients(list(model.parameters()), None) + if (args.recompute or args.use_expert_parallel) and availiable_no_sync: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Case 2: hack dp with master_grad - if dp_master_grad and not (args.recompute and availiable_no_sync): - fused_allreduce_gradients(list(model.parameters()), None) + elif dp_master_grad: + fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap pipeline_parallel_config = ( @@ -1007,8 +1011,7 @@ def _inner_training_loop( self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg) if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False): - fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg) - + fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg) self.timers and self.timers("all-reduce").stop() self.timers and self.timers("optimizer-step").start() @@ -1028,6 +1031,8 @@ def _inner_training_loop( ) optimizer_was_run = True if self.do_grad_scaling: + if args.pipeline_parallel_degree > 1: + assert not self.args.use_expert_parallel, "pipeline moe not work under fp16" scale_before = paddle.assign(self.scaler._scale) self.scaler.step(self.optimizer) self.scaler.update() @@ -2042,7 +2047,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, model.train() inputs = self._prepare_inputs(inputs) - with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) @@ -2053,7 +2057,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, self.scaler.scale(loss).backward() else: loss.backward() - return loss.detach() def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: @@ -2143,6 +2146,26 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done")) + def _filter_moe_no_sync_optimizer_params(self): + """ + filter optimizer params which should not sync + """ + state_dict = self.model.state_dict() + optimzier_state_dict = self.optimizer.state_dict() + filter_optimzier_state_dict = OrderedDict() + param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else [] + filter_optimzier_state_dict["master_weights"] = OrderedDict() + for k, v in state_dict.items(): + if getattr(v, "no_sync", False): + if v.name in param_names_in_master_weights: + filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][ + v.name + ] + for op_k, op_v in optimzier_state_dict.items(): + if op_k.startswith(v.name): + filter_optimzier_state_dict[op_k] = op_v + return filter_optimzier_state_dict + def _save_checkpoint(self, model, metrics=None): # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" self.runtime_timer.start("checkpoint saving time") @@ -2165,7 +2188,7 @@ def _save_checkpoint(self, model, metrics=None): optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix) if self.args.use_hybrid_parallel: - if self.dp_group.rank <= 0: + if self.dp_group.rank <= 0 or self.args.use_expert_parallel: os.makedirs(output_dir, exist_ok=True) logger.info("Saving optimizer files.") if self.args.unified_checkpoint: @@ -2177,12 +2200,18 @@ def _save_checkpoint(self, model, metrics=None): safe_serialization=True, ) else: - self._save_ckpt_func( - self.optimizer.state_dict(), - os.path.join(output_dir, optimizer_name), - ) + if self.dp_group.rank > 0: # this should only work for MoE saving + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + ) + else: + self._save_ckpt_func( + self.optimizer.state_dict(), + os.path.join(output_dir, optimizer_name), + ) - if self.args.should_save: + if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") if self.args.unified_checkpoint: @@ -2194,7 +2223,12 @@ def _save_checkpoint(self, model, metrics=None): safe_serialization=True, ) else: - self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + if self.dp_group.rank > 0: + self._save_ckpt_func( + self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME) + ) + else: + self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -2452,7 +2486,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint") if not use_unified_checkpoint: - if self.args.data_parallel_rank == 0: + if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel: optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix) path = os.path.join(checkpoint, optimizer_name) if os.path.isfile(path): @@ -2476,7 +2510,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): # broadcast optimizer state in dp group if self.args.local_rank != -1: dist.barrier() - opt_state_dict = broadcast_dp_optimizer(opt_state_dict) + if self.args.use_expert_parallel: + opt_state_dict = broadcast_moe_optimizer(opt_state_dict) + else: + if not self.args.should_load_sharding_stage1_model: + opt_state_dict = broadcast_dp_optimizer(opt_state_dict) if opt_state_dict is not None: # Load in optimizer and scheduler states diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index d30876a0f2d2..51b9ac4027dd 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -803,6 +803,10 @@ class TrainingArguments: default=False, metadata={"help": "whether to run distributed training in auto parallel mode"}, ) + use_expert_parallel: Optional[bool] = field( + default=False, + metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"}, + ) def __post_init__(self): env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1)) @@ -1149,6 +1153,8 @@ def is_segment_parallel_supported(): order = ["dp", "sharding", "pp", "sep", "mp"] else: order = ["dp", "sharding", "pp", "mp"] + if self.use_expert_parallel: + order = order[1:-1] + ["dp", "mp"] if is_segment_parallel_supported(): hybrid_configs = { @@ -1640,8 +1646,12 @@ def optimizer_name_suffix(self): name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) if self.sharding_parallel_degree > 1: name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree)) + if self.use_expert_parallel: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: + if self.use_expert_parallel: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None @property @@ -1652,12 +1662,16 @@ def weight_name_suffix(self): name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree)) if self.pipeline_parallel_degree > 1: name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree)) + if self.use_expert_parallel: + name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)) return "_".join(name) else: + if self.use_expert_parallel: + return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree) return None - def sharded_name_suffix(self, shard_id=None, pp_id=None): + def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None): if self.use_hybrid_parallel: name = [] if self.tensor_parallel_degree > 1: @@ -1672,8 +1686,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None): shard_id = self.sharding_parallel_rank assert isinstance(shard_id, int) name.append(self._format_name("shard", shard_id, self.sharding_parallel_degree)) + if self.use_expert_parallel: + if moe_id is None: + moe_id = self.data_parallel_rank + assert isinstance(moe_id, int) + name.append(self._format_name("moe", moe_id, self.data_parallel_degree)) return "_".join(name) else: + if self.use_expert_parallel: + if moe_id is None: + moe_id = self.data_parallel_rank + return self._format_name("moe", moe_id, self.data_parallel_degree) return None @property @@ -1766,9 +1789,9 @@ def should_save_model_state(self): return True elif self.use_hybrid_parallel: # save on dataset rank 0 - return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0 + return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel) else: - return self.process_index == 0 + return self.process_index == 0 or self.use_expert_parallel @property def _no_sync_in_gradient_accumulation(self): diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py index 25f593f71e35..12aec88bc41b 100644 --- a/paddlenlp/trainer/utils/helper.py +++ b/paddlenlp/trainer/utils/helper.py @@ -226,3 +226,59 @@ def broadcast_dp_optimizer(state_dict): state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group) return state_dict + + +def broadcast_moe_optimizer(state_dict): + + try: + hcg = fleet.get_hybrid_communicate_group() + dp_group = hcg.get_data_parallel_group() + src_rank = hcg.get_data_parallel_group_src_rank() + data_parallel_rank = hcg.get_data_parallel_rank() + # Don't broadcast optimizer for dp rank is 1. + if dp_group.nranks <= 1: + return state_dict + except: + dp_group = None + src_rank = 0 + data_parallel_rank = 0 + + def _broadcast_moe_optimizer_state(state_dict): + # boardcast_keys + base_state_dict = {"master_weights": {}} + buf = [ + {i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]}, + {i: j.shape for i, j in state_dict["master_weights"].items()}, + {"LR_Scheduler": state_dict.get("LR_Scheduler", {})}, + ] + + dist.broadcast_object_list(buf, src=src_rank, group=dp_group) + # logger.info(f"moe-optimizer-gather-keys{buf}") + for k, s in buf[0].items(): + v = state_dict.get(k, paddle.zeros(s, "float32")).cuda() + v.name = k + # k = k.replace("_fp32_master_0", "") + dist.broadcast(v, src=src_rank, group=dp_group) + logger.info(f"broadcast moe optimizer {k} from {src_rank}") + base_state_dict[k] = v.cpu() + for k, s in buf[1].items(): + v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda() + v.name = k + dist.broadcast(v, src=src_rank, group=dp_group) + logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}") + base_state_dict["master_weights"][k] = v.cpu() + base_state_dict.update(buf[2]) + return base_state_dict + + base_state_dict = _broadcast_moe_optimizer_state(state_dict) + if data_parallel_rank > 0: + master_weight = state_dict.pop("master_weights", {}) + base_state_dict.update(state_dict) + if master_weight: + if "master_weights" in base_state_dict: + base_state_dict["master_weights"].update(master_weight) + else: + base_state_dict["master_weights"] = master_weight + state_dict = base_state_dict + del base_state_dict + return state_dict diff --git a/paddlenlp/trainer/utils/reshard/common.py b/paddlenlp/trainer/utils/reshard/common.py index cc834862e299..66e3c3569916 100644 --- a/paddlenlp/trainer/utils/reshard/common.py +++ b/paddlenlp/trainer/utils/reshard/common.py @@ -266,6 +266,16 @@ def _opt_name_to_tname(tensor_names, opt_names): all_names.extend(opt_names) all_names.sort() pre_t_name = "" + suffix = [ + "_fp32_master_0_beta1_pow_acc_0", + "_fp32_master_0_beta2_pow_acc_0", + "_fp32_master_0_moment1_0", + "_fp32_master_0_moment2_0", + "_beta1_pow_acc_0", + "_beta2_pow_acc_0", + "_moment1_0", + "_moment2_0", + ] opt_to_t = {} for n in all_names: if n in tensor_names: @@ -274,6 +284,16 @@ def _opt_name_to_tname(tensor_names, opt_names): else: assert pre_t_name opt_to_t[n] = pre_t_name + + for t in opt_names: + _find = False + for s in suffix: + if t.endswith(s): + logger.info(f"{t}-{t[:-len(s)]}--{t[:-len(s)] in tensor_names}") + opt_to_t[t] = t[: -len(s)] + _find = True + break + assert _find return opt_to_t if structure_name_mapping is not None: @@ -291,7 +311,7 @@ def _opt_name_to_tname(tensor_names, opt_names): (self._model_weights, model_weights_tmp) = (model_weights_tmp, self._model_weights) for k in list(model_weights_tmp.keys()): t_name = structure_name_mapping[k] - self._model_weights[(k, t_name)] = model_weights_tmp[k].cpu() + self._model_weights[(k, t_name)] = paddle.to_tensor(model_weights_tmp[k]).cpu() del model_weights_tmp[k] # opt diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 56f4c426ce0a..4fe55d175005 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -67,11 +67,14 @@ def filter_sharded_params(state_dict, optimizer, sharding_group): if reshard_util.get_sharding_strategy(optimizer) == reshard_util.SHARDING_STRATEGY_V1: optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizer) for (k, v) in state_dict.items(): - assert v.name in optimizer._param2rank - sharded_rank = optimizer._param2rank[v.name] - if sharded_rank != sharding_rank: - continue - filtered_state_dict[k] = v + if v.name in optimizer._param2rank: + sharded_rank = optimizer._param2rank[v.name] + if sharded_rank != sharding_rank: + continue + filtered_state_dict[k] = v + else: + if sharding_rank == 0: + filtered_state_dict[k] = v else: optimizer = unwrap_optimizer(optimizer, DygraphShardingOptimizerV2) parameters = optimizer._parameter_list @@ -352,7 +355,7 @@ def manipulate_state_dict_and_config(self, model_to_save, merge_tensor_parallel= ) logger.info( "param_names_in_master_weights len:{}, bf16 state_dict len:{}, :{}".format( - len(param_names_in_master_weights), len(state_dict), state_dict + len(param_names_in_master_weights), len(state_dict), state_dict.keys() ) ) return state_dict, config_to_save, weight_name_suffix @@ -444,12 +447,17 @@ def filter_func(name): master_weights = reshard_util.all_gather_state_dict(master_weights, filter_func, self.sharding_group) model_state_dict = self.model.state_dict() + logger.info(f"state-dict-keys: {state_dict.keys()}, nums: {len(state_dict.keys())}") logger.info("before recover, model_state_dict number: {}".format(len(model_state_dict))) for key, param in model_state_dict.items(): if param.name in master_weights: assert param.shape == master_weights[param.name].shape - paddle.assign(master_weights[param.name].cuda(), model_state_dict[key]) - + paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key]) + elif key in state_dict: + logger.info(f"key: {key} is in state_dict, but not in master_weights") + paddle.assign(state_dict[key], model_state_dict[key]) + else: + logger.info(f"key: {key} is not in state_dict and master_weights") logger.info("after recover, casted model_state_dict number: {}".format(len(model_state_dict))) state_dict.update(model_state_dict) return state_dict