Skip to content

Commit

Permalink
[Trainer] Fix distributed dataloader (#8932)
Browse files Browse the repository at this point in the history
* fix ddloader, fix uc unittest

* update dataloader
  • Loading branch information
DesmonDay authored Aug 16, 2024
1 parent 277fdb4 commit e8708ed
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 93 deletions.
16 changes: 12 additions & 4 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def __len__(self):
return 0


class IterableDummyDataset(paddle.io.IterableDataset):
def __iter__(self):
return None


class DistDataLoader(paddle.io.DataLoader):
"""
DistDataLoader is a wrapper of paddle.io.DataLoader.
Expand All @@ -56,11 +61,14 @@ def __init__(
timeout=0,
worker_init_fn=None,
persistent_workers=False,
eval=False,
**kwargs,
):

eval = kwargs.pop("eval", False)
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)

if dataset is None:
dataset = DummyDataset()
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()
logger.info("rank has no data, use Dummpy dataset")

super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)
Expand Down Expand Up @@ -200,7 +208,7 @@ def __next__(self):
try:
data = next(self._dataloader_iter)
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
except:
pass
except Exception as e:
logger.debug(e)
data = self._broadcast_data(data)
return data
147 changes: 78 additions & 69 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,12 +1404,16 @@ def get_train_dataloader(self):
raise ValueError("We don't need train_dataset when should_load_dataset is False.")

train_dataset = self.train_dataset
if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset)
else:
is_iterable_dataset = self._is_iterable_dataset(train_dataset)
if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if self._is_iterable_dataset(train_dataset):
if self.args.dataset_world_size > 1:
if is_iterable_dataset: # For iterable dataset
if self.args.dataset_world_size > 1 and train_dataset is not None:
train_dataset = IterableDatasetShard(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
Expand All @@ -1418,24 +1422,28 @@ def get_train_dataloader(self):
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
additional_configs = {"is_iterable_dataset": True}
else:
additional_configs = {}
return _DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)
else:
train_sampler = self._get_train_sampler()
if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
return _DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)

train_sampler = self._get_train_sampler()

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")

return _DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)

def _get_eval_sampler(self, eval_dataset: Dataset):
if eval_dataset is None or not has_length(eval_dataset):
Expand Down Expand Up @@ -1482,54 +1490,48 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
raise ValueError("We don't need eval_dataset when should_load_dataset is False.")

eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_distributed(eval_dataset)
else:
is_iterable_dataset = self._is_iterable_dataset(eval_dataset)
if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if self._is_iterable_dataset(eval_dataset):
if self.args.dataset_world_size > 1:
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and eval_dataset is not None:
eval_dataset = IterableDatasetShard(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.dataset_world_size,
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
return DistDataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
eval=True,
)
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True, "is_iterable_dataset": True}
else:
return DataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
)

eval_sampler = self._get_eval_sampler(eval_dataset)

if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")

return DistDataLoader(
additional_configs = {}
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
eval=True,
num_workers=0,
**additional_configs,
)
else:
return DataLoader(
eval_sampler = self._get_eval_sampler(eval_dataset)
if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True}
else:
additional_configs = {}
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)

def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
Expand All @@ -1548,11 +1550,16 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
if not self.args.should_load_dataset and test_dataset is not None:
raise ValueError("We don't need test_dataset when should_load_dataset is False.")

if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_distributed(test_dataset)
else:
is_iterable_dataset = self._is_iterable_dataset(test_dataset)
if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if self._is_iterable_dataset(test_dataset):
if self.args.dataset_world_size > 1:
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and test_dataset is not None:
test_dataset = IterableDatasetShard(
test_dataset,
batch_size=self.args.per_device_eval_batch_size,
Expand All @@ -1562,40 +1569,31 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
)

if self.args.distributed_dataloader:
return DistDataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
eval=True,
)
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True, "is_iterable_dataset": True}
else:
return DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
)

test_sampler = self._get_eval_sampler(test_dataset)

if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")

# We use the same batch_size as for eval.
return DistDataLoader(
additional_config = {}
return _DataLoader(
test_dataset,
batch_sampler=test_sampler,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
eval=True,
num_workers=self.args.dataloader_num_workers,
**additional_config,
)
else:
return DataLoader(
test_sampler = self._get_eval_sampler(test_dataset)
if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True}
else:
additional_config = {}
# We use the same batch_size as for eval.
return _DataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
**additional_config,
)

def create_optimizer_and_scheduler(self, num_training_steps: int):
Expand Down Expand Up @@ -1700,6 +1698,8 @@ def _load_rng_state(self, checkpoint):

if self.args.use_hybrid_parallel:
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:
if self.args.tensor_parallel_degree <= 1:
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None)
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"]
)
Expand Down Expand Up @@ -3210,6 +3210,15 @@ def _get_collator_with_removed_columns(
def _is_iterable_dataset(self, dataset):
return isinstance(dataset, paddle.io.IterableDataset)

def _is_iterable_dataset_distributed(self, dataset):
# For distributed dataloaer.
is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1])
if dist.get_world_size() > 1:
dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX)
if is_iterable_dataset_tensor.item() == 1:
return True
return False

def print_config(self, args=None, key=""):
"""
print config values
Expand Down
22 changes: 2 additions & 20 deletions tests/trainer/test_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def setUp(self):
self.need_allclose = True
self.rtol = 1e-7

self.run_pretrain_file = "llm/llama/run_pretrain.py"
self.run_pretrain_file = "llm/run_pretrain.py"

def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
Expand Down Expand Up @@ -701,7 +701,7 @@ def setUp(self):
self.need_allclose = True
self.rtol = 1e-7

self.run_pretrain_file = "llm/llama/run_pretrain.py"
self.run_pretrain_file = "llm/run_pretrain.py"
self.filelists = [
"config.json",
"master_weights-00001-of-00002.safetensors",
Expand Down Expand Up @@ -1132,24 +1132,6 @@ def rerun(self, train_args):
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN1C8EnableAll(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
for config_key in self.configs:
self.configs[config_key]["unified_checkpoint"] = 1
self.configs[config_key]["unified_checkpoint_config"] = "enable_all_options"

self.need_allclose = True
self.rtol = 1e-7

def runfirst(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN1C8SaveLoadSpeed(TestUnifiedCheckpointFull):
def setUp(self):
Expand Down

0 comments on commit e8708ed

Please sign in to comment.