Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Fix bugs caused by the inconsistent outputs of Engine API #46633

Merged
merged 4 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,24 @@ def _restore_serial_feed_vars(self):
def _restore_serial_fetch_vars(self):
for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
# metrics is a list of list
if key == "metrics":
for inner_var_list in var_list:
new_inner_var_list = []
for var in inner_var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list

def _restore_serial_info(self, mode="to_backup"):
Expand Down
241 changes: 139 additions & 102 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,6 @@ def _prepare_feed(self, user_feeds=None, mode="train"):
"user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
feeds = {}
# TODO: add inputs and labels feed dict
for name, var in get_collection(CollectionNames.FEEDS):
assert name is not None, "No name defined for feed var"
feeds[name] = var
if user_feeds is not None:
for name, var in user_feeds.items():
feeds[name] = var
Expand All @@ -227,42 +224,120 @@ def _prepare_fetch(self, user_fetches=None, mode="train"):
assert isinstance(user_fetches, list), \
"user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
fetch_names = []
fetch_new_names = []
fetch_sections = {}
cnt = 0
fetch_indices = []

def _process_section(section_name, var_list):
nonlocal cnt
section_start = cnt
def _process_fetch_group(group_name, var_list):
group_indices = []
for var in var_list:
new_name = None
# Rename the loss
if section_name == "loss":
new_name = "loss"
if isinstance(var, tuple):
assert len(var) == 2, "Length of tuple {} must be 2".format(
var)
new_name, var = var
if self._is_local_var(var) and var.name not in fetch_names:
fetch_names.append(var.name)
fetch_new_names.append(var.name)
cnt += 1
if self._is_local_var(var) and new_name is not None:
fetch_new_names[fetch_names.index(var.name)] = new_name
section_end = cnt
fetch_sections[section_name] = (section_start, section_end)

for name, var_list in self._fetch_vars[mode].items():
if name == "loss" and mode != "predict":
_process_section("loss", var_list)
if name == "metrics" and mode != "predict":
_process_section("metrics", var_list)
if name == "outputs" and mode == "predict":
_process_section("metrics", var_list)
var_list = (get_collection(CollectionNames.FETCHES)
or []) + (user_fetches or [])
_process_section("user_fetches", var_list)
return fetch_names, fetch_new_names, fetch_sections
# Remove duplicate var_names
if self._is_local_var(var):
var_name = _to_name_str(var)
if var_name not in fetch_names:
fetch_names.append(var_name)
group_indices.append(fetch_names.index(var_name))
fetch_indices.append(group_indices)

if mode != "predict":
_process_fetch_group("loss", self._fetch_vars[mode]["loss"])
if mode != "predict":
metrics = self._fetch_vars[mode]["metrics"]
for i, var_list in enumerate(metrics):
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES)
]
var_list = (user_fetches_collection or []) + (user_fetches or [])
_process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices

def _prepare_logger(self,
outs,
mode="train",
epoch=None,
step=None,
lr=None,
fetch_names=None,
fetch_indices=None,
profiler_log=""):
logs = "[{}] ".format(mode)
if epoch is not None:
logs += "epoch: {:d} ".format(epoch)
if step is not None:
logs += "step: {:d} ".format(step)
if lr is not None:
logs += "lr: {:5e} ".format(lr)
group_idx = 0
# logging loss
if mode != "predict":
loss_indices = fetch_indices[group_idx]
for idx in loss_indices:
logs += "loss: {:8f} ".format(outs[idx][0])
group_idx += 1
# logging metrics
if mode != "predict":
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
metric_out = []
for idx in metrics_indices:
metric_out.append(outs[idx])
if metric_out:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs += "{}: {:8f} ".format(metric.name()[i], res)
group_idx += 1
# Skip logging outputs
if mode == "predict":
group_idx += 1
# logging user fetches
fetches_logging = get_collection(CollectionNames.LOGGING)
for name, var in fetches_logging:
if var.name in fetch_names:
idx = fetch_names.index(var.name)
# Use the user defined name for logging
logs += "{}: {} ".format(name, outs[idx])
self._logger.info(logs)

