Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] fix bugs for split_batches_for_accumulation && fix bu… #9217

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 50 additions & 26 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,19 @@
dist_loader = self._wrap_for_dist_loader(train_dataloader)

if ShardingOption.SHARD_OP in self.args.sharding:
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage1())
self.optimizer = dist.shard_optimizer(

Check warning on line 121 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L121

Added line #L121 was not covered by tests
self.optimizer, dist.ShardingStage1(), self.args.gradient_accumulation_steps
)
elif ShardingOption.SHARD_GRAD_OP in self.args.sharding:
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage2())
self.optimizer = dist.shard_optimizer(

Check warning on line 125 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L125

Added line #L125 was not covered by tests
self.optimizer, dist.ShardingStage2(), self.args.gradient_accumulation_steps
)
elif ShardingOption.FULL_SHARD in self.args.sharding:
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage3())
self.optimizer = dist.shard_optimizer(

Check warning on line 129 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L129

Added line #L129 was not covered by tests
self.optimizer, dist.ShardingStage3(), self.args.gradient_accumulation_steps
)
else:
self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps)

Check warning on line 133 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L133

Added line #L133 was not covered by tests

if self.args.to_static:
unified_strategy = dist.Strategy()
Expand Down Expand Up @@ -173,41 +181,54 @@
if self.args.to_static and self._in_pir_mode and self.args.gradient_accumulation_steps > 1:
return [inputs]

local_batches = [{} for i in range(self.args.gradient_accumulation_steps)]
global_micro_batchs = [{} for i in range(self.args.gradient_accumulation_steps)]

Check warning on line 184 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L184

Added line #L184 was not covered by tests
assert isinstance(inputs, dict)

def split_dtensor_by_axis(dtensor, axis):
mesh = dtensor.process_mesh
placements = [dist.Replicate() for _ in range(len(mesh.shape))]
replicate_value = dist.reshard(dtensor, mesh, placements)
local_datas = replicate_value.split(self.args.gradient_accumulation_steps, axis=0)
return local_datas
def split_dtensor_by_axis(dtensor, axis=0):
if not dtensor._is_initialized():
return dtensor.split(self.args.gradient_accumulation_steps, axis=axis)

Check warning on line 189 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L187-L189

Added lines #L187 - L189 were not covered by tests

micro_batch_shape = dtensor.shape
micro_batch_shape[axis] = int(dtensor.shape[axis] / self.args.gradient_accumulation_steps)

Check warning on line 192 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L191-L192

Added lines #L191 - L192 were not covered by tests

global_micro_batchs = [

Check warning on line 194 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L194

Added line #L194 was not covered by tests
paddle.zeros(micro_batch_shape, dtype=dtensor.dtype)
for _ in range(self.args.gradient_accumulation_steps)
]
global_micro_batchs = [

Check warning on line 198 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L198

Added line #L198 was not covered by tests
dist.shard_tensor(b, dtensor.process_mesh, dtensor.placements) for b in global_micro_batchs
]

local_micro_batchs = dtensor._local_value().split(self.args.gradient_accumulation_steps, axis=axis)
for local_micro_batch, global_micro_batch in zip(local_micro_batchs, global_micro_batchs):
paddle.assign(local_micro_batch, global_micro_batch._local_value())
return global_micro_batchs

Check warning on line 205 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L202-L205

Added lines #L202 - L205 were not covered by tests

for key, dtensors in inputs.items():
if isinstance(dtensors, paddle.Tensor):
mesh, placements = dtensors.process_mesh, dtensors.placements
local_datas = split_dtensor_by_axis(dtensors, 0)
for index, data in enumerate(local_datas):
local_batches[index].update({key: dist.reshard(data, mesh, placements)})
global_datas = split_dtensor_by_axis(dtensors, 0)
for index, data in enumerate(global_datas):
global_micro_batchs[index].update({key: dist.reshard(data, mesh, placements)})

Check warning on line 212 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L210-L212

Added lines #L210 - L212 were not covered by tests
elif isinstance(dtensors, (list, tuple)):
if len(dtensors) == 0:
for i in range(self.args.gradient_accumulation_steps):
local_batches[i].update({key: []})
global_micro_batchs[i].update({key: []})

