Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-Pick][AutoParallel] auto_parallel cherry-pick to release2.4 #47128

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions python/paddle/distributed/auto_parallel/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

import paddle
from paddle.hapi.callbacks import ProgBarLogger, ModelCheckpoint, LRScheduler, CallbackList, Callback
from .interface import CollectionNames, get_collection


def config_callbacks(callbacks=None,
engine=None,
batch_size=None,
epochs=None,
steps=None,
log_freq=2,
verbose=2,
save_freq=1,
save_dir=None,
metrics=None,
acc_step=1,
mode='train'):
cbks = callbacks or []
cbks = cbks if isinstance(cbks, (list, tuple)) else [cbks]

if not any(isinstance(k, ProgBarLogger) for k in cbks) and verbose:
cbks = [ProgBarLoggerAuto(log_freq, verbose=verbose)] + cbks

if not any(isinstance(k, LRScheduler) for k in cbks):
cbks = [LRSchedulerAuto()] + cbks

if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpointAuto(save_freq, save_dir)]

if not any(isinstance(k, Profiler) for k in cbks) and verbose == 3:
cbks = cbks + [Profiler(timer_only=True)]

if not any(isinstance(k, History) for k in cbks):
cbks = cbks + [History()]

for i, k in enumerate(cbks):
if isinstance(k, ProgBarLogger):
cbks[i] = ProgBarLoggerAuto(k.log_freq, k.verbose)
if isinstance(k, LRScheduler):
cbks[i] = LRSchedulerAuto(k.by_step, k.by_epoch)
if isinstance(k, ModelCheckpoint):
cbks[i] = ModelCheckpointAuto(k.save_freq, k.save_dir)

cbk_list = CallbackList(cbks)
cbk_list.set_model(engine)
metrics = metrics or [] if mode != 'test' else []
params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps,
'verbose': verbose,
'metrics': metrics,
'acc_step': acc_step,
}
cbk_list.set_params(params)
return cbk_list


class ProgBarLoggerAuto(ProgBarLogger):

def __init__(self, log_freq=1, verbose=2):
super(ProgBarLoggerAuto, self).__init__(log_freq, verbose)

def _is_print(self):
return True

def _updates(self, logs, mode):
values = []
metrics = getattr(self, '%s_metrics' % (mode))
progbar = getattr(self, '%s_progbar' % (mode))
steps = getattr(self, '%s_step' % (mode))

for k in metrics:
if k in logs:
values.append((k, logs[k]))

if 'lr' in logs:
values.append(('lr', logs['lr']))

fetches_logs = logs.get('fetches', {})
collect_logging = get_collection(CollectionNames.LOGGING)
for name, var in collect_logging:
k = name or var.name
if k in fetches_logs:
values.append((k, fetches_logs[k]))

out_logs = logs.get('outputs', {})
for k in out_logs:
values.append((k, out_logs[k]))

