Skip to content

Commit

Permalink
update dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 15, 2024
1 parent 65ab0ad commit 2384c4d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 72 deletions.
6 changes: 4 additions & 2 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ def __init__(
timeout=0,
worker_init_fn=None,
persistent_workers=False,
eval=False,
is_iterable_dataset=False,
**kwargs,
):

eval = kwargs.pop("eval", False)
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)

Check warning on line 68 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L67-L68

Added lines #L67 - L68 were not covered by tests

if dataset is None:
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()

Check warning on line 71 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L71

Added line #L71 was not covered by tests
logger.info("rank has no data, use Dummpy dataset")
Expand Down
121 changes: 51 additions & 70 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,7 @@ def get_train_dataloader(self):
is_iterable_dataset = self._is_iterable_dataset(train_dataset)
if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if is_iterable_dataset: # For iterable dataset
if self.args.dataset_world_size > 1 and train_dataset is not None:
Expand All @@ -1417,23 +1418,18 @@ def get_train_dataloader(self):

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
return DistDataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
is_iterable_dataset=True,
)
additional_configs = {"is_iterable_dataset": True}
else:

Check warning on line 1422 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1420-L1422

Added lines #L1420 - L1422 were not covered by tests
return DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)
additional_configs = {}
return _DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)
else:
train_sampler = self._get_train_sampler()
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
return _DataLoader(

Check warning on line 1435 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1435

Added line #L1435 was not covered by tests
Expand Down Expand Up @@ -1494,6 +1490,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
is_iterable_dataset = self._is_iterable_dataset(eval_dataset)
if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

Check warning on line 1493 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1493

Added line #L1493 was not covered by tests

if is_iterable_dataset:
if self.args.dataset_world_size > 1 and eval_dataset is not None:
Expand All @@ -1504,41 +1501,32 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
num_processes=self.args.dataset_world_size,
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
return DistDataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
eval=True,
is_iterable_dataset=True,
)
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True, "is_iterable_dataset": True}

Check warning on line 1506 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1505-L1506

Added lines #L1505 - L1506 were not covered by tests
else:
return DataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
)
additional_configs = {}
return _DataLoader(

Check warning on line 1509 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1509

Added line #L1509 was not covered by tests
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
**additional_configs,
)
else:
eval_sampler = self._get_eval_sampler(eval_dataset)
if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
return DistDataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
eval=True,
)
additional_configs = {"eval": True}
else:

Check warning on line 1521 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1520-L1521

Added lines #L1520 - L1521 were not covered by tests
return DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)
additional_configs = {}
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)

def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Expand All @@ -1562,6 +1550,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
is_iterable_dataset = self._is_iterable_dataset(test_dataset)
if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

Check warning on line 1553 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1553

Added line #L1553 was not covered by tests

if is_iterable_dataset:
if self.args.dataset_world_size > 1 and test_dataset is not None:

Check warning on line 1556 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1556

Added line #L1556 was not covered by tests
Expand All @@ -1574,40 +1563,32 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
)

if self.args.distributed_dataloader:
return DistDataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
eval=True,
is_iterable_dataset=True,
)
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True, "is_iterable_dataset": True}

Check warning on line 1567 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1566-L1567

Added lines #L1566 - L1567 were not covered by tests
else:
return DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
)
additional_config = {}
return _DataLoader(

Check warning on line 1570 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1570

Added line #L1570 was not covered by tests
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
**additional_config,
)
else:
test_sampler = self._get_eval_sampler(test_dataset)
if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
# We use the same batch_size as for eval.
return DistDataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
eval=True,
)
additional_config = {"eval": True}
else:

Check warning on line 1582 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1579-L1582

Added lines #L1579 - L1582 were not covered by tests
return DataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
additional_config = {}
# We use the same batch_size as for eval.

Check warning on line 1584 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1584

Added line #L1584 was not covered by tests
return _DataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
**additional_config,
)

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Expand Down

0 comments on commit 2384c4d

Please sign in to comment.