Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Recompute Pass #38920

Merged
merged 5 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,28 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
# TODO to add attribute for moment var
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":

param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping

out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(out,
out_dist_attr)

op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)

if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input(
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/distributed/auto_parallel/dist_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
"process_mesh", "dims_mapping", "shard_sizes", "device_placement"
]

_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"]
_g_op_dist_attr_field_keys = [
"process_mesh", "impl_type", "impl_idx", "is_recompute"
]

_g_op_input_suffix = "@input"

Expand Down Expand Up @@ -178,6 +180,7 @@ def __init__(self):
self._inputs_dist_attrs = {}
self._outputs_dist_attrs = {}
self._is_annotated = {}
self._is_recompute = False

@property
def process_mesh(self):
Expand Down Expand Up @@ -214,6 +217,15 @@ def impl_idx(self, impl_idx):
if impl_idx is not None:
self._impl_idx = impl_idx

@property
def is_recompute(self):
return self._is_recompute

@is_recompute.setter
def is_recompute(self, is_recompute):
assert isinstance(is_recompute, bool)
self._is_recompute = is_recompute

@property
def inputs_dist_attrs(self):
return self._inputs_dist_attrs
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def get_tensor_dist_attr_for_program(self, serial_tensor):
else:
return None

def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None

def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr)
self.add_dist_tensor_for_program(dist_tensor)
Expand All @@ -192,6 +199,13 @@ def get_op_dist_attr_for_program(self, serial_op):
else:
return None

def get_op_dist_attr_for_program_with_id(self, op_id):
dist_op = self._dist_ops_for_program.get(op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None

def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr)
self.add_dist_op_for_program(dist_op)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _init_default_dist_attr(self):
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False

def _filter_dist_attr(self, dist_attr):
if dist_attr is None:
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):


def is_parameter_related(varname, block):
if ".subprog_" in varname:
varname = varname[:varname.index(".subprog_")]
if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")]
assert block.has_var(varname)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ def backward(ctx, *args, **kwargs):
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
# NOTE: When amp and recompute pass are effective at the same time,
# if a parameter is casted and recomputed, the 'parameter@GARD' can not
# be found in the grad_op's output.
if "subprog_" in varname:
varname = varname[:varname.index(".subprog_")]

assert len(
backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# param initialization sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def forward(ctx, *args, **kwargs):
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down Expand Up @@ -968,7 +968,7 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down Expand Up @@ -1383,7 +1383,7 @@ def forward(ctx, *args, **kwargs):
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down Expand Up @@ -1666,7 +1666,7 @@ def forward(ctx, *args, **kwargs):
allreduce_op_dist_attr)

# init param sync
if Weight_var.is_parameter:
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def backward(ctx, *args, **kwargs):
assert 'Out' in kwargs, "output [{}] is not given".format('Out')
assert 'LossScaling' in kwargs, "output [{}] is not given".format(
'LossScaling')
assert 'OutGoodSteps' in kwargs, "input [{}] is not given".format(
assert 'OutGoodSteps' in kwargs, "output [{}] is not given".format(
'OutGoodSteps')
assert 'OutBadSteps' in kwargs, "input [{}] is not given".format(
assert 'OutBadSteps' in kwargs, "output [{}] is not given".format(
'OutBadSteps')

assert len(kwargs['FoundInfinite']) == 1, \
Expand Down
17 changes: 10 additions & 7 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _remove_distributed_attrs(self, main_program):
op._remove_attr(attr_name)

def _apply_pre_optimization_passed(self, main_program, startup_program,
loss, params_grads):
loss, params_grads, no_grad_set):
# apply amp pass
if self._dist_strategy.amp:
config = copy.deepcopy(self._dist_strategy.amp_configs)
Expand All @@ -111,11 +111,14 @@ def _apply_pre_optimization_passed(self, main_program, startup_program,

# apply recompute pass
if self._dist_strategy.recompute:
auto_parallel_recompute_pass = new_pass(
"auto_parallel_recompute_pass",
self._dist_strategy.recompute_configs)
auto_parallel_recompute_pass.apply(main_program, startup_program,
self._pass_context)
config = copy.deepcopy(self._dist_strategy.recompute_configs)
config["dist_context"] = self._dist_context
config["no_grad_set"] = copy.deepcopy(no_grad_set)
config["loss"] = loss
auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
config)
auto_parallel_recompute_pass.apply(
[main_program], [startup_program], self._pass_context)

def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks):
Expand Down Expand Up @@ -190,7 +193,7 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
# serial forward pass
self._apply_pre_optimization_passed(completed_main_program,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename _apply_pre_optimization_passed to _apply_pre_optimization_passes and _apply_post_optimization_passed to _apply_post_optimization_passes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

serial_startup_program, serial_loss,
params_grads)
params_grads, self._no_grad_set)
# Logical partition
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
Expand Down
11 changes: 6 additions & 5 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
from .operators.common import BACKWARD_ONLY_DIST_OPS

__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
Expand Down Expand Up @@ -200,7 +200,8 @@ def partition_main_program(self, serial_main_program, params_and_grads):
serial_output_varname] = new_varname

# partition op
if is_forward_op(op):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_forward_op(op) or op_dist_attr.is_recompute:
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement(
op, self._dist_context)
Expand Down Expand Up @@ -380,9 +381,9 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
# NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
return get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)
dist_op = get_distributed_operator_impl_container(backward_op.type)
if dist_op and op_dist_attr.impl_idx >= 0:
return dist_op.get_impl(op_dist_attr.impl_idx)

dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process_group import new_process_group, ProcessGroup, _g_process_group_map

# NOTE: If op in SPECIAL_OPS, it will not be resharded.
SPECIAL_OPS = ['check_finite_and_unscale', 'update_loss_scaling']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global variable should use _g_xxxx. Please rename SPECIAL_OPS to _g_special_ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.



class AllGatherOpDesc:
"""
Expand Down Expand Up @@ -966,6 +969,17 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]

def _is_special_op(op):
global SPECIAL_OPS
if op.type in SPECIAL_OPS:
return True
return False

if _is_special_op(op):
idx += 1
continue

dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None:
idx_offset = 0
Expand Down
9 changes: 2 additions & 7 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,8 @@ def set_grad_var_shape(program, dist_context):
assert op_dist_attr is not None

for var_name in op.output_arg_names:

assert "@GRAD" in var_name
if "@GRAD" not in var_name:
continue
forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast"
Expand Down Expand Up @@ -1076,11 +1076,6 @@ def is_backward_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)


def is_recompute_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == 9


def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_recompute import *
from .cpp_pass import *

__all__ = [
Expand Down
Loading