diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 0497cfba304e..9ecfb137d688 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -89,6 +89,7 @@ from ..transformers.model_utils import ( PretrainedModel, _add_variant, + _load_state_dict_into_model, load_sharded_checkpoint, unwrap_model, ) @@ -149,6 +150,7 @@ from .utils import reshard as reshard_util from .utils.async_save import AsyncSaver from .utils.helper import ( # nested_truncate, + broadcast_dataset_rank0_model, broadcast_dp_optimizer, broadcast_moe_optimizer, distributed_concat, @@ -1161,6 +1163,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): self.state.best_model_checkpoint, safe_serialization=True, ) + if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: + state_dict = broadcast_dataset_rank0_model(self.model.state_dict()) + if self.args.dataset_rank > 0: + _load_state_dict_into_model(self.model, state_dict, "") else: weight_name = PADDLE_WEIGHTS_NAME best_model_path = os.path.join( @@ -1203,6 +1209,10 @@ def _load_best_model_from_peft_checkpoint(self): self.state.best_model_checkpoint, safe_serialization=True, ) + if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: + state_dict = broadcast_dataset_rank0_model(self.model.get_trainable_state_dict()) + if self.args.dataset_rank > 0: + _load_state_dict_into_model(self.model, state_dict, "") return convert_tp = False diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py index b9fc7b0002e0..7f0e87d0f9e2 100644 --- a/paddlenlp/trainer/utils/helper.py +++ b/paddlenlp/trainer/utils/helper.py @@ -309,3 +309,21 @@ def _broadcast_moe_optimizer_state(state_dict): state_dict = base_state_dict del base_state_dict return state_dict + + +def broadcast_dataset_rank0_model(state_dict): + if paddle.distributed.get_world_size() <= 1: + return state_dict + + logger.info("Start broadcast model in sharding group or data parallel group.") + hcg = fleet.get_hybrid_communicate_group() + sharding_group = hcg.get_sharding_parallel_group() + dp_group = hcg.get_data_parallel_group() + + if sharding_group.nranks > 1: + for k in state_dict.keys(): + dist.broadcast(state_dict[k], src=hcg.get_sharding_parallel_group_src_rank(), group=sharding_group) + if dp_group.nranks > 1: + for k in state_dict.keys(): + dist.broadcast(state_dict[k], src=hcg.get_data_parallel_group_src_rank(), group=dp_group) + return state_dict