if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt)))
values.append(
('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt)))
values.append(
('ips', "%.5f samples/sec" %
(samples / (timer['data_time'] + timer['batch_time']))))
timer['count'] = 0
timer['samples'] = 0
timer['data_time'] = 0.
timer['batch_time'] = 0.

progbar.update(steps, values)

def on_eval_batch_end(self, step, logs=None):
logs = logs or {}
self.eval_step += 1
samples = self.params['batch_size']
self.evaled_samples += samples

self._eval_timer['batch_time'] += (
time.time() - self._eval_timer['batch_data_end_time'])
self._eval_timer['count'] += 1
samples = self.params['batch_size']
self._eval_timer['samples'] += samples

if self._is_print() and self.eval_step % self.log_freq == 0:
if self.eval_steps is None or self.eval_step < self.eval_steps:
self._updates(logs, 'eval')

self._eval_timer['batch_start_time'] = time.time()


class LRSchedulerAuto(LRScheduler):

def __init__(self, by_step=True, by_epoch=False):
super(LRSchedulerAuto, self).__init__(by_step, by_epoch)

def on_epoch_begin(self, epoch=None, logs=None):
self.acc_step = self.params["acc_step"]
self.epoch = epoch
self.train_step = 0

def on_train_batch_end(self, step, logs=None):
self.train_step += 1

if self.by_step and self.train_step % self.acc_step == 0:
if self.model._optimizer and \
hasattr(self.model._optimizer, '_learning_rate') and \
isinstance(self.model._optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
self.model._optimizer._learning_rate.step()


class History(Callback):

def __init__(self):
self.history = {}

def on_train_begin(self, logs=None):
self.epoch = []

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)

self.model.history = self


class Profiler(Callback):

def __init__(self, *args, **kwargs):
self.prof = paddle.profiler.Profiler(*args, **kwargs)

def on_epoch_begin(self, epoch=None, logs=None):
self.epoch = epoch
self.train_step = 0
self.batch_size = self.params["batch_size"]
self.steps = self.params['steps']

def on_train_begin(self, logs=None):
self.prof.start()

def on_train_batch_end(self, step, logs=None):
self.train_step += 1
self.prof.step(num_samples=self.batch_size)
print("step {}:{}".format(self.train_step,
self.prof.step_info(unit='samples')))

def on_train_end(self, logs=None):
self.prof.stop()
self.prof.summary()


class ModelCheckpointAuto(ModelCheckpoint):

def __init__(self, *args, **kwargs):
super(ModelCheckpointAuto, self).__init__(*args, **kwargs)

def _is_save(self):
return self.model and self.save_dir

def on_epoch_end(self, epoch, logs=None):
if self._is_save() and (self.epoch + 1) % self.save_freq == 0:
path = '{}/epoch{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)

def on_train_end(self, logs=None):
if self._is_save():
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
86 changes: 83 additions & 3 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from paddle.fluid import core
from paddle.fluid import framework

from .utils import print_program_with_dist_attr, is_gradient_clip_op
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor
Expand Down Expand Up @@ -142,6 +142,7 @@ class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False

def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
Expand Down Expand Up @@ -366,7 +367,14 @@ def _update_dims_mapping_between_graphs(self):
def _update_dims_mapping_for_special(self):
# Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it
op_nodes = self._dist_context._serial_ordered_op_nodes
# NOTE: this list may be changed if Paddle changes the existing rules.
related_reader_ops = [
"create_py_reader", "create_double_buffer_reader", "read"
]
for op_node in op_nodes:
if op_node.op() is not None \
and op_node.op().type() in related_reader_ops:
continue
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
Expand Down Expand Up @@ -406,6 +414,7 @@ def _update_dims_mapping(self):
reach_fix_point = False
else:
reach_fix_point = True
# NOTE: this will be removed after changing the reshard rule
self._update_dims_mapping_for_special()

def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
Expand Down Expand Up @@ -494,14 +503,14 @@ def _find_nodes_related_to_cond(source_node):
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
Expand Down Expand Up @@ -719,6 +728,8 @@ def _update_process_mesh(self):
self._update_process_mesh_between_graphs()

def _prepare(self):
if self._has_prepared:
return
self._while_op_nodes = {}
self._array_nodes = {}
self._node_pairs_between_graphs = []
Expand All @@ -732,6 +743,8 @@ def _prepare(self):
if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
# Add the array input node
self._array_nodes[array_var_name].append(node.inputs[0])
if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None:
Expand All @@ -752,6 +765,7 @@ def _prepare(self):
and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(after_node, node))
self._has_prepared = True

def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program.
Expand Down Expand Up @@ -899,6 +913,72 @@ def _update_dist_attr_for_dp(self):
else:
dist_op.dist_attr = original_op_dist_attr

def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program

self._dist_context.initialize()

self._prepare()

has_set_dist_attr = set()

all_nodes = self._dist_context.serial_ordered_nodes
for node in all_nodes:
if node.is_op():
if node.op().type() in ["while"]:
continue
dist_op = self._dist_context.get_dist_op_for_graph(node)
op_dist_attr = dist_op.dist_attr
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
tensor_desc = tensor_node.var()
tensor_name = tensor_desc.name()
tensor = dist_op.get_serial_input(tensor_name)
# Use the first op to set the tensor dist attr
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name) if tensor.is_parameter else [
-1 for i in tensor_desc.shape()
]
has_set_dist_attr.add(tensor_name)
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_name = tensor_node.var().name()
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
has_set_dist_attr.add(tensor_name)

self._update_process_mesh_for_specials()

self._update_process_mesh_between_graphs()

self._update_dims_mapping_for_special()

self._update_dims_mapping_between_graphs()

# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()

# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()

self._dist_context.validate_dist_attr_for_program()

def _complete_high_order_grad_annotation(self, serial_main_program=None):
"""
NOTE:
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)

#########################################
# dataset configuration
#########################################
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
Loading