Skip to content

Commit

Permalink
update dd
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 15, 2024
1 parent 21c7dd9 commit 65ab0ad
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ def get_train_dataloader(self):

train_dataset = self.train_dataset
if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_dd(train_dataset)
is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset)
else:

Check warning on line 1403 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1403

Added line #L1403 was not covered by tests
is_iterable_dataset = self._is_iterable_dataset(train_dataset)
if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset):
Expand Down Expand Up @@ -1489,7 +1489,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa

eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_dd(eval_dataset)
is_iterable_dataset = self._is_iterable_dataset_distributed(eval_dataset)
else:

Check warning on line 1493 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1493

Added line #L1493 was not covered by tests
is_iterable_dataset = self._is_iterable_dataset(eval_dataset)
if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset):
Expand Down Expand Up @@ -1557,7 +1557,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
raise ValueError("We don't need test_dataset when should_load_dataset is False.")

if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_dd(test_dataset)
is_iterable_dataset = self._is_iterable_dataset_distributed(test_dataset)
else:

Check warning on line 1561 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1560-L1561

Added lines #L1560 - L1561 were not covered by tests
is_iterable_dataset = self._is_iterable_dataset(test_dataset)
if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset):
Expand Down Expand Up @@ -3220,7 +3220,7 @@ def _get_collator_with_removed_columns(
def _is_iterable_dataset(self, dataset):
return isinstance(dataset, paddle.io.IterableDataset)

def _is_iterable_dataset_dd(self, dataset):
def _is_iterable_dataset_distributed(self, dataset):
# For distributed dataloaer.
is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1])
if dist.get_world_size() > 1:
Expand Down

0 comments on commit 65ab0ad

Please sign in to comment.