Skip to content

Commit

Permalink
【Comm】switch c_allgather in fluid to all_gather in python call (Paddl…
Browse files Browse the repository at this point in the history
  • Loading branch information
liym27 authored and inaomIIsfarell committed Jul 31, 2024
1 parent e2d8052 commit a934541
Show file tree
Hide file tree
Showing 19 changed files with 125 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"send_v2",
"recv_v2",
"c_broadcast",
"c_allgather",
"all_gather",
"c_allreduce_sum",
"c_identity",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

import math

import numpy as np

import paddle

from .base_cost import CommOpCost, register_op_cost


Expand Down Expand Up @@ -81,7 +85,7 @@ def calc_time_tree(self):

@register_op_cost
class AllgatherOpCost(CommOpCost):
OP_TYPE = "c_allgather"
OP_TYPE = "all_gather"

def __init__(self, op=None, op_desc=None, comm_context=None):
super().__init__(op=op, op_desc=op_desc, comm_context=comm_context)
Expand All @@ -105,6 +109,38 @@ def calc_time_ring(self):
)
return time

@property
def comm_count(self):
from ..reshard import get_var_with_recursion

if self._comm_count is None:
dtype = None
shape = None
if self.op is not None:
vars = self.op.block.vars
try:
var_name = self.op.input("x")[0]
except:
var_name = self.op.output("out")[0]
var = get_var_with_recursion(
var_name, self.op.block, self.op.block.program
)
dtype = var.dtype
shape = var.shape
elif self.op_desc is not None:
dtype = self.op_desc["inputs"]["X"][0][0]
shape = self.op_desc["inputs"]["X"][0][1]

factor = None
if dtype == paddle.float32 or dtype == paddle.int32:
factor = 4
else:
raise ValueError(f"Unsupported comm dtype {dtype}")
comm_count = int(np.prod(shape)) * factor
self._comm_count = comm_count

return self._comm_count


@register_op_cost
class BroadcastOpCost(CommOpCost):
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/static/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def is_collective_comm_op(op):
"c_reduce_max",
"c_reduce_prod",
"c_broadcast",
"c_allgather",
"all_gather",
]
if op.type in comm_list:
return True
Expand Down Expand Up @@ -100,7 +100,7 @@ def get_comm_volume(comm_op, src_rank, tgt_rank):
tensor_bytes = tensor_size * get_dtype_bytes(tensor.dtype)
if "c_allreduce" in comm_op_type:
comm_volume = 2 * tensor_bytes
elif "c_allgather" in comm_op_type:
elif "all_gather" in comm_op_type:
comm_volume = tensor_bytes
elif "c_broadcast" in comm_op_type:
if comm_op.attr("root") == src_rank:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy

from paddle.common_ops_import import check_dtype, check_variable_and_dtype
from paddle.distributed.passes.pass_utils import AutoParallelStreamType
from paddle.framework import core
from paddle.static import Operator

Expand Down Expand Up @@ -200,10 +201,10 @@ def forward(ctx, *args, **kwargs):
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'norm'
)

