Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Support sharding_comm_overlap #9392

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
set_seed,
should_skip_data,
speed_metrics,
split_parallel_config,
)
from .training_args import TrainingArguments
from .unified_checkpoint import UnifiedCheckpointHandler
Expand Down Expand Up @@ -2053,6 +2054,14 @@
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)

if (

Check warning on line 2057 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2057

Added line #L2057 was not covered by tests
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and self.args.unified_checkpoint
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model.register_sharding_comm_overlap_hook(self.optimizer)

Check warning on line 2063 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2063

Added line #L2063 was not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZHUI 看一下这个要不要专门针对uc来打开这个开关

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要的,最小影响到其他策略。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了一个判断条件,针对split_param开启了再打开


# No pipeline mode, sharding only
if not in_pipeline_parallel_mode and in_sharding_parallel_mode:
# Sharded DDP!
Expand Down Expand Up @@ -2840,8 +2849,15 @@
else:
opt_state_dict = None
else:
model = self.model
if (

Check warning on line 2853 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2852-L2853

Added lines #L2852 - L2853 were not covered by tests
hasattr(self.args, "enable_sharding_comm_overlap")
and self.args.enable_sharding_comm_overlap
and "split_param" in split_parallel_config(self.args.sharding_parallel_config)
):
model = self.model_wrapped

Check warning on line 2858 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2858

Added line #L2858 was not covered by tests
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
model=self.model,
model=model,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emm,会不会 state_dict 的 name 前面又套了一层其他的东西
比如 model.model.embedding ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相关情况已处理

optimizer=self.optimizer,
resume_from_checkpoint=checkpoint,
)
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,3 +1126,11 @@
skip_flag = True
break
return skip_flag


def split_parallel_config(parallel_config):
if "," in parallel_config:
parallel_config = set(parallel_config.split(","))

Check warning on line 1133 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L1132-L1133

Added lines #L1132 - L1133 were not covered by tests
else:
parallel_config = set(parallel_config.split(" "))
return parallel_config

Check warning on line 1136 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L1135-L1136

Added lines #L1135 - L1136 were not covered by tests
25 changes: 5 additions & 20 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
OptimizerNames,
SchedulerType,
ShardingOption,
split_parallel_config,
)

try:
Expand Down Expand Up @@ -1096,13 +1097,6 @@
logger.warning("set amp_master_grad to false since amp is disabled.")
self.amp_master_grad = False

def split_parallel_config(parallel_config):
if "," in parallel_config:
parallel_config = set(parallel_config.split(","))
else:
parallel_config = set(parallel_config.split(" "))
return parallel_config

# use_hybrid_parallel
if self.use_hybrid_parallel:

Expand Down Expand Up @@ -1155,29 +1149,20 @@
or "enable_dp_comm_overlap" in pipeline_parallel_config
)
enable_dp_comm_overlap = using_comm_overlap and self.data_parallel_degree > 1
enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1
self.enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1

Check warning on line 1152 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1152

Added line #L1152 was not covered by tests
assert not (
enable_dp_comm_overlap and enable_sharding_comm_overlap
enable_dp_comm_overlap and self.enable_sharding_comm_overlap
), "dp_comm_overlap and sharding_comm_overlap cannot be enabled at the same time"

if enable_sharding_comm_overlap and not self.amp_master_grad:
if self.enable_sharding_comm_overlap and not self.amp_master_grad:

Check warning on line 1157 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1157

Added line #L1157 was not covered by tests
raise ValueError(
"If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True."
)
if (
enable_sharding_comm_overlap
and self.unified_checkpoint
and "split_param" in split_parallel_config(self.sharding_parallel_config)
):
logger.warning(
"Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`."
)
enable_sharding_comm_overlap = False

dygraph_pp_configs = {
"delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False,
"dp_comm_overlap": enable_dp_comm_overlap,
"sharding_comm_overlap": enable_sharding_comm_overlap,
"sharding_comm_overlap": self.enable_sharding_comm_overlap,
"enable_timer": "enable_timer" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads,
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/check_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@
sharding_group = hcg.get_sharding_parallel_group()
sharding_rank = sharding_group.rank
dp_rank = dp_group.rank if dp_group.nranks > 1 else 0
struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()}
model_state_dict = get_expected_state_dict(model)
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}

Check warning on line 154 in paddlenlp/trainer/unified_checkpoint/check_completion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/check_completion.py#L153-L154

Added lines #L153 - L154 were not covered by tests

if is_sharding_split_param_mode(args):
# We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume.
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/unified_checkpoint/load_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False):
# Special process with split param.
if is_sharding_split_param_mode(args):
returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint)
returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint)

Check warning on line 153 in paddlenlp/trainer/unified_checkpoint/load_local.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/load_local.py#L153

Added line #L153 was not covered by tests
return returned_optim_state_dict

# init and get optimizer LR_Scheduler
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

import gc
import os
from itertools import chain

import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from tqdm.auto import tqdm

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import load_state_dict
from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model
from paddlenlp.utils.env import (
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_OPTIMIZER_INDEX_NAME,
Expand Down Expand Up @@ -97,6 +98,7 @@
global_rank = dist.get_rank()
param_slice_info = {}
param_shape_info = {}

for buffer in optimizer._inner_opt._comm_buffer_list:
for key in buffer._sharding_param_grad_view.keys():
param_slice_info[key] = (
Expand Down Expand Up @@ -153,7 +155,7 @@
return optim_state_dict, master_weights


def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint):
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
returned_optim_state_dict = nested_copy(optimizer.state_dict())

index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME
Expand All @@ -177,7 +179,13 @@
expected_keys = []
param_slice_info = {}
param_shape_info = {}
for buffer in optimizer._inner_opt._comm_buffer_list:

comm_buffer_list = optimizer._inner_opt._comm_buffer_list
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
model = unwrap_model(model)

Check warning on line 186 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L183-L186

Added lines #L183 - L186 were not covered by tests

for buffer in comm_buffer_list:

Check warning on line 188 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L188

Added line #L188 was not covered by tests
for key in buffer._sharding_param_grad_view.keys():
begin = buffer._sharding_param_grad_view[key]._param_begin
end = buffer._sharding_param_grad_view[key]._param_end
Expand Down
10 changes: 8 additions & 2 deletions paddlenlp/trainer/unified_checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption
from paddlenlp.trainer.utils.helper import distributed_isfile
from paddlenlp.transformers.model_utils import PretrainedModel, get_parameter_dtype
from paddlenlp.transformers.model_utils import (
PretrainedModel,
get_parameter_dtype,
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import (
Expand Down Expand Up @@ -193,6 +197,8 @@
"""
Get trainable state_dict of model_to_save.
"""
model_to_save = unwrap_model(model_to_save)

Check warning on line 200 in paddlenlp/trainer/unified_checkpoint/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/utils.py#L200

Added line #L200 was not covered by tests

if isinstance(model_to_save, PretrainedModel):
state_dict = model_to_save.state_dict()
if (
Expand Down Expand Up @@ -221,7 +227,7 @@
params2rank = optimizer._param2rank

model_state_dict = get_expected_state_dict(model)
struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()}
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}

Check warning on line 230 in paddlenlp/trainer/unified_checkpoint/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/utils.py#L230

Added line #L230 was not covered by tests

expected_keys = []
for key in list(sharded_metadata["all_optimizer_keys"]):
Expand Down
Loading