Skip to content

Commit

Permalink
[Cherry-Pick][AutoParallel] auto_parallel cherry-pick to release2.4 (#…
Browse files Browse the repository at this point in the history
…47145)

* [Auto Parallel] Make Engine class callable (#46416)

* [Auto Parallel] Imporve the user-defined fetches and logging

* [Auto Parallel] Make Engine class callable

* [Auto Parallel] Update the data loading of tuner

* Print IPS in auto parallel Engine (#46554)

* [AutoParallel] fix dist_split (#46505)

* [AutoParallel] fix dist_split

* add unittest

* update cmakelist

* [AutoParallel] fix sharding (#46572)

* [AutoParallel] fix process_mesh (#46583)

* [AutoParallel] fix reshard when train with eval (#46605)

* [AutoParallel] fix reshard when train with eval

* fix mppp

* [AutoParallel] fix amp when predict (#46637)

* [Auto Parallel]Update comp cost and completion for gpt auto search (#46387)

* update comp cost and completion for gpt auto search

* add unittest

* [Auto Parallel] Fix bugs caused by the inconsistent outputs of Engine API (#46633)

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Improve the fine-grained APIs (#46552)

* [Auto Parallel] Suppport different dataloaders

* [Auto Parallel] Add num_shards config for dataset

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Add the prepare API and replace __call__ with run

* [Auto Parallel] Improve the private implementations of Engine

* [Auto Parallel] Set capacity of dataloader for opt tuning

* [Auto Parallel] [WIP] Change the fine-grained API

* [Auto Parallel] Improve APIs to support different user cases

* [Auto Parallel] Add removed config

* [Auto Parallel] Add imports

* [Auto Parallel] Fix bugs for to_static

* [Auto Parallel] Remove unnecessary imports

* bugfix (#46921)

* [Auto Parallel] Fix the bug for None labels (#46987)

* [AutoParallel] adapt for gpt-gen (#46771)

* for gpt-gen

* fix reshard

* adapt assign and shape op

* add dist_assign & unittest

* add conditional block unittest

* rename unittest

* [Auto Parallel] Fix the bug of completion (#47056)

* [Auto Parallel] Fix the bug for None labels

* [Auto Parallel] Fix the completion bug

* [AutoParallel] add callbacks (#47014)

* [AutoParallel] add callbacks

* fix unittest

* fix dist_context

* fix engine

* fix cmakelist

* fix unittest's returns

* fix cmakelist

* [Auto Parallel] Add cost interface (#47043)

* add cost interface

* update inferface and add unittest

* update unittest

* update inferface

* [Auto Parallel]Add parallel tuner (#46189)

* add parallel tuner

* add unittest

* fix unittest

* set timeout of unittest

* set unittest timeout

* fix auto_mode setting

* update unittest

* sync from develop and update unittest

* remove unused import

* update unittest

* update cmakelist

* add unittests

Co-authored-by: Yulong Ao <aoyulong@baidu.com>
Co-authored-by: Ruibiao Chen <chenruibiao@baidu.com>
Co-authored-by: caozhou <48191911+Caozhou1995@users.noreply.github.com>
Co-authored-by: JZ-LIANG <jianzhongliang10@gmail.com>
  • Loading branch information
5 people authored Oct 19, 2022
1 parent 23f2a4e commit 90b3179
Show file tree
Hide file tree
Showing 57 changed files with 4,239 additions and 606 deletions.
226 changes: 226 additions & 0 deletions python/paddle/distributed/auto_parallel/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

import paddle
from paddle.hapi.callbacks import ProgBarLogger, ModelCheckpoint, LRScheduler, CallbackList, Callback
from .interface import CollectionNames, get_collection


def config_callbacks(callbacks=None,
engine=None,
batch_size=None,
epochs=None,
steps=None,
log_freq=2,
verbose=2,
save_freq=1,
save_dir=None,
metrics=None,
acc_step=1,
mode='train'):
cbks = callbacks or []
cbks = cbks if isinstance(cbks, (list, tuple)) else [cbks]

if not any(isinstance(k, ProgBarLogger) for k in cbks) and verbose:
cbks = [ProgBarLoggerAuto(log_freq, verbose=verbose)] + cbks

if not any(isinstance(k, LRScheduler) for k in cbks):
cbks = [LRSchedulerAuto()] + cbks

if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpointAuto(save_freq, save_dir)]

if not any(isinstance(k, Profiler) for k in cbks) and verbose == 3:
cbks = cbks + [Profiler(timer_only=True)]

if not any(isinstance(k, History) for k in cbks):
cbks = cbks + [History()]

for i, k in enumerate(cbks):
if isinstance(k, ProgBarLogger):
cbks[i] = ProgBarLoggerAuto(k.log_freq, k.verbose)
if isinstance(k, LRScheduler):
cbks[i] = LRSchedulerAuto(k.by_step, k.by_epoch)
if isinstance(k, ModelCheckpoint):
cbks[i] = ModelCheckpointAuto(k.save_freq, k.save_dir)

cbk_list = CallbackList(cbks)
cbk_list.set_model(engine)
metrics = metrics or [] if mode != 'test' else []
params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps,
'verbose': verbose,
'metrics': metrics,
'acc_step': acc_step,
}
cbk_list.set_params(params)
return cbk_list


class ProgBarLoggerAuto(ProgBarLogger):

def __init__(self, log_freq=1, verbose=2):
super(ProgBarLoggerAuto, self).__init__(log_freq, verbose)

def _is_print(self):
return True

def _updates(self, logs, mode):
values = []
metrics = getattr(self, '%s_metrics' % (mode))
progbar = getattr(self, '%s_progbar' % (mode))
steps = getattr(self, '%s_step' % (mode))

for k in metrics:
if k in logs:
values.append((k, logs[k]))

if 'lr' in logs:
values.append(('lr', logs['lr']))

fetches_logs = logs.get('fetches', {})
collect_logging = get_collection(CollectionNames.LOGGING)
for name, var in collect_logging:
k = name or var.name
if k in fetches_logs:
values.append((k, fetches_logs[k]))

out_logs = logs.get('outputs', {})
for k in out_logs:
values.append((k, out_logs[k]))

if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt)))
values.append(
('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt)))
values.append(
('ips', "%.5f samples/sec" %
(samples / (timer['data_time'] + timer['batch_time']))))
timer['count'] = 0
timer['samples'] = 0
timer['data_time'] = 0.
timer['batch_time'] = 0.

progbar.update(steps, values)

def on_eval_batch_end(self, step, logs=None):
logs = logs or {}
self.eval_step += 1
samples = self.params['batch_size']
self.evaled_samples += samples

self._eval_timer['batch_time'] += (
time.time() - self._eval_timer['batch_data_end_time'])
self._eval_timer['count'] += 1
samples = self.params['batch_size']
self._eval_timer['samples'] += samples

if self._is_print() and self.eval_step % self.log_freq == 0:
if self.eval_steps is None or self.eval_step < self.eval_steps:
self._updates(logs, 'eval')

self._eval_timer['batch_start_time'] = time.time()


class LRSchedulerAuto(LRScheduler):

def __init__(self, by_step=True, by_epoch=False):
super(LRSchedulerAuto, self).__init__(by_step, by_epoch)

def on_epoch_begin(self, epoch=None, logs=None):
self.acc_step = self.params["acc_step"]
self.epoch = epoch
self.train_step = 0

def on_train_batch_end(self, step, logs=None):
self.train_step += 1

if self.by_step and self.train_step % self.acc_step == 0:
if self.model._optimizer and \
hasattr(self.model._optimizer, '_learning_rate') and \
isinstance(self.model._optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
self.model._optimizer._learning_rate.step()


class History(Callback):

def __init__(self):
self.history = {}

def on_train_begin(self, logs=None):
self.epoch = []

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)

self.model.history = self


class Profiler(Callback):

def __init__(self, *args, **kwargs):
self.prof = paddle.profiler.Profiler(*args, **kwargs)

def on_epoch_begin(self, epoch=None, logs=None):
self.epoch = epoch
self.train_step = 0
self.batch_size = self.params["batch_size"]
self.steps = self.params['steps']

def on_train_begin(self, logs=None):
self.prof.start()

def on_train_batch_end(self, step, logs=None):
self.train_step += 1
self.prof.step(num_samples=self.batch_size)
print("step {}:{}".format(self.train_step,
self.prof.step_info(unit='samples')))

def on_train_end(self, logs=None):
self.prof.stop()
self.prof.summary()


class ModelCheckpointAuto(ModelCheckpoint):

def __init__(self, *args, **kwargs):
super(ModelCheckpointAuto, self).__init__(*args, **kwargs)

def _is_save(self):
return self.model and self.save_dir

def on_epoch_end(self, epoch, logs=None):
if self._is_save() and (self.epoch + 1) % self.save_freq == 0:
path = '{}/epoch{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)

def on_train_end(self, logs=None):
if self._is_save():
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
86 changes: 83 additions & 3 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from paddle.fluid import core
from paddle.fluid import framework

from .utils import print_program_with_dist_attr, is_gradient_clip_op
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor
Expand Down Expand Up @@ -142,6 +142,7 @@ class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False

def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
Expand Down Expand Up @@ -366,7 +367,14 @@ def _update_dims_mapping_between_graphs(self):
def _update_dims_mapping_for_special(self):
# Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it
op_nodes = self._dist_context._serial_ordered_op_nodes
# NOTE: this list may be changed if Paddle changes the existing rules.
related_reader_ops = [
"create_py_reader", "create_double_buffer_reader", "read"
]
for op_node in op_nodes:
if op_node.op() is not None \
and op_node.op().type() in related_reader_ops:
continue
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
Expand Down Expand Up @@ -406,6 +414,7 @@ def _update_dims_mapping(self):
reach_fix_point = False
else:
reach_fix_point = True
# NOTE: this will be removed after changing the reshard rule
self._update_dims_mapping_for_special()

def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
Expand Down Expand Up @@ -494,14 +503,14 @@ def _find_nodes_related_to_cond(source_node):
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
Expand Down Expand Up @@ -719,6 +728,8 @@ def _update_process_mesh(self):
self._update_process_mesh_between_graphs()

def _prepare(self):
if self._has_prepared:
return
self._while_op_nodes = {}
self._array_nodes = {}
self._node_pairs_between_graphs = []
Expand All @@ -732,6 +743,8 @@ def _prepare(self):
if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
# Add the array input node
self._array_nodes[array_var_name].append(node.inputs[0])
if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None:
Expand All @@ -752,6 +765,7 @@ def _prepare(self):
and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(after_node, node))
self._has_prepared = True

def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program.
Expand Down Expand Up @@ -899,6 +913,72 @@ def _update_dist_attr_for_dp(self):
else:
dist_op.dist_attr = original_op_dist_attr

def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program

self._dist_context.initialize()

self._prepare()

has_set_dist_attr = set()

all_nodes = self._dist_context.serial_ordered_nodes
for node in all_nodes:
if node.is_op():
if node.op().type() in ["while"]:
continue
dist_op = self._dist_context.get_dist_op_for_graph(node)
op_dist_attr = dist_op.dist_attr
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
tensor_desc = tensor_node.var()
tensor_name = tensor_desc.name()
tensor = dist_op.get_serial_input(tensor_name)
# Use the first op to set the tensor dist attr
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name) if tensor.is_parameter else [
-1 for i in tensor_desc.shape()
]
has_set_dist_attr.add(tensor_name)
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_name = tensor_node.var().name()
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
has_set_dist_attr.add(tensor_name)

self._update_process_mesh_for_specials()

self._update_process_mesh_between_graphs()

self._update_dims_mapping_for_special()

self._update_dims_mapping_between_graphs()

# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()

# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()

self._dist_context.validate_dist_attr_for_program()

def _complete_high_order_grad_annotation(self, serial_main_program=None):
"""
NOTE:
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)

#########################################
# dataset configuration
#########################################
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
Loading

0 comments on commit 90b3179

Please sign in to comment.