Skip to content

Commit

Permalink
[Auto Parallel] fix enable_delay_scale_loss for static auto parallel …
Browse files Browse the repository at this point in the history
…&& fix sharding degree (PaddlePaddle#68525)
  • Loading branch information
zhangyuqin1998 committed Oct 28, 2024
1 parent f210620 commit f668f2b
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 18 deletions.
22 changes: 20 additions & 2 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
EagerParamBase,
Variable,
default_main_program,
in_dygraph_mode,
in_pir_mode,
use_pir_api,
)
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def get_placement_with_sharding(param, sharding_mesh_axis):


class _ShardOptimizer(Optimizer):
def __init__(self, optimizer, shard_fn=None):
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
assert (
optimizer is not None
), "The argument `optimizer` cannot be empty."
Expand All @@ -1025,6 +1026,7 @@ def __init__(self, optimizer, shard_fn=None):
self._shard_fn = shard_fn
self._sharding_mesh_axis = None
self._sharding_degree = None
self.gradient_accumulation_steps = gradient_accumulation_steps

if isinstance(
self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3)
Expand Down Expand Up @@ -1246,6 +1248,21 @@ def state_dict(self):
return self._inner_opt.state_dict()

def _append_optimize_op(self, block, param_and_grad):
if (
in_auto_parallel_align_mode() # In align mode, we use enable_delay_scale_loss by default
and in_dygraph_mode()
and param_and_grad[1].is_dist()
):
placements = param_and_grad[1].placements
meshs = param_and_grad[1].process_mesh
grad = param_and_grad[1]

for i in range(len(placements) - 1, -1, -1):
if isinstance(placements[i], dist.Partial):
placements[i] = dist.Replicate()
grad = dist.reshard(grad, meshs, placements)
grad /= self.gradient_accumulation_steps
param_and_grad = (param_and_grad[0], grad)
return self._inner_opt._append_optimize_op(block, param_and_grad)

def __getattr__(self, item):
Expand Down Expand Up @@ -1596,6 +1613,7 @@ def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor:
def shard_optimizer(
optimizer: Optimizer,
shard_fn: Callable[[str, Tensor, Tensor], Tensor] | None = None,
gradient_accumulation_steps: int = 1,
) -> _ShardOptimizer:
"""
Expand Down Expand Up @@ -1640,7 +1658,7 @@ def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""
return _ShardOptimizer(optimizer, shard_fn)
return _ShardOptimizer(optimizer, shard_fn, gradient_accumulation_steps)


def shard_scaler(scaler: GradScaler) -> GradScaler:
Expand Down
105 changes: 89 additions & 16 deletions python/paddle/distributed/passes/auto_parallel_gradient_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,94 @@ def parse_program(
return grad_to_gradient_merge


def _find_trival_optimizer_ops(block):
optimizer_ops = []
for op in block.ops:
if "adam" in op.name() or "sgd" in op.name():
optimizer_ops.append(op)
return optimizer_ops


def _get_prev_op(block, optimizer_op):
found = False
for op in reversed(block.ops):
if found:
return op
if op.id == optimizer_op.id:
found = True
return None


def _insert_scale_op_after(target_value, optimizer_op, scale, bias=0.0):
scaled_grad = paddle._C_ops.scale_(target_value, scale, bias, False)

scale_op = scaled_grad.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)

if "adam" in optimizer_op.name():
optimizer_op.operand(1).set_source(scaled_grad)
elif "sgd" in optimizer_op.name():
optimizer_op.operand(2).set_source(scaled_grad)


def _append_scale_op_before_comm(block, new_params_to_grads, k_steps):
for op in reversed(block.ops):
if op.op_role == int(OpRole.Backward):
paddle.pir.set_insertion_point_after(op)
break
for _, new_grad in new_params_to_grads:
new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False)

scale_op = new_grad.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)
paddle.pir.set_insertion_point_to_block_end(block)


def _append_scale_op_after_comm(block, optimizer_ops, k_steps):
for optimizer_op in optimizer_ops:
target_value = None
if "adam" in optimizer_op.name(): # adam and adamw are included
target_value = optimizer_op.operand_source(1)
elif "sgd" in optimizer_op.name():
target_value = optimizer_op.operand_source(2)
else:
raise NotImplementedError(
f"We yet support adamw, adam and sgd, but got {optimizer_op.name()}"
)
assert (
target_value is not None
), "target_value is not expected to be None"
insertion_point = target_value.get_defining_op()
if insertion_point is None:
# target_value is a gradient_merge_var, which hasn't defining_op
# so we find the prev op of optimizer_op, inserting a scale op behind.
insertion_point = _get_prev_op(block, optimizer_op)
paddle.pir.set_insertion_point_after(insertion_point)
_insert_scale_op_after(target_value, optimizer_op, 1.0 / k_steps)
paddle.pir.set_insertion_point_to_block_end(block)


def _pir_append_scale_op(program, new_params_to_grads, k_steps):
block = program.global_block()
optimizer_ops = _find_trival_optimizer_ops(block)
if len(optimizer_ops) > 0:
_append_scale_op_after_comm(block, optimizer_ops, k_steps)
else:
_append_scale_op_before_comm(block, new_params_to_grads, k_steps)


def _pir_parse_program(
main_program,
startup_program,
Expand All @@ -657,22 +745,7 @@ def _pir_parse_program(

# step3: append scale op
if avg:
main_block = main_program.global_block()
for op in reversed(main_block.ops):
if op.op_role == int(OpRole.Backward):
paddle.pir.set_insertion_point_after(op)
break
for _, new_grad in new_params_to_grads:
new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False)

scale_op = new_grad.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)
_pir_append_scale_op(main_program, new_params_to_grads, k_steps)


@register_pass("auto_parallel_gradient_merge_pass")
Expand Down

0 comments on commit f668f2b

Please sign in to comment.