# 2. insert c_allgather op
# create c_allgather output var
# 2. insert all_gather op
# create all_gather output var
allgather_out = main_block.create_var(
name=".".join(["c_allgather", X_var.name]),
name=".".join(["all_gather", X_var.name]),
dtype=X_var.dtype,
shape=X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
Expand All @@ -220,18 +221,18 @@ def forward(ctx, *args, **kwargs):
ctx.set_tensor_dist_attr_for_program(
allgather_out, allgather_out_dist_attr
)
c_allgather_op = main_block.append_op(
type='c_allgather',
inputs={'X': [X_var]},
outputs={'Out': [allgather_out]},
all_gather_op = main_block.append_op(
type='all_gather',
inputs={'x': [X_var]},
outputs={'out': [allgather_out]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'nranks': group.nranks,
'op_role': src_op.attr('op_role'),
},
)
# set c_allgather op dist_attr
# set all_gather op dist_attr
allgather_op_dist_attr = OperatorDistAttr()
allgather_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allgather_op_dist_attr.chunk_id = op_dist_attr.chunk_id
Expand All @@ -241,7 +242,10 @@ def forward(ctx, *args, **kwargs):
allgather_op_dist_attr.set_output_dims_mapping(
allgather_out.name, allgather_out_dist_attr.dims_mapping
)
ctx.set_op_dist_attr_for_program(c_allgather_op, allgather_op_dist_attr)
allgather_op_dist_attr.execution_stream = (
AutoParallelStreamType.CALC_STREAM.value
)
ctx.set_op_dist_attr_for_program(all_gather_op, allgather_op_dist_attr)

# 3. copy p_norm op desc and reset input name
# rename input
Expand Down Expand Up @@ -290,10 +294,10 @@ def backward(ctx, *args, **kwargs):

# 1. copy p_norm_grad op and reset input name and output name
new_kwargs = copy.deepcopy(kwargs)
new_kwargs['X'] = [".".join(["c_allgather", X_var.name])]
new_kwargs['X'] = [".".join(["all_gather", X_var.name])]
new_X_var = main_block._var_recursive(new_kwargs['X'][0])
new_X_grad = main_block.create_var(
name=".".join(["c_allgather", X_grad_var.name]),
name=".".join(["all_gather", X_grad_var.name]),
dtype=X_grad_var.dtype,
shape=new_X_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/static/pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def remove_other_rank_input_output_pass(dist_program):


# Note: this is the pass in the dense program
comm_ops = ["pd_op.c_allreduce_sum", "pd_op.c_allgather"]
comm_ops = ["pd_op.c_allreduce_sum", "pd_op.all_gather"]


def remove_unuseful_comm_op_pass(program):
Expand Down
15 changes: 9 additions & 6 deletions python/paddle/distributed/auto_parallel/static/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.passes.pass_utils import AutoParallelStreamType
from paddle.framework import LayerHelper, OpProtoHolder, Program, core
from paddle.utils import unique_name

Expand Down Expand Up @@ -667,8 +668,8 @@ def insert_allgather_op(
group = new_process_group(ranks)
idx_offset = 0

# insert c_allgather op
op_type = 'c_allgather'
# insert all_gather op
op_type = 'all_gather'
# to avoid name conflict with framework
helper = LayerHelper(op_type + "@RESHARD", **locals())
insert_operation = (
Expand All @@ -690,16 +691,18 @@ def insert_allgather_op(
allgather_op = insert_operation(
idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [allgather_out]},
inputs={'x': [tensor]},
outputs={'out': [allgather_out]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'nranks': group.nranks,
'op_role': op_role,
},
)
allgather_op._set_attr('op_namescope', "/auto_parallel/reshard")
allgather_op.dist_attr.execution_stream = (
AutoParallelStreamType.CALC_STREAM.value
)
idx_offset += 1

# insert split op
Expand Down Expand Up @@ -3202,7 +3205,7 @@ def _get_idx(comm_ranks, group_ranks):
group_ranks = op_desc.group
shape = op_desc.shape
allgather_desc = build_comm_desc(
"c_allgather", group_ranks, dtype, shape
"all_gather", group_ranks, dtype, shape
)
split_inputs_shape = []
for idx, dim in enumerate(shape):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import paddle
from paddle.distributed.passes.pass_utils import AutoParallelStreamType

from ..process_group import new_process_group
from .base_reshard_func import ReshardFunction, is_replicated, is_shard
Expand Down Expand Up @@ -129,8 +130,8 @@ def reshard_s_to_r_with_padding(
num_of_process = len(src_mesh.process_ids)

group = new_process_group(sorted(src_mesh.process_ids))
allgather_value = paddle._C_ops.c_allgather(
src_value, group.id, num_of_process, True
allgather_value = paddle._C_ops.all_gather(
src_value, group.id, num_of_process
)
allgather_type = self.infer_allgather_dist_type(src_value, split_axis)
allgather_value.set_type(allgather_type)
Expand All @@ -146,6 +147,9 @@ def reshard_s_to_r_with_padding(
src_mesh, [src_dist_attr], [new_dist_attr]
)
)
allgather_value.get_defining_op().set_execution_stream(
AutoParallelStreamType.CALC_STREAM.value
)

if split_axis != 0 or padding_num != 0:
allgather_op = allgather_value.get_defining_op()
Expand Down
23 changes: 18 additions & 5 deletions python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,19 +2428,32 @@ def update_grad_var_to_var(program, strategy, grad_var_to_var):
"cast",
"c_concat",
"concat",
"c_allgather",
"slice",
"all_gather",
]
if op.desc.type() in reshard_op_types:
input_names = op.desc.input_names()
if "X" in input_names or "Input" in input_names:
if (
"X" in input_names
or "Input" in input_names
or "x" in input_names
):
inputs = (
op.desc.input("X")
if "X" in input_names
else op.desc.input("Input")
else (
op.desc.input("Input")
if "Input" in input_names
else op.desc.input("x")
)
)
output_names = op.desc.output_names()
if "Out" in output_names or "out" in output_names:
outputs = (
op.desc.output("Out")
if "Out" in output_names
else op.desc.output("out")
)
if "Out" in op.desc.output_names():
outputs = op.desc.output("Out")
if inputs[0] in grad_var_to_var.keys():
for output in outputs:
grad_var_to_var[output] = grad_var_to_var[inputs[0]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@
"elementwise_add_grad",
"c_allreduce_sum",
"scale",
"c_allgather",
"all_gather",
"matmul_v2_grad",
"all_gather",
],
},
{ # DP + MP
Expand All @@ -82,8 +83,9 @@
"scale",
"c_allreduce_sum",
"scale",
"c_allgather",
"all_gather",
"matmul_v2_grad",
"all_gather",
],
},
# amp_level == 'o1'
Expand All @@ -97,7 +99,8 @@
"elementwise_add_grad",
"c_allreduce_sum",
"scale",
"c_allgather",
"all_gather",
"all_gather",
"matmul_v2_grad",
],
},
Expand All @@ -118,8 +121,9 @@
"scale",
"c_allreduce_sum",
"scale",
"c_allgather",
"all_gather",
"matmul_v2_grad",
"all_gather",
],
},
]
Expand Down Expand Up @@ -631,7 +635,7 @@ def _transform_backward(
to_delete_grad_of_param = []
if is_first_rank:
if is_sp:
# place the comm_op(c_allgather) before the elementwise_add_grad
# place the comm_op(all_gather) before the elementwise_add_grad
for segment in reversed(backward_segments):
add_grad_op = global_block.ops[segment[0]]
matmul_grad_op = global_block.ops[segment[-1] - 1]
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/distributed/passes/auto_parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _task_stream(self):

if not is_after_send_op or not is_after_recv_op:
if self._cur_pp_stage == self._pp_stages - 1:
# NOTE: the c_sync_calc_stream about c_allgather cannot be removed
# NOTE: the c_sync_calc_stream about all_gather cannot be removed
if (
op.type == "c_sync_calc_stream"
and src_block.ops[i + 1].type == "send_v2"
Expand All @@ -380,7 +380,11 @@ def _task_stream(self):
# HACKCODE: the varname of send_v2 op, cast op should be recorded for brpc comm
if (
op.type
not in ["recv_2", "assign", "c_allgather"]
not in [
"recv_2",
"assign",
"all_gather",
]
and op.has_attr('op_namescope')
and "/auto_parallel/reshard"
in op.attr('op_namescope')
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/fuse_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def filter_all_collective_op_indices(block):
"c_allreduce_min",
"c_allgather",
"c_broadcast",
"all_gather",
}

match_op_indices = []
Expand Down
6 changes: 3 additions & 3 deletions test/auto_parallel/hybrid_strategy/pir_reshard_nd_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def run_ss_to_ss_case(self):
all_gather_ops = []
slice_ops = []
for i, op in enumerate(new_ops):
if op.name() == "pd_op.c_allgather":
if op.name() == "pd_op.all_gather":
all_gather_ops.append(op)
elif op.name() == "pd_op.slice":
slice_ops.append(op)
Expand Down Expand Up @@ -279,10 +279,10 @@ def run_ps_to_ps_case(self):
ops = dist_program.global_block().ops
op_names = [op.name() for op in ops]
assert "pd_op.c_allreduce_sum" in op_names
assert "pd_op.c_allgather" in op_names
assert "pd_op.all_gather" in op_names
assert "pd_op.slice" in op_names

allgather_op = ops[op_names.index("pd_op.c_allgather")]
allgather_op = ops[op_names.index("pd_op.all_gather")]
allreduce_sum_op = ops[op_names.index("pd_op.c_allreduce_sum")]
slice_op = ops[op_names.index("pd_op.slice")]

Expand Down
Loading

0 comments on commit a934541

Please sign in to comment.