Skip to content

Commit

Permalink
[Auto Parallel] fix data stream bug of dist.to_static (#9077)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 authored Sep 5, 2024
1 parent 70da482 commit 8ee99a4
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ def _wrap_for_auto(self, model, train_dataloader):
if self.args.to_static:
unified_strategy = dist.Strategy()
unified_strategy._from_legacy_strategy(self.args.strategy)
model = dist.to_static(model, dist_loader, self.criterion, self.optimizer, strategy=unified_strategy)
# dist.to_static() obtains the input spec information through next(dataloader), but this has side effects
# on the passed-in dataloader, altering the state of the sampler of the dataloader. In some cases, once
# the state of the sampler is changed, it cannot be reverted. Therefore, a temporary dataloader is
# constructed here to avoid side effects on the dataloader used for actual training.
temp_loader = self._wrap_for_dist_loader(self.get_train_dataloader())
model = dist.to_static(model, temp_loader, self.criterion, self.optimizer, strategy=unified_strategy)

self.model_wrapped = model
return model, dist_loader
Expand Down

0 comments on commit 8ee99a4

Please sign in to comment.