From de25e8ac2baea8e36c276e17899fab6909d8d2b5 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 14 Aug 2024 15:42:55 +0800 Subject: [PATCH 1/4] fix ddloader, fix uc unittest --- paddlenlp/data/dist_dataloader.py | 12 +++- paddlenlp/trainer/trainer.py | 75 +++++++++++++++++++----- tests/trainer/test_unified_checkpoint.py | 22 +------ 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/paddlenlp/data/dist_dataloader.py b/paddlenlp/data/dist_dataloader.py index 5d5c6cc7512c..9480a5cb4a61 100644 --- a/paddlenlp/data/dist_dataloader.py +++ b/paddlenlp/data/dist_dataloader.py @@ -33,6 +33,11 @@ def __len__(self): return 0 +class IterableDummyDataset(paddle.io.IterableDataset): + def __iter__(self): + return None + + class DistDataLoader(paddle.io.DataLoader): """ DistDataLoader is a wrapper of paddle.io.DataLoader. @@ -57,10 +62,11 @@ def __init__( worker_init_fn=None, persistent_workers=False, eval=False, + is_iterable_dataset=False, ): if dataset is None: - dataset = DummyDataset() + dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset() logger.info("rank has no data, use Dummpy dataset") super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers) @@ -200,7 +206,7 @@ def __next__(self): try: data = next(self._dataloader_iter) data = nested_copy_place(data, place=paddle.framework._current_expected_place()) - except: - pass + except Exception as e: + logger.debug(e) data = self._broadcast_data(data) return data diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 0497cfba304e..ef1fe2aa35c2 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1398,9 +1398,12 @@ def get_train_dataloader(self): raise ValueError("We don't need train_dataset when should_load_dataset is False.") train_dataset = self.train_dataset + if self.args.distributed_dataloader: + is_iterable_dataset = self._is_iterable_dataset_dd(train_dataset) + else: + is_iterable_dataset = self._is_iterable_dataset(train_dataset) if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") - _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if self._is_iterable_dataset(train_dataset): if self.args.dataset_world_size > 1: @@ -1412,24 +1415,42 @@ def get_train_dataloader(self): process_index=self.args.dataset_rank, ) - return _DataLoader( - train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - ) + if self.args.distributed_dataloader: + return DistDataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + is_iterable_dataset=True, + ) + else: + return DataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + ) train_sampler = self._get_train_sampler() if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") - return _DataLoader( - train_dataset, - batch_sampler=train_sampler, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - ) + if self.args.distributed_dataloader: + return DistDataLoader( + train_dataset, + batch_sampler=train_sampler if not is_iterable_dataset else None, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + is_iterable_dataset=is_iterable_dataset, + ) + else: + return DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + ) def _get_eval_sampler(self, eval_dataset: Dataset): if eval_dataset is None or not has_length(eval_dataset): @@ -1476,7 +1497,10 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa raise ValueError("We don't need eval_dataset when should_load_dataset is False.") eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - + if self.args.distributed_dataloader: + is_iterable_dataset = self._is_iterable_dataset_dd(eval_dataset) + else: + is_iterable_dataset = self._is_iterable_dataset(eval_dataset) if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") @@ -1497,6 +1521,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa collate_fn=self.data_collator, num_workers=0, eval=True, + is_iterable_dataset=True, ) else: return DataLoader( @@ -1513,10 +1538,11 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa return DistDataLoader( eval_dataset, - batch_sampler=eval_sampler, + batch_sampler=eval_sampler if not is_iterable_dataset else None, collate_fn=self.data_collator, num_workers=self.args.dataloader_num_workers, eval=True, + is_iterable_dataset=is_iterable_dataset, ) else: return DataLoader( @@ -1542,6 +1568,10 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: if not self.args.should_load_dataset and test_dataset is not None: raise ValueError("We don't need test_dataset when should_load_dataset is False.") + if self.args.distributed_dataloader: + is_iterable_dataset = self._is_iterable_dataset_dd(test_dataset) + else: + is_iterable_dataset = self._is_iterable_dataset(test_dataset) if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") @@ -1562,6 +1592,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: collate_fn=self.data_collator, # _get_collator_with_removed_columns num_workers=self.args.dataloader_num_workers, eval=True, + is_iterable_dataset=True, ) else: return DataLoader( @@ -1579,10 +1610,11 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: # We use the same batch_size as for eval. return DistDataLoader( test_dataset, - batch_sampler=test_sampler, + batch_sampler=test_sampler if not is_iterable_dataset else None, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, eval=True, + is_iterable_dataset=is_iterable_dataset, ) else: return DataLoader( @@ -1694,6 +1726,8 @@ def _load_rng_state(self, checkpoint): if self.args.use_hybrid_parallel: if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state: + if self.args.tensor_parallel_degree <= 1: + checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None) fleet.meta_parallel.get_rng_state_tracker().set_states_tracker( checkpoint_rng_state["hybrid_parallel_rng_state_tracker"] ) @@ -3201,6 +3235,15 @@ def _get_collator_with_removed_columns( def _is_iterable_dataset(self, dataset): return isinstance(dataset, paddle.io.IterableDataset) + def _is_iterable_dataset_dd(self, dataset): + # For distributed dataloaer. + is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1]) + if dist.get_world_size() > 1: + dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX) + if is_iterable_dataset_tensor.item() == 1: + return True + return False + def print_config(self, args=None, key=""): """ print config values diff --git a/tests/trainer/test_unified_checkpoint.py b/tests/trainer/test_unified_checkpoint.py index 5044eeaad5f5..17fe0f14f9ea 100644 --- a/tests/trainer/test_unified_checkpoint.py +++ b/tests/trainer/test_unified_checkpoint.py @@ -659,7 +659,7 @@ def setUp(self): self.need_allclose = True self.rtol = 1e-7 - self.run_pretrain_file = "llm/llama/run_pretrain.py" + self.run_pretrain_file = "llm/run_pretrain.py" def runfirst(self, train_args): train_args["unified_checkpoint"] = 0 @@ -701,7 +701,7 @@ def setUp(self): self.need_allclose = True self.rtol = 1e-7 - self.run_pretrain_file = "llm/llama/run_pretrain.py" + self.run_pretrain_file = "llm/run_pretrain.py" self.filelists = [ "config.json", "master_weights-00001-of-00002.safetensors", @@ -1132,24 +1132,6 @@ def rerun(self, train_args): np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol) -@pytest.mark.skipif(True, reason="Skip for None CE") -class TestUnifiedCheckpointOnN1C8EnableAll(TestUnifiedCheckpointBase): - def setUp(self): - super().setUp() - for config_key in self.configs: - self.configs[config_key]["unified_checkpoint"] = 1 - self.configs[config_key]["unified_checkpoint_config"] = "enable_all_options" - - self.need_allclose = True - self.rtol = 1e-7 - - def runfirst(self, train_args): - self.run_n1c8(self.run_pretrain_file, **train_args) - - def rerun(self, train_args): - self.run_n1c8(self.run_pretrain_file, **train_args) - - @pytest.mark.skipif(True, reason="Skip for None CE") class TestUnifiedCheckpointOnN1C8SaveLoadSpeed(TestUnifiedCheckpointFull): def setUp(self): From 21c7dd99ea6259dd664e7447cc937ffdf897fe8a Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 15 Aug 2024 16:01:14 +0800 Subject: [PATCH 2/4] update loader --- paddlenlp/trainer/trainer.py | 109 +++++++++++++++-------------------- 1 file changed, 47 insertions(+), 62 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index ef1fe2aa35c2..c077b5d8e552 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1405,8 +1405,8 @@ def get_train_dataloader(self): if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") - if self._is_iterable_dataset(train_dataset): - if self.args.dataset_world_size > 1: + if is_iterable_dataset: # For iterable dataset + if self.args.dataset_world_size > 1 and train_dataset is not None: train_dataset = IterableDatasetShard( train_dataset, batch_size=self.args.per_device_train_batch_size, @@ -1416,6 +1416,7 @@ def get_train_dataloader(self): ) if self.args.distributed_dataloader: + logger.info("Training using DistDataLoader.") return DistDataLoader( train_dataset, batch_size=self.args.per_device_train_batch_size, @@ -1430,22 +1431,12 @@ def get_train_dataloader(self): collate_fn=self.data_collator, num_workers=self.args.dataloader_num_workers, ) - - train_sampler = self._get_train_sampler() - - if self.args.distributed_dataloader: - logger.info("Training using DistDataLoader.") - - if self.args.distributed_dataloader: - return DistDataLoader( - train_dataset, - batch_sampler=train_sampler if not is_iterable_dataset else None, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - is_iterable_dataset=is_iterable_dataset, - ) else: - return DataLoader( + train_sampler = self._get_train_sampler() + _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader + if self.args.distributed_dataloader: + logger.info("Training using DistDataLoader.") + return _DataLoader( train_dataset, batch_sampler=train_sampler, collate_fn=self.data_collator, @@ -1504,8 +1495,8 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") - if self._is_iterable_dataset(eval_dataset): - if self.args.dataset_world_size > 1: + if is_iterable_dataset: + if self.args.dataset_world_size > 1 and eval_dataset is not None: eval_dataset = IterableDatasetShard( eval_dataset, batch_size=self.args.per_device_eval_batch_size, @@ -1530,27 +1521,24 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa collate_fn=self.data_collator, num_workers=0, ) - - eval_sampler = self._get_eval_sampler(eval_dataset) - - if self.args.distributed_dataloader: - logger.info("Eval using DistDataLoader.") - - return DistDataLoader( - eval_dataset, - batch_sampler=eval_sampler if not is_iterable_dataset else None, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - eval=True, - is_iterable_dataset=is_iterable_dataset, - ) else: - return DataLoader( - eval_dataset, - batch_sampler=eval_sampler, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - ) + eval_sampler = self._get_eval_sampler(eval_dataset) + if self.args.distributed_dataloader: + logger.info("Eval using DistDataLoader.") + return DistDataLoader( + eval_dataset, + batch_sampler=eval_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + eval=True, + ) + else: + return DataLoader( + eval_dataset, + batch_sampler=eval_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + ) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ @@ -1575,8 +1563,8 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") - if self._is_iterable_dataset(test_dataset): - if self.args.dataset_world_size > 1: + if is_iterable_dataset: + if self.args.dataset_world_size > 1 and test_dataset is not None: test_dataset = IterableDatasetShard( test_dataset, batch_size=self.args.per_device_eval_batch_size, @@ -1601,28 +1589,25 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: collate_fn=self.data_collator, # _get_collator_with_removed_columns num_workers=self.args.dataloader_num_workers, ) - - test_sampler = self._get_eval_sampler(test_dataset) - - if self.args.distributed_dataloader: - logger.info("Test using DistDataLoader.") - - # We use the same batch_size as for eval. - return DistDataLoader( - test_dataset, - batch_sampler=test_sampler if not is_iterable_dataset else None, - collate_fn=self.data_collator, - drop_last=self.args.dataloader_drop_last, - eval=True, - is_iterable_dataset=is_iterable_dataset, - ) else: - return DataLoader( - test_dataset, - batch_sampler=test_sampler, - collate_fn=self.data_collator, - drop_last=self.args.dataloader_drop_last, - ) + test_sampler = self._get_eval_sampler(test_dataset) + if self.args.distributed_dataloader: + logger.info("Test using DistDataLoader.") + # We use the same batch_size as for eval. + return DistDataLoader( + test_dataset, + batch_sampler=test_sampler, + collate_fn=self.data_collator, + drop_last=self.args.dataloader_drop_last, + eval=True, + ) + else: + return DataLoader( + test_dataset, + batch_sampler=test_sampler, + collate_fn=self.data_collator, + drop_last=self.args.dataloader_drop_last, + ) def create_optimizer_and_scheduler(self, num_training_steps: int): """ From 65ab0ad68055eb54d9e2b12556285f95d769c011 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 15 Aug 2024 17:27:21 +0800 Subject: [PATCH 3/4] update dd --- paddlenlp/trainer/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index c077b5d8e552..ff6d79e8dd0f 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1399,7 +1399,7 @@ def get_train_dataloader(self): train_dataset = self.train_dataset if self.args.distributed_dataloader: - is_iterable_dataset = self._is_iterable_dataset_dd(train_dataset) + is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset) else: is_iterable_dataset = self._is_iterable_dataset(train_dataset) if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): @@ -1489,7 +1489,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset if self.args.distributed_dataloader: - is_iterable_dataset = self._is_iterable_dataset_dd(eval_dataset) + is_iterable_dataset = self._is_iterable_dataset_distributed(eval_dataset) else: is_iterable_dataset = self._is_iterable_dataset(eval_dataset) if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset): @@ -1557,7 +1557,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: raise ValueError("We don't need test_dataset when should_load_dataset is False.") if self.args.distributed_dataloader: - is_iterable_dataset = self._is_iterable_dataset_dd(test_dataset) + is_iterable_dataset = self._is_iterable_dataset_distributed(test_dataset) else: is_iterable_dataset = self._is_iterable_dataset(test_dataset) if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset): @@ -3220,7 +3220,7 @@ def _get_collator_with_removed_columns( def _is_iterable_dataset(self, dataset): return isinstance(dataset, paddle.io.IterableDataset) - def _is_iterable_dataset_dd(self, dataset): + def _is_iterable_dataset_distributed(self, dataset): # For distributed dataloaer. is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1]) if dist.get_world_size() > 1: From fd9ffba19bec9bac1da7d8fe775f6d688f71c365 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 15 Aug 2024 20:25:53 +0800 Subject: [PATCH 4/4] update dataloader --- paddlenlp/data/dist_dataloader.py | 6 +- paddlenlp/trainer/trainer.py | 121 +++++++++++++----------------- 2 files changed, 55 insertions(+), 72 deletions(-) diff --git a/paddlenlp/data/dist_dataloader.py b/paddlenlp/data/dist_dataloader.py index 9480a5cb4a61..a6330ce1fe08 100644 --- a/paddlenlp/data/dist_dataloader.py +++ b/paddlenlp/data/dist_dataloader.py @@ -61,10 +61,12 @@ def __init__( timeout=0, worker_init_fn=None, persistent_workers=False, - eval=False, - is_iterable_dataset=False, + **kwargs, ): + eval = kwargs.pop("eval", False) + is_iterable_dataset = kwargs.pop("is_iterable_dataset", False) + if dataset is None: dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset() logger.info("rank has no data, use Dummpy dataset") diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index ff6d79e8dd0f..782776fdd46e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1404,6 +1404,7 @@ def get_train_dataloader(self): is_iterable_dataset = self._is_iterable_dataset(train_dataset) if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") + _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if is_iterable_dataset: # For iterable dataset if self.args.dataset_world_size > 1 and train_dataset is not None: @@ -1417,23 +1418,18 @@ def get_train_dataloader(self): if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") - return DistDataLoader( - train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - is_iterable_dataset=True, - ) + additional_configs = {"is_iterable_dataset": True} else: - return DataLoader( - train_dataset, - batch_size=self.args.per_device_train_batch_size, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - ) + additional_configs = {} + return _DataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + **additional_configs, + ) else: train_sampler = self._get_train_sampler() - _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") return _DataLoader( @@ -1494,6 +1490,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa is_iterable_dataset = self._is_iterable_dataset(eval_dataset) if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset): eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if is_iterable_dataset: if self.args.dataset_world_size > 1 and eval_dataset is not None: @@ -1504,41 +1501,32 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa num_processes=self.args.dataset_world_size, process_index=self.args.dataset_rank, ) - if self.args.distributed_dataloader: - return DistDataLoader( - eval_dataset, - batch_size=self.args.per_device_eval_batch_size, - collate_fn=self.data_collator, - num_workers=0, - eval=True, - is_iterable_dataset=True, - ) + logger.info("Eval using DistDataLoader.") + additional_configs = {"eval": True, "is_iterable_dataset": True} else: - return DataLoader( - eval_dataset, - batch_size=self.args.per_device_eval_batch_size, - collate_fn=self.data_collator, - num_workers=0, - ) + additional_configs = {} + return _DataLoader( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + collate_fn=self.data_collator, + num_workers=0, + **additional_configs, + ) else: eval_sampler = self._get_eval_sampler(eval_dataset) if self.args.distributed_dataloader: logger.info("Eval using DistDataLoader.") - return DistDataLoader( - eval_dataset, - batch_sampler=eval_sampler, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - eval=True, - ) + additional_configs = {"eval": True} else: - return DataLoader( - eval_dataset, - batch_sampler=eval_sampler, - collate_fn=self.data_collator, - num_workers=self.args.dataloader_num_workers, - ) + additional_configs = {} + return _DataLoader( + eval_dataset, + batch_sampler=eval_sampler, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + **additional_configs, + ) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ @@ -1562,6 +1550,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: is_iterable_dataset = self._is_iterable_dataset(test_dataset) if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset): test_dataset = self._remove_unused_columns(test_dataset, description="test") + _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader if is_iterable_dataset: if self.args.dataset_world_size > 1 and test_dataset is not None: @@ -1574,40 +1563,32 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: ) if self.args.distributed_dataloader: - return DistDataLoader( - test_dataset, - batch_size=self.args.per_device_eval_batch_size * self.world_size, - collate_fn=self.data_collator, # _get_collator_with_removed_columns - num_workers=self.args.dataloader_num_workers, - eval=True, - is_iterable_dataset=True, - ) + logger.info("Test using DistDataLoader.") + additional_config = {"eval": True, "is_iterable_dataset": True} else: - return DataLoader( - test_dataset, - batch_size=self.args.per_device_eval_batch_size * self.world_size, - collate_fn=self.data_collator, # _get_collator_with_removed_columns - num_workers=self.args.dataloader_num_workers, - ) + additional_config = {} + return _DataLoader( + test_dataset, + batch_size=self.args.per_device_eval_batch_size * self.world_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + **additional_config, + ) else: test_sampler = self._get_eval_sampler(test_dataset) if self.args.distributed_dataloader: logger.info("Test using DistDataLoader.") - # We use the same batch_size as for eval. - return DistDataLoader( - test_dataset, - batch_sampler=test_sampler, - collate_fn=self.data_collator, - drop_last=self.args.dataloader_drop_last, - eval=True, - ) + additional_config = {"eval": True} else: - return DataLoader( - test_dataset, - batch_sampler=test_sampler, - collate_fn=self.data_collator, - drop_last=self.args.dataloader_drop_last, - ) + additional_config = {} + # We use the same batch_size as for eval. + return _DataLoader( + test_dataset, + batch_sampler=test_sampler, + collate_fn=self.data_collator, + drop_last=self.args.dataloader_drop_last, + **additional_config, + ) def create_optimizer_and_scheduler(self, num_training_steps: int): """