Check warning on line 216 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L216

Added line #L216 was not covered by tests
else:
for dtensor in dtensors:
if isinstance(dtensor, paddle.Tensor):
mesh, placements = dtensor.process_mesh, dtensor.placements
local_datas = split_dtensor_by_axis(dtensor, 0)
for index, data in enumerate(local_datas):
if key in local_batches[index].keys():
local_batches[index][key].append(dist.reshard(data, mesh, placements))
global_datas = split_dtensor_by_axis(dtensor, 0)
for index, data in enumerate(global_datas):
if key in global_micro_batchs[index].keys():
global_micro_batchs[index][key].append(dist.reshard(data, mesh, placements))

Check warning on line 224 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L221-L224

Added lines #L221 - L224 were not covered by tests
else:
local_batches[index].update({key: [dist.reshard(data, mesh, placements)]})
global_micro_batchs[index].update({key: [dist.reshard(data, mesh, placements)]})

Check warning on line 226 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L226

Added line #L226 was not covered by tests
else:
raise ValueError(f"unsupported type: {type(dtensor)}")
else:
raise ValueError(f"unsupported type: {type(dtensors)}")
return local_batches
return global_micro_batchs

Check warning on line 231 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L231

Added line #L231 was not covered by tests

def _inner_training_loop(
self,
Expand Down Expand Up @@ -372,6 +393,9 @@

self.timers and self.timers("optimizer-step").start()

if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss():
tr_loss /= self.args.gradient_accumulation_steps

Check warning on line 397 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L396-L397

Added lines #L396 - L397 were not covered by tests

# Optimizer step
self.callback_handler.on_optimizer_begin(
args, self.state, self.control, scaler=self.scaler if self.do_grad_scaling else None
Expand Down Expand Up @@ -503,11 +527,11 @@

return (loss, outputs) if return_outputs else loss

def dynamic_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:

Check warning on line 530 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L530

Added line #L530 was not covered by tests
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

if loss is not None and self.args.gradient_accumulation_steps > 1:
if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss():

Check warning on line 534 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L534

Added line #L534 was not covered by tests
loss = loss / self.args.gradient_accumulation_steps

if self.do_grad_scaling:
Expand All @@ -517,11 +541,11 @@

return loss

def static_traning(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
def static_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:

Check warning on line 544 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L544

Added line #L544 was not covered by tests
input_ids, labels = tuple(inputs.values())
loss = model(input_ids, labels)

if loss is not None and self.args.gradient_accumulation_steps > 1:
if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss():

Check warning on line 548 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L548

Added line #L548 was not covered by tests
loss = loss / self.args.gradient_accumulation_steps

return loss
Expand All @@ -532,9 +556,9 @@
inputs = self._prepare_inputs(inputs)

if not self.args.to_static:
loss = self.dynamic_traning(model, inputs)
loss = self.dynamic_training(model, inputs)

Check warning on line 559 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L559

Added line #L559 was not covered by tests
else:
loss = self.static_traning(model, inputs)
loss = self.static_training(model, inputs)

Check warning on line 561 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L561

Added line #L561 was not covered by tests

if isinstance(loss, paddle.Tensor):
return loss.detach() if loss._is_initialized() else float(0.0)
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,9 @@
return (loss, outputs) if return_outputs else loss

def _enable_delay_scale_loss(self):
if in_auto_parallel_align_mode():
return True

Check warning on line 2201 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2201

Added line #L2201 was not covered by tests

key = "enable_delay_scale_loss"
if self.args.pipeline_parallel_degree > 1:
return key in self.args.pipeline_parallel_config
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,7 @@ def is_segment_parallel_supported():
"enable_send_recv_overlap",
# "disable_p2p_cache_shape", # no need for auto_parallel
# "disable_partial_send_recv", # no implemenation for auto_parallel
# "enable_delay_scale_loss", # default True in auto_parallel, non-configurable
"enable_delay_scale_loss",
# "enable_dp_comm_overlap", # no implemenation for auto_parallel
# "enable_sharding_comm_overlap", # no implemenation for auto_parallel
# "enable_timer", # no implemenation for auto_parallel
Expand Down Expand Up @@ -1521,6 +1521,7 @@ def is_segment_parallel_supported():
if len(x) > 0:
if x not in [
"enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel
"enable_delay_scale_loss",
# "enable_mp_skip_c_identity",
# "enable_mp_fused_linear_param_grad_add",
]:
Expand Down
40 changes: 18 additions & 22 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.51876831
if [ $IS_A100 -ne 0 ];then
loss_base=9.53083992
loss_base=9.53084087
fi
ips_base=-1
mem_base=-1
Expand Down Expand Up @@ -629,9 +629,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.35078526
loss_base=9.3507843
if [ $IS_A100 -ne 0 ];then
loss_base=9.38577652
loss_base=9.38577747
fi
ips_base=-1
mem_base=-1
Expand Down Expand Up @@ -699,7 +699,7 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.35139465
loss_base=9.3513937
if [ $IS_A100 -ne 0 ];then
loss_base=9.39356422
fi
Expand Down Expand Up @@ -770,7 +770,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.35162354
loss_base=9.35162258
if [ $IS_A100 -ne 0 ];then
loss_base=9.39368534
fi
Expand Down Expand Up @@ -1063,7 +1063,7 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() {
loss_base=9.16783295
loss_md5_base=8ea72495fba4e1b9ba004b4431e27218
if [ $IS_A100 -ne 0 ];then
loss_base=9.37966919
loss_base=9.38009949
fi
ips_base=-1
mem_base=-1
Expand Down Expand Up @@ -1253,7 +1253,7 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP() {
loss_base=9.25199432
loss_md5_base=83531e98ee11cd271db175150ab254bb
if [ $IS_A100 -ne 0 ];then
loss_base=9.44203949
loss_base=9.44241714
fi
ips_base=-1
mem_base=-1
Expand Down Expand Up @@ -1546,10 +1546,10 @@ function llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4() {
>>${log_path}/$FUNCNAME 2>&1

loss=$(grep "global_step: 10," "$case_log_dir/workerlog.0" | grep -oP '(?<=loss: )\d+(\.\d+)?' | awk -F ',' '{print $1}')
if [ "$pp_mode" == "FThenB" ]; then
loss1=loss
if [ "$pp_mode" == "1F1B" ]; then
loss1=($loss)
else
loss2=loss
loss2=($loss)
fi
echo "result: $pp_mode loss=$loss"
done
Expand Down Expand Up @@ -1887,12 +1887,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
loss_base=10.59205246
loss_md5_base=0ebf68698887b33b33a46518621cf412
loss_base=10.59368134
ips_base=-1
mem_base=-1
if [ $IS_A100 -ne 0 ];then
loss_base=10.60499191
loss_base=10.60190201
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
Expand Down Expand Up @@ -1960,12 +1959,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
loss_base=10.58860683
loss_md5_base=6df87d01bd08113a92930f6349514b35
loss_base=10.5913763
ips_base=-1
mem_base=-1
if [ $IS_A100 -ne 0 ];then
loss_base=10.59338379
loss_base=10.5915575
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
Expand Down Expand Up @@ -2034,12 +2032,11 @@ function llm_gpt_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
# loss_base=10.59993172 # note: need to debug
loss_base=10.59993267
loss_md5_base=6cb4e151b35f026190df90ab240d9a95
loss_base=10.59891224
ips_base=-1
mem_base=-1
if [ $IS_A100 -ne 0 ];then
loss_base=10.59612274
loss_base=10.60014629
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
Expand Down Expand Up @@ -2108,13 +2105,12 @@ function llm_gpt_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem loss_md5=$loss_md5"
# loss_base=10.58456802 # note: need to debug
loss_base=10.6004734
loss_md5_base=e82a1f5668870d18a2d45b3ee0a25386
loss_base=10.59941483
ips_base=-1
mem_base=-1
if [ $IS_A100 -ne 0 ];then
# loss_base=10.58141422 # note: need to debug
loss_base=10.59650803
loss_base=10.60039139
fi
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
Expand Down
Loading