Skip to content

Commit

Permalink
fix load best
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 14, 2024
1 parent 75c7636 commit 2097916
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
10 changes: 10 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
from ..transformers.model_utils import (
PretrainedModel,
_add_variant,
_load_state_dict_into_model,
load_sharded_checkpoint,
unwrap_model,
)
Expand Down Expand Up @@ -149,6 +150,7 @@
from .utils import reshard as reshard_util
from .utils.async_save import AsyncSaver
from .utils.helper import ( # nested_truncate,
broadcast_dataset_rank0_model,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
distributed_concat,
Expand Down Expand Up @@ -1161,6 +1163,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
self.state.best_model_checkpoint,
safe_serialization=True,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
state_dict = broadcast_dataset_rank0_model(self.model.state_dict())
if self.args.dataset_rank > 0:
_load_state_dict_into_model(self.model, state_dict, "")

Check warning on line 1169 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1166-L1169

Added lines #L1166 - L1169 were not covered by tests
else:
weight_name = PADDLE_WEIGHTS_NAME
best_model_path = os.path.join(
Expand Down Expand Up @@ -1203,6 +1209,10 @@ def _load_best_model_from_peft_checkpoint(self):
self.state.best_model_checkpoint,
safe_serialization=True,
)
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
state_dict = broadcast_dataset_rank0_model(self.model.get_trainable_state_dict())
if self.args.dataset_rank > 0:
_load_state_dict_into_model(self.model, state_dict, "")

Check warning on line 1215 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1212-L1215

Added lines #L1212 - L1215 were not covered by tests
return

convert_tp = False
Expand Down
18 changes: 18 additions & 0 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,21 @@ def _broadcast_moe_optimizer_state(state_dict):
state_dict = base_state_dict
del base_state_dict
return state_dict


def broadcast_dataset_rank0_model(state_dict):
if paddle.distributed.get_world_size() <= 1:
return state_dict

Check warning on line 316 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L315-L316

Added lines #L315 - L316 were not covered by tests

logger.info("Start broadcast model in sharding group or data parallel group.")
hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
dp_group = hcg.get_data_parallel_group()

Check warning on line 321 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L318-L321

Added lines #L318 - L321 were not covered by tests

if sharding_group.nranks > 1:
for k in state_dict.keys():
dist.broadcast(state_dict[k], src=hcg.get_sharding_parallel_group_src_rank(), group=sharding_group)
if dp_group.nranks > 1:
for k in state_dict.keys():
dist.broadcast(state_dict[k], src=hcg.get_data_parallel_group_src_rank(), group=dp_group)
return state_dict

Check warning on line 329 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L323-L329

Added lines #L323 - L329 were not covered by tests

0 comments on commit 2097916

Please sign in to comment.