Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Nov 12, 2024
1 parent 10a62c7 commit 1e9f456
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 19 deletions.
12 changes: 11 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,13 @@ def get_expected_keys(inputs, keys):
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if (
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and self.args.unified_checkpoint
):
model.register_sharding_comm_overlap_hook(self.optimizer)

# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
# Sharded DDP!
Expand Down Expand Up @@ -2840,8 +2847,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
else:
opt_state_dict = None
else:
model = self.model
if hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap:
model = self.model_wrapped
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
model=self.model,
model=model,
optimizer=self.optimizer,
resume_from_checkpoint=checkpoint,
)
Expand Down
17 changes: 4 additions & 13 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,29 +1155,20 @@ def split_parallel_config(parallel_config):
or "enable_dp_comm_overlap" in pipeline_parallel_config
)
enable_dp_comm_overlap = using_comm_overlap and self.data_parallel_degree > 1
enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1
self.enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1
assert not (
enable_dp_comm_overlap and enable_sharding_comm_overlap
enable_dp_comm_overlap and self.enable_sharding_comm_overlap
), "dp_comm_overlap and sharding_comm_overlap cannot be enabled at the same time"

if enable_sharding_comm_overlap and not self.amp_master_grad:
if self.enable_sharding_comm_overlap and not self.amp_master_grad:
raise ValueError(
"If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True."
)
if (
enable_sharding_comm_overlap
and self.unified_checkpoint
and "split_param" in split_parallel_config(self.sharding_parallel_config)
):
logger.warning(
"Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`."
)
enable_sharding_comm_overlap = False

dygraph_pp_configs = {
"delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False,
"dp_comm_overlap": enable_dp_comm_overlap,
"sharding_comm_overlap": enable_sharding_comm_overlap,
"sharding_comm_overlap": self.enable_sharding_comm_overlap,
"enable_timer": "enable_timer" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads,
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/unified_checkpoint/load_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _remove_unused_keys(
def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint)
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import gc
import os
from itertools import chain

import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from tqdm.auto import tqdm

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import load_state_dict
from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model
from paddlenlp.utils.env import (
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_OPTIMIZER_INDEX_NAME,
Expand Down Expand Up @@ -97,6 +98,7 @@ def gather_splited_param_for_optimizer(optimizer):
global_rank = dist.get_rank()
param_slice_info = {}
param_shape_info = {}

for buffer in optimizer._inner_opt._comm_buffer_list:
for key in buffer._sharding_param_grad_view.keys():
param_slice_info[key] = (
Expand Down Expand Up @@ -153,7 +155,7 @@ def gather_splited_param_for_optimizer(optimizer):
return optim_state_dict, master_weights


def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint):
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
returned_optim_state_dict = nested_copy(optimizer.state_dict())

index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
Expand All @@ -177,7 +179,13 @@ def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint)
expected_keys = []
param_slice_info = {}
param_shape_info = {}
for buffer in optimizer._inner_opt._comm_buffer_list:

comm_buffer_list = optimizer._inner_opt._comm_buffer_list
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
model = unwrap_model(model)

for buffer in comm_buffer_list:
for key in buffer._sharding_param_grad_view.keys():
begin = buffer._sharding_param_grad_view[key]._param_begin
end = buffer._sharding_param_grad_view[key]._param_end
Expand Down
8 changes: 7 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption
from paddlenlp.trainer.utils.helper import distributed_isfile
from paddlenlp.transformers.model_utils import PretrainedModel, get_parameter_dtype
from paddlenlp.transformers.model_utils import (
PretrainedModel,
get_parameter_dtype,
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import (
Expand Down Expand Up @@ -193,6 +197,8 @@ def get_expected_state_dict(model_to_save):
"""
Get trainable state_dict of model_to_save.
"""
model_to_save = unwrap_model(model_to_save)

if isinstance(model_to_save, PretrainedModel):
state_dict = model_to_save.state_dict()
if (
Expand Down

0 comments on commit 1e9f456

Please sign in to comment.