From fd9ffba19bec9bac1da7d8fe775f6d688f71c365 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 15 Aug 2024 20:25:53 +0800 Subject: [PATCH] update dataloader --- paddlenlp/data/dist_dataloader.py | 6 +- paddlenlp/trainer/trainer.py | 121 +++++++++++++----------------- 2 files changed, 55 insertions(+), 72 deletions(-) diff --git a/paddlenlp/data/dist_dataloader.py b/paddlenlp/data/dist_dataloader.py index 9480a5cb4a61..a6330ce1fe08 100644 --- a/paddlenlp/data/dist_dataloader.py +++ b/paddlenlp/data/dist_dataloader.py @@ -61,10 +61,12 @@ def __init__( timeout=0, worker_init_fn=None, persistent_workers=False, - eval=False, - is_iterable_dataset=False, + **kwargs, ): + eval = kwargs.pop("eval", False) + is_iterable_dataset = kwargs.pop("is_iterable_dataset", False) + if dataset is None: dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset() logger.info("rank has no data, use Dummpy dataset") diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index ff6d79e8dd0f..782776fdd46e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1404,6 +1404,7 @@ def get_train_dataloader(self): is_iterable_dataset = self._is_iterable_dataset(train_dataset) if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") + _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if is_iterable_dataset: # For iterable dataset if self.args.dataset_world_size > 1 and train_dataset is not None: @@ -1417,23 +1418,18 @@ def get_train_dataloader(self): if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") - return DistDataLoader( - train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - is_iterable_dataset=True, - ) + additional_configs = {"is_iterable_dataset": True} else: - return DataLoader( - train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - ) + additional_configs = {} + return _DataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + **additional_configs, + ) else: train_sampler = self._get_train_sampler() - _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") return _DataLoader( @@ -1494,6 +1490,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa is_iterable_dataset = self._is_iterable_dataset(eval_dataset) if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if is_iterable_dataset: if self.args.dataset_world_size > 1 and eval_dataset is not None: @@ -1504,41 +1501,32 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa num_processes=self.args.dataset_world_size, process_index=self.args.dataset_rank, ) - if self.args.distributed_dataloader: - return DistDataLoader( - eval_dataset, - batch_size=self.args.per_device_eval_batch_size, - collate_fn=self.data_collator, - num_workers=0, - eval=True, - is_iterable_dataset=True, - ) + logger.info("Eval using DistDataLoader.") + additional_configs = {"eval": True, "is_iterable_dataset": True} else: - return DataLoader( - eval_dataset, - batch_size=self.args.per_device_eval_batch_size, - collate_fn=self.data_collator, - num_workers=0, - ) + additional_configs = {} + return _DataLoader( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + collate_fn=self.data_collator, + num_workers=0, + **additional_configs, + ) else: eval_sampler = self._get_eval_sampler(eval_dataset) if self.args.distributed_dataloader: logger.info("Eval using DistDataLoader.") - return DistDataLoader( - eval_dataset, - batch_sampler=eval_sampler, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - eval=True, - ) + additional_configs = {"eval": True} else: - return DataLoader( - eval_dataset, - batch_sampler=eval_sampler, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - ) + additional_configs = {} + return _DataLoader( + eval_dataset, + batch_sampler=eval_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + **additional_configs, + ) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ @@ -1562,6 +1550,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: is_iterable_dataset = self._is_iterable_dataset(test_dataset) if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") + _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if is_iterable_dataset: if self.args.dataset_world_size > 1 and test_dataset is not None: @@ -1574,40 +1563,32 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: ) if self.args.distributed_dataloader: - return DistDataLoader( - test_dataset, - batch_size=self.args.per_device_eval_batch_size * self.world_size, - collate_fn=self.data_collator, # _get_collator_with_removed_columns - num_workers=self.args.dataloader_num_workers, - eval=True, - is_iterable_dataset=True, - ) + logger.info("Test using DistDataLoader.") + additional_config = {"eval": True, "is_iterable_dataset": True} else: - return DataLoader( - test_dataset, - batch_size=self.args.per_device_eval_batch_size * self.world_size, - collate_fn=self.data_collator, # _get_collator_with_removed_columns - num_workers=self.args.dataloader_num_workers, - ) + additional_config = {} + return _DataLoader( + test_dataset, + batch_size=self.args.per_device_eval_batch_size * self.world_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + **additional_config, + ) else: test_sampler = self._get_eval_sampler(test_dataset) if self.args.distributed_dataloader: logger.info("Test using DistDataLoader.") - # We use the same batch_size as for eval. - return DistDataLoader( - test_dataset, - batch_sampler=test_sampler, - collate_fn=self.data_collator, - drop_last=self.args.dataloader_drop_last, - eval=True, - ) + additional_config = {"eval": True} else: - return DataLoader( - test_dataset, - batch_sampler=test_sampler, - collate_fn=self.data_collator, - drop_last=self.args.dataloader_drop_last, - ) + additional_config = {} + # We use the same batch_size as for eval. + return _DataLoader( + test_dataset, + batch_sampler=test_sampler, + collate_fn=self.data_collator, + drop_last=self.args.dataloader_drop_last, + **additional_config, + ) def create_optimizer_and_scheduler(self, num_training_steps: int): """