Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Support semi-auto trainer and fit Llama2 training #7885

Merged
merged 10 commits into from
Jan 31, 2024

Conversation

haohongxiang
Copy link
Contributor

PR types

Bug fixes

PR changes

Others

Description

[Auto Parallel] Support semi-auto trainer and fit Llama2 training

Copy link

paddle-bot bot commented Jan 23, 2024

Thanks for your contribution!

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch 6 times, most recently from 9668320 to 97498b9 Compare January 23, 2024 06:22
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改动比较大,先 Request changes 一手。

)

return optimizer
def _wrap_dist_loader(self, train_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used in dynamic mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, now it's used in dynamic and static mode.

meshes.append(_get_mesh(pp_idx))
return meshes

def _wrap_dist_loader(self, train_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference with _wrap_dist_loader in run_pretrain_3D_auto.py?

shard_dims="dp",
)

def _wrap_for_static(self, model, train_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems not used?

Copy link
Contributor Author

@haohongxiang haohongxiang Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called in Trainer from paddlenlp/trainer/trainer.py, for wrapping model into DistModel in static mode

@@ -939,15 +877,16 @@ def forward(
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
# NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Shard(0), dist.Replicate()])
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change to replicated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in static mode, infer spmd hasn't supported the case -- "[seq_len] --> [batch, seq_len]"

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch 6 times, most recently from 0afdc96 to 624abd7 Compare January 24, 2024 08:26
Comment on lines 723 to 726
if self.args.use_auto_parallel and self.args.run_static_semi_auto:
model = self._wrap_for_static(model, train_dataloader)

self.model = model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.args.use_auto_parallel and self.args.run_static_semi_auto:
model = self._wrap_for_static(model, train_dataloader)
self.model = model
if self.args.use_auto_parallel and self.args.run_static_semi_auto:
model = self._wrap_for_static(model, train_dataloader)
self.model = model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

paddlenlp/trainer/trainer.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Jan 25, 2024

Codecov Report

Attention: 398 lines in your changes are missing coverage. Please review.

Comparison is base (44bfeb0) 56.80% compared to head (f13a0bf) 56.57%.

Files Patch % Lines
paddlenlp/trainer/auto_trainer.py 0.00% 326 Missing ⚠️
paddlenlp/transformers/llama/modeling_3D_auto.py 4.54% 42 Missing ⚠️
paddlenlp/trainer/training_args.py 47.36% 20 Missing ⚠️
paddlenlp/trainer/trainer_utils.py 30.76% 9 Missing ⚠️
paddlenlp/trainer/trainer.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7885      +/-   ##
===========================================
- Coverage    56.80%   56.57%   -0.23%     
===========================================
  Files          588      589       +1     
  Lines        89536    89900     +364     
===========================================
+ Hits         50858    50865       +7     
- Misses       38678    39035     +357     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch 4 times, most recently from e3dfa0b to eda936c Compare January 28, 2024 22:12
paddlenlp/trainer/trainer.py Outdated Show resolved Hide resolved
paddlenlp/trainer/trainer.py Outdated Show resolved Hide resolved
def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):
if self.control.should_log:

logs: Dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _get_item_from_loss(self, loss):
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
return loss.item()
def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):
if self.control.should_log:
logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())

这里 我看你是复用了 _maybe_log_save_evaluate 函数。而且外面包括了 guard,为什么这里 要加一个 assert isinstance(loss, paddle.Tensor) and loss._is_initialized()的检查?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以删掉,半自动判断逻辑在 auto_trainer 中重写即可

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦删除一下?

@@ -747,6 +748,8 @@ class TrainingArguments:
default=False,
metadata={"help": "reshard pp even if pp degree in the model and pp degree in script match"},
)
parallel_mode: str = field(default="hybrid", metadata={"help": ""})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里注释写详细一点?标注一下只为自动并行或者半自动并行使用?有什么是选的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -747,6 +748,8 @@ class TrainingArguments:
default=False,
metadata={"help": "reshard pp even if pp degree in the model and pp degree in script match"},
)
parallel_mode: str = field(default="hybrid", metadata={"help": ""})
run_static_semi_auto: bool = field(default=True, metadata={"help": ""})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数的具体意义又是?两个选项是否存在合并的可能?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数用以区分动半执行还是静半执行;默认值为True,表示在静半模型下执行训练端到端流程;若手动设置为False,在动半模式下执行训练,可方便用户完成组网标记等模块的调试。

llm/llama/auto_parallel/run_pretrain_3D_auto.py Outdated Show resolved Hide resolved
paddlenlp/trainer/auto_trainer.py Outdated Show resolved Hide resolved
if kwargs.get("args", None) is not None and kwargs["args"].run_static_semi_auto:
if kwargs.get("criterion", None) is None:

def loss_func(loss, outputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要 定义 criterion 吗?现在paddlenlp的模型,loss基本在模型里面了。不额外定义。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

静半架构运行需要一个假的critirion,哪怕直接返回loss也可以

total_batch_size_per_acc_step = self.args.per_device_train_batch_size * self.args.dataset_world_size
total_batch_size = total_batch_size_per_acc_step * self.args.gradient_accumulation_steps
batch_size = total_batch_size if self.args.run_static_semi_auto else total_batch_size_per_acc_step

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是很懂,是说 acc 会并 run_static_semi_auto 里面控制,所以bs更大?

会不会出现数据不一致问题,开启run_static_semi_auto 与否?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,batch sample里统一传入global batch size;除静半PP策略,其他场景的梯度累加都需要在读出global batch后在batch dim维度做split,然后for循环执行,完成梯度累加

@haohongxiang haohongxiang force-pushed the semi_auto_trainer_llama2 branch 2 times, most recently from 79f68f3 to 3c8be71 Compare January 30, 2024 03:56
)
},
)
run_static_auto: bool = field(default=True, metadata={"help": "whether to run static graph in auto parallel mode"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about?

"hybrid"
"auto"
"auto_static"
"auto_semi"
"auto_semi_static"

Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 3a704ea into PaddlePaddle:develop Jan 31, 2024
7 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants