Skip to content

Commit

Permalink
update dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 15, 2024
1 parent 65ab0ad commit fd9ffba
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 72 deletions.
6 changes: 4 additions & 2 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ def __init__(
timeout=0,
worker_init_fn=None,
persistent_workers=False,
eval=False,
is_iterable_dataset=False,
**kwargs,
):

eval = kwargs.pop("eval", False)
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)

if dataset is None:
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()
logger.info("rank has no data, use Dummpy dataset")
Expand Down
121 changes: 51 additions & 70 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,7 @@ def get_train_dataloader(self):
is_iterable_dataset = self._is_iterable_dataset(train_dataset)
if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if is_iterable_dataset: # For iterable dataset
if self.args.dataset_world_size > 1 and train_dataset is not None:
Expand All @@ -1417,23 +1418,18 @@ def get_train_dataloader(self):

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
return DistDataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
is_iterable_dataset=True,
)
additional_configs = {"is_iterable_dataset": True}
else:
return DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)
additional_configs = {}
return _DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)
else:
train_sampler = self._get_train_sampler()
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
return _DataLoader(
Expand Down Expand Up @@ -1494,6 +1490,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
is_iterable_dataset = self._is_iterable_dataset(eval_dataset)
if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if is_iterable_dataset:
if self.args.dataset_world_size > 1 and eval_dataset is not None:
Expand All @@ -1504,41 +1501,32 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
num_processes=self.args.dataset_world_size,
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
return DistDataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
eval=True,
is_iterable_dataset=True,
)
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True, "is_iterable_dataset": True}
else:
return DataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
)
additional_configs = {}
return _DataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
**additional_configs,
)
else:
eval_sampler = self._get_eval_sampler(eval_dataset)
if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
return DistDataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
eval=True,
)
additional_configs = {"eval": True}
else:
return DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)
additional_configs = {}
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)

def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Expand All @@ -1562,6 +1550,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
is_iterable_dataset = self._is_iterable_dataset(test_dataset)
if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if is_iterable_dataset:
if self.args.dataset_world_size > 1 and test_dataset is not None:
Expand All @@ -1574,40 +1563,32 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
)

if self.args.distributed_dataloader:
return DistDataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
eval=True,
is_iterable_dataset=True,
)
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True, "is_iterable_dataset": True}
else:
return DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
)
additional_config = {}
return _DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_config,
)
else:
test_sampler = self._get_eval_sampler(test_dataset)
if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
# We use the same batch_size as for eval.
return DistDataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
eval=True,
)
additional_config = {"eval": True}
else:
return DataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
)
additional_config = {}
# We use the same batch_size as for eval.
return _DataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
**additional_config,
)

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Expand Down

0 comments on commit fd9ffba

Please sign in to comment.