Skip to content

Commit

Permalink
【AutoParallelism】Add refined recompute support (PaddlePaddle#58421)
Browse files Browse the repository at this point in the history
* add refined-recompute support

* fix bug in recompute_pass

* fix coverage
  • Loading branch information
heavyrain-lzy authored and zeroRains committed Nov 8, 2023
1 parent 57d1ed2 commit e1a2f3c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 6 deletions.
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
from .interface import exclude_ops_in_recompute
from .interface import fetch
from .random import parallel_manual_seed

Expand Down
48 changes: 45 additions & 3 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,57 @@ def __call__(self, *args, **kwargs):

for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
op._set_attr(
'op_namescope', "/auto_parallel/rc_" + str(_g_recompute_idx)
)
if op.has_attr(
"op_namescope"
) and 'auto_parallel/exclude_rc' in op.attr("op_namescope"):
op._set_attr(
'op_namescope',
"/auto_parallel/rc_"
+ str(_g_recompute_idx)
+ "_exclude_rc",
)
else:
op._set_attr(
'op_namescope',
'/auto_parallel/rc_' + str(_g_recompute_idx),
)

return output

return RecomputeOperator(op)


def exclude_ops_in_recompute(run_function):
"""
Exclude some operators in recompute segements.
Args:
run_function (callabe): The callabe function to be excluded.
Returns:
ExcludeOperator: The callable object.
"""

class ExcludeOperator:
def __init__(self, run_function):
self._run_function = run_function

def __call__(self, *args, **kwargs):
default_prog = paddle.static.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._run_function(*args, **kwargs)
new_op_size = len(cur_block.ops)

for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
op._set_attr('op_namescope', "/auto_parallel/exclude_rc")

return output

return ExcludeOperator(run_function)


_g_collections = {}


Expand Down
10 changes: 9 additions & 1 deletion python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,15 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank):


def is_recompute_op(op):
return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr(
return (
op.has_attr('op_namescope')
and "/auto_parallel/rc" in op.attr('op_namescope')
and 'exclude_rc' not in op.attr('op_namescope')
)


def is_recompute_exclude_op(op):
return op.has_attr('op_namescope') and 'exclude_rc' in op.attr(
'op_namescope'
)

Expand Down
8 changes: 7 additions & 1 deletion python/paddle/distributed/passes/auto_parallel_recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_loss_op,
insert_dependencies_for_two_ops,
is_backward_op,
is_recompute_exclude_op,
is_recompute_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_dist_op_desc_original_id,
Expand Down Expand Up @@ -80,9 +81,13 @@ def build_states(self):

if not is_recompute_op(op):
self._checkpoints.extend(op.output_arg_names)
continue
if not is_recompute_exclude_op(op):
continue

seg_name = op.attr('op_namescope')
seg_name = (
seg_name if '_exclude_rc' not in seg_name else seg_name[:-11]
)
if seg_name not in self.seg_op_deps:
self.seg_op_deps[seg_name] = [i]
else:
Expand Down Expand Up @@ -317,6 +322,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
)

# 3. get vars that should be hold in memory
# list of var_names
vars_should_be_hold = []
for segment in segments:
vars_should_be_hold.extend(
Expand Down
4 changes: 3 additions & 1 deletion test/legacy_test/auto_parallel_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def forward(
if self.use_new_recompute and self.recompute_granularity == "core_attn":
out, weights = auto.recompute(self.core_attn)(q, k, v, attn_mask)
else:
out, weights = self.core_attn(q, k, v, attn_mask)
out, weights = auto.exclude_ops_in_recompute(self.core_attn)(
q, k, v, attn_mask
)

# project to output
out = self.out_proj(out)
Expand Down

0 comments on commit e1a2f3c

Please sign in to comment.