Skip to content

Commit

Permalink
[AutoParallel] Recompute Pass (#38920)
Browse files Browse the repository at this point in the history
* [AutoParallel] Recompute Pass

* update unittest

* reshard for amp

* add comment
  • Loading branch information
zhaoyinglia authored Jan 18, 2022
1 parent 4aa91fd commit 3084573
Show file tree
Hide file tree
Showing 18 changed files with 569 additions and 39 deletions.
22 changes: 22 additions & 0 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,28 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
# TODO to add attribute for moment var
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":

param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping

out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(out,
out_dist_attr)

op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)

if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input(
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/distributed/auto_parallel/dist_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
"process_mesh", "dims_mapping", "shard_sizes", "device_placement"
]

_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"]
_g_op_dist_attr_field_keys = [
"process_mesh", "impl_type", "impl_idx", "is_recompute"
]

_g_op_input_suffix = "@input"

Expand Down Expand Up @@ -178,6 +180,7 @@ def __init__(self):
self._inputs_dist_attrs = {}
self._outputs_dist_attrs = {}
self._is_annotated = {}
self._is_recompute = False

@property
def process_mesh(self):
Expand Down Expand Up @@ -214,6 +217,15 @@ def impl_idx(self, impl_idx):
if impl_idx is not None:
self._impl_idx = impl_idx

@property
def is_recompute(self):
return self._is_recompute

@is_recompute.setter
def is_recompute(self, is_recompute):
assert isinstance(is_recompute, bool)
self._is_recompute = is_recompute

@property
def inputs_dist_attrs(self):
return self._inputs_dist_attrs
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def get_tensor_dist_attr_for_program(self, serial_tensor):
else:
return None

def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None

def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr)
self.add_dist_tensor_for_program(dist_tensor)
Expand All @@ -192,6 +199,13 @@ def get_op_dist_attr_for_program(self, serial_op):
else:
return None

def get_op_dist_attr_for_program_with_id(self, op_id):
dist_op = self._dist_ops_for_program.get(op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None

def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr)
self.add_dist_op_for_program(dist_op)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _init_default_dist_attr(self):
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False

def _filter_dist_attr(self, dist_attr):
if dist_attr is None:
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):


def is_parameter_related(varname, block):
if ".subprog_" in varname:
varname = varname[:varname.index(".subprog_")]
if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")]
assert block.has_var(varname)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ def backward(ctx, *args, **kwargs):
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE: When amp and recompute pass are effective at the same time,
# if a parameter is casted and recomputed, the 'parameter@GARD' can not
# be found in the grad_op's output.
if "subprog_" in varname:
varname = varname[:varname.index(".subprog_")]

assert len(
backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# param initialization sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def forward(ctx, *args, **kwargs):
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down Expand Up @@ -968,7 +968,7 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down Expand Up @@ -1383,7 +1383,7 @@ def forward(ctx, *args, **kwargs):
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down Expand Up @@ -1666,7 +1666,7 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def backward(ctx, *args, **kwargs):
assert 'Out' in kwargs, "output [{}] is not given".format('Out')
assert 'LossScaling' in kwargs, "output [{}] is not given".format(
'LossScaling')
assert 'OutGoodSteps' in kwargs, "input [{}] is not given".format(
assert 'OutGoodSteps' in kwargs, "output [{}] is not given".format(
'OutGoodSteps')
assert 'OutBadSteps' in kwargs, "input [{}] is not given".format(
assert 'OutBadSteps' in kwargs, "output [{}] is not given".format(
'OutBadSteps')

assert len(kwargs['FoundInfinite']) == 1, \
Expand Down
25 changes: 14 additions & 11 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def _remove_distributed_attrs(self, main_program):
if suffix in attr_name:
op._remove_attr(attr_name)

def _apply_pre_optimization_passed(self, main_program, startup_program,
loss, params_grads):
def _apply_pre_optimization_passes(self, main_program, startup_program,
loss, params_grads, no_grad_set):
# apply amp pass
if self._dist_strategy.amp:
config = copy.deepcopy(self._dist_strategy.amp_configs)
Expand All @@ -111,11 +111,14 @@ def _apply_pre_optimization_passed(self, main_program, startup_program,

# apply recompute pass
if self._dist_strategy.recompute:
auto_parallel_recompute_pass = new_pass(
"auto_parallel_recompute_pass",
self._dist_strategy.recompute_configs)
auto_parallel_recompute_pass.apply(main_program, startup_program,
self._pass_context)
config = copy.deepcopy(self._dist_strategy.recompute_configs)
config["dist_context"] = self._dist_context
config["no_grad_set"] = copy.deepcopy(no_grad_set)
config["loss"] = loss
auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
config)
auto_parallel_recompute_pass.apply(
[main_program], [startup_program], self._pass_context)

def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks):
Expand Down Expand Up @@ -144,7 +147,7 @@ def _apply_optimize(self, main_program, startup_program, params_grads):

return optimize_ops

def _apply_post_optimization_passed(self, main_program, startup_program,
def _apply_post_optimization_passes(self, main_program, startup_program,
rank, params_grads):

if self._dist_strategy.sharding:
Expand Down Expand Up @@ -188,9 +191,9 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
self._parameter_list, self._no_grad_set, self._callbacks)

# serial forward pass
self._apply_pre_optimization_passed(completed_main_program,
self._apply_pre_optimization_passes(completed_main_program,
serial_startup_program, serial_loss,
params_grads)
params_grads, self._no_grad_set)
# Logical partition
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
Expand All @@ -207,7 +210,7 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):

reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)

self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog,
self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
g_process_group_map = None
if not relaunch_phase:
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
from .operators.common import BACKWARD_ONLY_DIST_OPS

__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
Expand Down Expand Up @@ -200,7 +200,8 @@ def partition_main_program(self, serial_main_program, params_and_grads):
serial_output_varname] = new_varname

# partition op
if is_forward_op(op):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_forward_op(op) or op_dist_attr.is_recompute:
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement(
op, self._dist_context)
Expand Down Expand Up @@ -380,9 +381,9 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
# NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
return get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)
dist_op = get_distributed_operator_impl_container(backward_op.type)
if dist_op and op_dist_attr.impl_idx >= 0:
return dist_op.get_impl(op_dist_attr.impl_idx)

dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process_group import new_process_group, ProcessGroup, _g_process_group_map

# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']


class AllGatherOpDesc:
"""
Expand Down Expand Up @@ -966,6 +969,17 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]

def _is_special_op(op):
global _g_special_ops
if op.type in _g_special_ops:
return True
return False

if _is_special_op(op):
idx += 1
continue

dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None:
idx_offset = 0
Expand Down
9 changes: 2 additions & 7 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,8 @@ def set_grad_var_shape(program, dist_context):
assert op_dist_attr is not None

for var_name in op.output_arg_names:

assert "@GRAD" in var_name
if "@GRAD" not in var_name:
continue
forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast"
Expand Down Expand Up @@ -1076,11 +1076,6 @@ def is_backward_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)


def is_recompute_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == 9


def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_recompute import *
from .cpp_pass import *

__all__ = [
Expand Down
Loading

0 comments on commit 3084573

Please sign in to comment.