diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index c9f836f37971..b8fcbaab4adc 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -92,9 +92,12 @@ def _get_meshes_for_loader(self): def _get_mesh(pp_idx=0): return self.global_mesh.get_mesh_with_dim("pp")[pp_idx] + # Note(lizhiyu): If the values returned by `DataLoader` don't have the format `[images, labels]`, + # error may occurs here. meshes = [] - for pp_idx in range(self.args.pipeline_parallel_degree): - meshes.append(_get_mesh(pp_idx)) + meshes.append(_get_mesh(0)) + if self.args.pipeline_parallel_degree > 1: + meshes.append(_get_mesh(self.args.pipeline_parallel_degree - 1)) return meshes def _wrap_for_dist_loader(self, train_dataloader):