Skip to content

Commit

Permalink
[Auto Parallel] Move reduce to opt stage (#62157)
Browse files Browse the repository at this point in the history
* move reduce to opt stage

* set op_role for reduce op

* update

* fix

* add debug info

* add debug info

* skip reduce op which has @rename in the input name

* remove debug info

* update

* move scale op to opt stage

* add dp_gradient_sync_after_accumulate as a strategy

* fix

* add notes
  • Loading branch information
AndSonder authored Mar 5, 2024
1 parent f38e19b commit 23e0355
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
set_field_default_config(
GRADIENT_MERGE, "dp_gradient_sync_after_accumulate", False
)

#########################################
# pipeline configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,12 @@ def _apply_post_optimization(
)
dp_pass.apply([main_program], [startup_program], self._pass_context)

dp_gradient_sync_after_accumulate = (
self._strategy.gradient_merge.dp_gradient_sync_after_accumulate
)
if dp_gradient_sync_after_accumulate:
global_params_grads = params_grads

if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
Expand Down Expand Up @@ -485,7 +491,10 @@ def _apply_post_optimization(
if self.is_train and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
if dp_gradient_sync_after_accumulate:
config["params_grads"] = global_params_grads
else:
config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", config
)
Expand Down
71 changes: 68 additions & 3 deletions python/paddle/distributed/passes/auto_parallel_gradient_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

import paddle
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.operators.common import (
is_data_parallel_reduce_op,
is_data_parallel_scale_op,
)
from paddle.distributed.auto_parallel.static.process_group import (
get_world_process_group,
)
Expand Down Expand Up @@ -260,6 +264,51 @@ def _append_gradient_merge_backward_op(
return new_params_grads, grad_to_gradient_merge


def _move_reduce_to_optimizer_ops_block(
main_program, optimize_ops_block, params_grads
):
main_block = main_program.global_block()
removed_op_idx = []
params_grads_name = [grad.name for _, grad in params_grads]

for idx, op in list(enumerate(main_block.ops)):
if is_data_parallel_reduce_op(op):
op_input_names = op.desc.input_arg_names()
# NOTE(sonder): When "@RENAME@" is in the input name, it means that the op has been renamed.
# Such types input names are caused by shared parameter policy.
# Gradient merge should accumulate the gradient of ops without renaming.
if "@RENAME" in op_input_names[0]:
continue

reduce_op_desc = optimize_ops_block.desc._insert_op(
len(removed_op_idx)
)
reduce_op_desc.copy_from(op.desc)
reduce_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize)
removed_op_idx.append(idx)

if op.type in ["c_allreduce_sum", "c_reduce_sum"]:
scale_index = idx + 1
while scale_index < len(main_block.ops):
if is_data_parallel_scale_op(main_block.ops[scale_index]):
scale_op_desc = optimize_ops_block.desc._insert_op(
len(removed_op_idx)
)
scale_op_desc.copy_from(
main_block.ops[scale_index].desc
)
scale_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize)
removed_op_idx.append(scale_index)
break
scale_index += 1

for idx in removed_op_idx[::-1]:
main_block._remove_op(idx, sync=False)

main_block._sync_with_cpp()
return optimize_ops_block


def _create_cond_block_and_update_optimizer(
main_program,
cond_var,
Expand Down Expand Up @@ -390,7 +439,13 @@ def true_apply_gradient():


def parse_program(
main_program, startup_program, params_grads, k_steps, avg, dist_context
main_program,
startup_program,
params_grads,
k_steps,
avg,
dist_context,
dp_gradient_sync_after_accumulate,
):
# 1 remove optimizer_op from main_program
optimize_ops_block = _remove_and_get_optimizer_op(
Expand All @@ -405,10 +460,16 @@ def parse_program(
main_program, startup_program, params_grads, dist_context
)

# 3 create gradient_merge_cond
if dp_gradient_sync_after_accumulate:
# 3 move reduce op to optimizer_ops_block
optimize_ops_block = _move_reduce_to_optimizer_ops_block(
main_program, optimize_ops_block, params_grads
)

# 4 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)

# 4 create ConditionalBlock and append gradient merge optimizer ops
# 5 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(
main_program,
cond_var,
Expand Down Expand Up @@ -444,6 +505,9 @@ def _apply_single_impl(self, main_program, startup_program, context):
avg = self.get_attr("avg", False)
dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
dp_gradient_sync_after_accumulate = self.get_attr(
"dp_gradient_sync_after_accumulate", False
)
with paddle.static.program_guard(main_program, startup_program):
parse_program(
main_program,
Expand All @@ -452,6 +516,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
k_steps,
avg,
dist_context,
dp_gradient_sync_after_accumulate,
)

main_program._sync_with_cpp()

0 comments on commit 23e0355

Please sign in to comment.