Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support nested list of dict inputs #8876

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,43 @@ def _split_batches_for_accumulation(self, inputs):
if self.args.gradient_accumulation_steps == 1:
return [inputs]

# if self.args.to_static:
if self.args.to_static and self.args.pipeline_parallel_degree > 1:
return [inputs]

local_batches = [{} for i in range(self.args.gradient_accumulation_steps)]
assert isinstance(inputs, dict)

for key, value in inputs.items():
ori_mesh, ori_placements = value.process_mesh, value.placements
replicate_value = dist.reshard(value, ori_mesh, [dist.Replicate(), dist.Replicate()])
def split_dtensor_by_axis(dtensor, axis):
mesh = dtensor.process_mesh
placements = [dist.Replicate() for _ in range(len(mesh.shape))]
replicate_value = dist.reshard(dtensor, mesh, placements)
local_datas = replicate_value.split(self.args.gradient_accumulation_steps, axis=0)

for index, data in enumerate(local_datas):
local_batches[index].update({key: dist.reshard(data, ori_mesh, ori_placements)})

return local_datas

for key, dtensors in inputs.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/utils/nested.py#L25-L37

参考这种写,定义函数,递归遍历,可能清晰一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已线下讨论

if isinstance(dtensors, paddle.Tensor):
mesh, placements = dtensors.process_mesh, dtensors.placements
local_datas = split_dtensor_by_axis(dtensors, 0)
for index, data in enumerate(local_datas):
local_batches[index].update({key: dist.reshard(data, mesh, placements)})
elif isinstance(dtensors, (list, tuple)):
if len(dtensors) == 0:
for i in range(self.args.gradient_accumulation_steps):
local_batches[i].update({key: []})
else:
for dtensor in dtensors:
if isinstance(dtensor, paddle.Tensor):
mesh, placements = dtensor.process_mesh, dtensor.placements
local_datas = split_dtensor_by_axis(dtensor, 0)
for index, data in enumerate(local_datas):
if key in local_batches[index].keys():
local_batches[index][key].append(dist.reshard(data, mesh, placements))
else:
local_batches[index].update({key: [dist.reshard(data, mesh, placements)]})
else:
raise ValueError(f"unsupported type: {type(dtensor)}")
else:
raise ValueError(f"unsupported type: {type(dtensors)}")
return local_batches

def _inner_training_loop(
Expand Down
Loading