Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] fix enable_delay_scale_loss for static auto parallel … #68525

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
EagerParamBase,
Variable,
default_main_program,
in_dygraph_mode,
in_pir_mode,
use_pir_api,
)
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def get_placement_with_sharding(param, sharding_mesh_axis):


class _ShardOptimizer(Optimizer):
def __init__(self, optimizer, shard_fn=None):
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
assert (
optimizer is not None
), "The argument `optimizer` cannot be empty."
Expand All @@ -1025,6 +1026,7 @@ def __init__(self, optimizer, shard_fn=None):
self._shard_fn = shard_fn
self._sharding_mesh_axis = None
self._sharding_degree = None
self.gradient_accumulation_steps = gradient_accumulation_steps

if isinstance(
self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3)
Expand Down Expand Up @@ -1246,6 +1248,21 @@ def state_dict(self):
return self._inner_opt.state_dict()

def _append_optimize_op(self, block, param_and_grad):
if (
in_auto_parallel_align_mode() # In align mode, we use enable_delay_scale_loss by default
and in_dygraph_mode()
and param_and_grad[1].is_dist()
):
placements = param_and_grad[1].placements
meshs = param_and_grad[1].process_mesh
grad = param_and_grad[1]

for i in range(len(placements) - 1, -1, -1):
if isinstance(placements[i], dist.Partial):
placements[i] = dist.Replicate()
grad = dist.reshard(grad, meshs, placements)
grad /= self.gradient_accumulation_steps
param_and_grad = (param_and_grad[0], grad)
return self._inner_opt._append_optimize_op(block, param_and_grad)

def __getattr__(self, item):
Expand Down Expand Up @@ -1596,6 +1613,7 @@ def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor:
def shard_optimizer(
optimizer: Optimizer,
shard_fn: Callable[[str, Tensor, Tensor], Tensor] | None = None,
gradient_accumulation_steps: int = 1,
) -> _ShardOptimizer:
"""

Expand Down Expand Up @@ -1640,7 +1658,7 @@ def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py

"""
return _ShardOptimizer(optimizer, shard_fn)
return _ShardOptimizer(optimizer, shard_fn, gradient_accumulation_steps)


def shard_scaler(scaler: GradScaler) -> GradScaler:
Expand Down
105 changes: 89 additions & 16 deletions python/paddle/distributed/passes/auto_parallel_gradient_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,94 @@ def parse_program(
return grad_to_gradient_merge


def _find_trival_optimizer_ops(block):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只用 name string 判断 optimizer op 未来很容易遗漏,后续可能想一下用 一个 固定 opt_op_name_list 统一维护。

optimizer_ops = []
for op in block.ops:
if "adam" in op.name() or "sgd" in op.name():
optimizer_ops.append(op)
return optimizer_ops


def _get_prev_op(block, optimizer_op):
found = False
for op in reversed(block.ops):
if found:
return op
if op.id == optimizer_op.id:
found = True
return None


def _insert_scale_op_after(target_value, optimizer_op, scale, bias=0.0):
scaled_grad = paddle._C_ops.scale_(target_value, scale, bias, False)

scale_op = scaled_grad.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)

if "adam" in optimizer_op.name():
optimizer_op.operand(1).set_source(scaled_grad)
elif "sgd" in optimizer_op.name():
optimizer_op.operand(2).set_source(scaled_grad)


def _append_scale_op_before_comm(block, new_params_to_grads, k_steps):
for op in reversed(block.ops):
if op.op_role == int(OpRole.Backward):
paddle.pir.set_insertion_point_after(op)
break
for _, new_grad in new_params_to_grads:
new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False)

scale_op = new_grad.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)
paddle.pir.set_insertion_point_to_block_end(block)


def _append_scale_op_after_comm(block, optimizer_ops, k_steps):
for optimizer_op in optimizer_ops:
target_value = None
if "adam" in optimizer_op.name(): # adam and adamw are included
target_value = optimizer_op.operand_source(1)
elif "sgd" in optimizer_op.name():
target_value = optimizer_op.operand_source(2)
else:
raise NotImplementedError(
f"We yet support adamw, adam and sgd, but got {optimizer_op.name()}"
)
assert (
target_value is not None
), "target_value is not expected to be None"
insertion_point = target_value.get_defining_op()
if insertion_point is None:
# target_value is a gradient_merge_var, which hasn't defining_op
# so we find the prev op of optimizer_op, inserting a scale op behind.
insertion_point = _get_prev_op(block, optimizer_op)
paddle.pir.set_insertion_point_after(insertion_point)
_insert_scale_op_after(target_value, optimizer_op, 1.0 / k_steps)
paddle.pir.set_insertion_point_to_block_end(block)


def _pir_append_scale_op(program, new_params_to_grads, k_steps):
block = program.global_block()
optimizer_ops = _find_trival_optimizer_ops(block)
if len(optimizer_ops) > 0:
_append_scale_op_after_comm(block, optimizer_ops, k_steps)
else:
_append_scale_op_before_comm(block, new_params_to_grads, k_steps)


def _pir_parse_program(
main_program,
startup_program,
Expand All @@ -657,22 +745,7 @@ def _pir_parse_program(

# step3: append scale op
if avg:
main_block = main_program.global_block()
for op in reversed(main_block.ops):
if op.op_role == int(OpRole.Backward):
paddle.pir.set_insertion_point_after(op)
break
for _, new_grad in new_params_to_grads:
new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False)

scale_op = new_grad.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)
_pir_append_scale_op(main_program, new_params_to_grads, k_steps)


@register_pass("auto_parallel_gradient_merge_pass")
Expand Down