Skip to content

Commit

Permalink
Cherry-pick some PRs from incubate/paddlenlp-fleety (#9253)
Browse files Browse the repository at this point in the history
* support pp-sharding reshard (#9153)

* support best unbalaced pp scheduler (#9235)

* remove pp hack (#9189)

---------

Co-authored-by: Meiyim <chenxuyi@baidu.com>
  • Loading branch information
LiYuRio and Meiyim authored Oct 14, 2024
1 parent 3007c79 commit 0e96b0f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
9 changes: 0 additions & 9 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,13 +2270,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
self._pp_data_buffer = []

model.train()
# hack pipeline-layers
# since the pipeline layer will check input is valid every iter.
# in same case, for example, batch size warmup, we need dynamic change gradient_accumulation_steps to implement.
config_backup = model.micro_batch_size, model.accumulate_steps
model.micro_batch_size = self.args.per_device_train_batch_size
model.accumulate_steps = self.args.gradient_accumulation_steps

if model._dp_comm_overlap or model._sharding_comm_overlap:
for _, buffers in model._chunk_2_comm_buffers.items():
for buffer in buffers:
Expand All @@ -2291,8 +2284,6 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle
with self.autocast_smart_context_manager():
loss = model.forward_backward_pipeline(inputs, self.scaler if self.do_grad_scaling else None)

model.micro_batch_size, model.accumulate_steps = config_backup

return loss.detach()

def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False):
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,7 @@ def split_parallel_config(parallel_config):
"enable_clear_every_step_cache",
"enable_overlap_p2p_comm",
"disable_batch_p2p_comm",
"best_unbalanced_scheduler",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
Expand Down Expand Up @@ -1158,6 +1159,7 @@ def split_parallel_config(parallel_config):
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
"clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config,
"use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,
"best_unbalanced_scheduler": "best_unbalanced_scheduler" in pipeline_parallel_config,
}
if dygraph_pp_configs["dp_comm_overlap"]:
raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue
Expand Down
35 changes: 32 additions & 3 deletions paddlenlp/trainer/utils/reshard/pp_reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import OrderedDict

from paddle.distributed.fleet.model import PipelineParallel
Expand Down Expand Up @@ -46,6 +45,25 @@ def get_index_layer_func():
return _GLOBAL_INDEX_LAYER_FUNC


_GLOBAL_SNAME_TO_TNAME_FUNC = None


def register_sname_to_tname_func(func):
global _GLOBAL_SNAME_TO_TNAME_FUNC
_GLOBAL_SNAME_TO_TNAME_FUNC = func


def has_register_sname_to_tname_func():
global _GLOBAL_SNAME_TO_TNAME_FUNC
return _GLOBAL_SNAME_TO_TNAME_FUNC is not None


def get_sname_to_tname_func():
global _GLOBAL_SNAME_TO_TNAME_FUNC
assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None, "sname to tname func is not registered yet"
return _GLOBAL_SNAME_TO_TNAME_FUNC


class LayerNameScope:
"""
layer name scope for a layer, layer name of the same kind of layer will be named consecutively
Expand Down Expand Up @@ -206,6 +224,7 @@ def __init__(self):
self._segments = OrderedDict()
self._layer_to_segment = OrderedDict()
self._param_to_tname = OrderedDict()
self._wname_to_rname = OrderedDict()

def add_segment(self, start_index, end_index):
segment = PipeLineSegment(start_index, end_index)
Expand All @@ -218,19 +237,24 @@ def add_layer(self, layer_index, layer_name, param_names):
segment = self._layer_to_segment[layer_index]
segment.add_layer(layer_name, param_names)

def build_name_mapping(self):
def build_name_mapping(self, sname_to_tname=None):
for (k, segment) in self._segments.items():
for (i, layer) in segment.layers.items():
for param in layer.params.items():
(param_name, tensor_name) = param
# map to a new name
n_name = self._rename_mgr.get_new_param_name(layer.name, tensor_name)
if sname_to_tname is not None:
if param_name in sname_to_tname.keys():
self._wname_to_rname[param_name] = sname_to_tname[param_name]
# logger.info(f"{param_name} {tensor_name}=>{n_name}")
self._param_to_tname[param_name] = (tensor_name, n_name)

def map_name(self, param_name, t_name):
assert param_name in self._param_to_tname
tensor_name, n_name = self._param_to_tname[param_name]
if param_name in self._wname_to_rname:
n_name = self._wname_to_rname[param_name]
assert tensor_name == t_name
return n_name

Expand Down Expand Up @@ -261,6 +285,11 @@ def __init__(
self._index_layers()

stage_segments = self._segment()
if has_register_sname_to_tname_func():
self._sname_to_tname = get_sname_to_tname_func()(pp_model)
else:
self._sname_to_tname = None

for (i, stage_seg) in enumerate(stage_segments):
pipe_stage = PipeLineStage()
self._stages.append(pipe_stage)
Expand All @@ -275,7 +304,7 @@ def __init__(
self._layer_name_to_stage[layer_name] = i

for stage in self._stages:
stage.build_name_mapping()
stage.build_name_mapping(self._sname_to_tname)

def _index_layers(self):
for layer_name in self._param_names_by_layer.keys():
Expand Down

0 comments on commit 0e96b0f

Please sign in to comment.