def _prepare_history(self, outs, mode="train", fetch_indices=None):
history = {}
group_idx = 0
# store loss
if mode != "predict":
loss_indices = fetch_indices[group_idx]
loss_values = []
for idx in loss_indices:
loss_values.append(outs[idx][0])
history["loss"] = loss_values
group_idx += 1
# store metrics
if mode != "predict":
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
metric_out = []
for idx in metrics_indices:
metric_out.append(outs[idx])
if metric_out:
metric.update(*metric_out)
results = metric.accumulate()
history[tuple(metric.name())] = to_list(results)
group_idx += 1
# store outputs
if mode == "predict":
outputs_indices = fetch_indices[group_idx]
outputs_values = []
for idx in outputs_indices:
outputs_values.append(outs[idx])
history["outputs"] = outputs_values
group_idx += 1
# store user fetches
fetches_indices = fetch_indices[group_idx]
fetches_values = []
for idx in fetches_indices:
fetches_values.append(outs[idx])
history["fetches"] = fetches_values
return history

def _build(self, mode):
if _non_static_mode() or self._dygraph_mode:
Expand Down Expand Up @@ -311,7 +386,7 @@ def _build(self, mode):

if mode != "predict":
for metric in self._metrics:
metrics.extend(
metrics.append(
to_list(metric.compute(*(outputs + labels))))

default_ctx = get_default_distributed_context()
Expand Down Expand Up @@ -547,58 +622,20 @@ def __call__(self,
fetches=None,
mode="train"):
feed_dict = self._prepare_feed(feeds, mode)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
fetches, mode)
fetch_names, fetch_indices = self._prepare_fetch(fetches, mode)
try:
outs = self._executor.run(
self.main_program,
feed=feed_dict,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
pass
self._print_log(outs, self.mode, None, None, None, fetch_new_names,
fetch_sections)
return outs

# TODO: need a better to print the log
def _print_log(self,
outs,
mode="train",
epoch=None,
step=None,
lr=None,
fetch_new_names=None,
fetch_sections=None,
profiler_log=""):
prefix = "[{}] ".format(mode)
logs = {}
if epoch is not None:
logs["epoch: {:d} "] = epoch
if step is not None:
logs["step: {:d} "] = step
if lr is not None:
logs["lr: {:5e} "] = lr
if fetch_sections is not None:
assert fetch_new_names is not None
for section_name, section in fetch_sections.items():
section_start, section_end = section
if section_name == "metrics" and section_start < section_end:
metric_out = outs[section_start:section_end]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs[metric.name()[i] + ": {:8f} "] = res
elif section_name == "loss" and section_start < section_end:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {:8f} "] = outs[i][0]
else:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {} "] = outs[i]
string = prefix + ''.join(list(logs.keys())) + profiler_log
self._logger.info(string.format(*list(logs.values())))
self._prepare_logger(outs, self.mode, None, None, None, fetch_names,
fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)
return history

def fit(self,
train_data,
Expand Down Expand Up @@ -692,8 +729,7 @@ def fit(self,
epochs, steps_per_epoch,
collate_fn)

fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)
lr_scheduler = self._get_lr_scheduler(self.main_program)

with profiler.Profiler(timer_only=True) as prof:
Expand All @@ -702,7 +738,7 @@ def fit(self,
try:
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
Expand All @@ -713,17 +749,19 @@ def fit(self,

prof.step()

self._print_log(outs, self.mode, epoch, step, lr,
fetch_new_names, fetch_sections,
prof.step_info())
self._prepare_logger(outs, self.mode, epoch, step, lr,
fetch_names, fetch_indices,
prof.step_info())
history = self._prepare_history(outs, self.mode,
fetch_indices)

if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks)
self._switch_mode("train")
else:
self._reset_metrics()
return outs
return history

def evaluate(self,
valid_data,
Expand Down Expand Up @@ -793,23 +831,22 @@ def evaluate(self,
steps_per_epoch=steps,
collate_fn=collate_fn)

fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)

outputs = defaultdict(list)
for step, _ in enumerate(valid_dataloader):
try:
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
self._print_log(outs, self.mode, None, step, None, fetch_new_names,
fetch_sections)
self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)
self._reset_metrics()
return outputs
return history

def predict(self,
test_data,
Expand Down Expand Up @@ -876,22 +913,22 @@ def predict(self,
steps_per_epoch=steps,
collate_fn=collate_fn)

fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
fetch_names, fetch_indices = self._prepare_fetch(mode=self.mode)

for step, _ in enumerate(test_dataloader):
try:
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
self._print_log(outs, self.mode, None, step, None, fetch_new_names,
fetch_sections)
self._prepare_logger(outs, self.mode, None, step, None, fetch_names,
fetch_indices)
history = self._prepare_history(outs, self.mode, fetch_indices)

return outs
return history

def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self.mode = 'train'
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def call_metrics(self, inputs):
"""
outs = []
for metric in self.metrics:
outs.extend(metric.compute(*inputs))
outs.append(to_list(metric.compute(*inputs)))

return outs

Expand Down
Loading