Skip to content

Commit

Permalink
update auto_engine use iterable dataloader (#6862)
Browse files Browse the repository at this point in the history
* fix auto engine with iterable dataloader in new executor

* fix dataloder batch_size
  • Loading branch information
zhaoyinglia authored Aug 30, 2023
1 parent 3655108 commit 36ebbf5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Engine:
num_train_epochs: 1
accumulate_steps:
logging_freq: 1
eval_freq: 1
eval_freq: 500
eval_iters: 10
test_iters:
mix_precision:
Expand Down
44 changes: 28 additions & 16 deletions model_zoo/gpt-3/ppfleetx/core/engine/auto_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ def __init__(self, configs, module=None, mode="train"):
self.nvprof_start = configs.get("Profiler_auto", {}).get("nvprof_start", -1)
self.nvprof_end = configs.get("Profiler_auto", {}).get("nvprof_end", -1)

def _validate_batch(self, batch):
if self._pp_degree > 1 or self._accumulate_steps == 1:
batches = batch
else:
feed_names = []
split_batches = []
for n, b in batch[0].items():
feed_names.append(n)
split_batches.append(np.split(np.array(b), self._accumulate_steps, 0))
batches = []
for i in range(len(split_batches[0])):
micro_batch = [split_batch[i] for split_batch in split_batches]
batches.append(dict(zip(feed_names, micro_batch)))
return batches

def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loader=None):

train_losses = []
Expand All @@ -137,7 +152,7 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade

total_train_batch = self._max_steps if self._run_mode == "step" else len(train_data_loader)
total_train_step = self._max_steps if self._run_mode == "step" else total_train_batch * self._num_train_epochs
total_eval_batch = valid_data_loader._steps if valid_data_loader is not None else 0
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
valid_data_loader = valid_data_loader if valid_data_loader is not None else None
eval_finished_step = 0

Expand All @@ -148,20 +163,17 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
if step < self._load_recovery["step"]:
continue

batches = self._validate_batch(batch)

fetch_list = None
if self._strategy.amp.enable:
# fetch_list = ["find_infinite_scale.tmp_0", "loss_scaling_0"]
fetch_list = []

if self._pp_degree == 1 and self._accumulate_steps > 1: # gradient merge
local_steps = self._accumulate_steps
else:
local_steps = 1

final_loss = None
for _ in range(local_steps):
for micro_batch in batches:
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
outs = self._auto_engine.run(batch, fetch_list=fetch_list, mode="train")
outs = self._auto_engine.run(micro_batch, fetch_list=fetch_list, mode="train")
# pp: some devices don't have loss in outs
if "loss" in outs:
if final_loss is None:
Expand Down Expand Up @@ -255,22 +267,24 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):

train_data_loader, valid_data_loader = None, None
if train_dataset:
train_data_loader = self._auto_engine.dataloader_from_generator(
train_data_loader = self._auto_engine.dataloader(
dataset=train_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=train_dataset.collate_fn,
num_workers=1,
sample_split=train_dataset.sample_split,
mode="train",
)
if valid_dataset and self._eval_freq <= self._max_steps:
valid_data_loader = self._auto_engine.dataloader_from_generator(
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
Expand Down Expand Up @@ -307,23 +321,20 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):
)
)

# from-generator dataloder need to do this to exit normally
if valid_data_loader:
valid_data_loader._inner_dataloader.reset()

if self.profiler:
self._profiler_done()

def evaluate(self, epoch=1, valid_dataset=None):

valid_data_loader = None
if valid_dataset:
valid_data_loader = self._auto_engine.dataloader_from_generator(
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
Expand Down Expand Up @@ -377,12 +388,13 @@ def predict(self, epoch=1, test_dataset=None):

test_data_loader = None
if test_dataset:
test_data_loader = self._auto_engine.dataloader_from_generator(
test_data_loader = self._auto_engine.dataloader(
dataset=test_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=test_dataset.collate_fn,
num_workers=1,
sample_split=test_dataset.sample_split,
mode="predict",
)
Expand Down

0 comments on commit 36ebbf5

Please sign in to comment.