Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] extract split matmul_grad_op to pass_utils #62737

Merged
merged 6 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 28 additions & 169 deletions python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

from ..auto_parallel.static.utils import (
get_logger,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
from .pass_base import PassBase, register_pass
from .pass_utils import AutoParallelStreamType
from .pass_utils import split_matmul_grad_to_matmul

logger = get_logger(logging.INFO)

Expand Down Expand Up @@ -84,44 +83,6 @@ def _get_all_matmul_grad_and_allreduce_pairs(self, block):
matmul_grad_id_to_allreduce_id[i] = j
return matmul_grad_id_to_allreduce_id

def _insert_reshape_op(self, block, index, x, shape, op_role, out=None):
var_x = block.var(x[0])
x_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(var_x)

if out is None:
out = block.create_var(
name=f"{x[0]}@reshape.out",
dtype=var_x.dtype,
persistable=False,
)
self.dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr)

x_shape = block.create_var(
name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype
)
self.dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr)

reshape_op = block._insert_op_without_sync(
index=index,
type="reshape2",
inputs={"X": x},
outputs={"Out": out, "XShape": x_shape},
attrs={
"shape": shape,
"op_role": op_role,
'op_namescope': self.op_namescope,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
reshape_op,
process_mesh=x_dist_attr.process_mesh,
ref_mapping=x_dist_attr.dims_mapping,
ctx=self.dist_context,
chunk_id=x_dist_attr.chunk_id,
)

return out

def _split_matmul_grad_and_multi_streaming_allreduce(
self, block, matmul_grad_id_to_allreduce_id
):
Expand All @@ -133,20 +94,15 @@ def _split_matmul_grad_and_multi_streaming_allreduce(
matmul_grad_op = ops[matmul_grad_id]
allreduce_op = ops[allreduce_id]

# NOTE(Sonder): Why move those operations to the back of matmul_v2?
# When using amp_master_grad, the cast operation is inserted after matmul_grad.
# However, when employing allreduce_matmul_grad_overlapping, the matmul_grad is
# split into two matmul operations. In this case, some operations would access
# uninitialized tensors. Therefore, we move the cast operation to the back of the
# second matmul operation to avoid this problem.
# NOTE(Sonder): When there are ops between matmul_grad and allreduce, we should check whether
# these ops rely on the output of the intermediate ops. If so, we should not split the matmul_grad.
# Otherwise, the output of the intermediate ops will get wrong results.
skip_overlapping = False
moved_ops_idx = []
moved_ops_output = []
matmul_grad_output = matmul_grad_op.output('Y@GRAD')[0]

for idx in range(matmul_grad_id + 1, allreduce_id):
if matmul_grad_output in ops[idx].desc.input_arg_names():
moved_ops_idx.append(idx)
moved_ops_output.extend(ops[idx].desc.output_arg_names())
else:
for input_name in ops[idx].desc.input_arg_names():
Expand All @@ -156,137 +112,40 @@ def _split_matmul_grad_and_multi_streaming_allreduce(
if skip_overlapping:
continue

for i, idx in enumerate(moved_ops_idx):
op = ops[idx]
dist_attr = self.dist_context.get_op_dist_attr_for_program(op)

op_inputs = op.desc.input_names()
op_outputs = op.desc.output_names()

op_inputs = {name: op.input(name) for name in op_inputs}
op_outputs = {name: op.output(name) for name in op_outputs}

op = block._insert_op_without_sync(
index=allreduce_id + 1 + i,
type=op.type,
inputs=op_inputs,
outputs=op_outputs,
attrs=op.all_attrs(),
)

self.dist_context.set_op_dist_attr_for_program(op, dist_attr)

for i, idx in enumerate(moved_ops_idx):
block._remove_op(idx - i, sync=False)
allreduce_id -= 1

tran_x = matmul_grad_op.attr("trans_x")
assert (
not tran_x
), f"matmul_grad(id={matmul_grad_id}) with tran_x == True is not supported for column parallel linear backward overlapping"
tran_y = matmul_grad_op.attr("trans_y")
assert (
not tran_y
), f"matmul_grad(id={matmul_grad_id}) with tran_y == True is not supported for column parallel linear backward overlapping"

allreduce_op.dist_attr.execution_stream = (
AutoParallelStreamType.MP_STREAM.value
# matmul_grad_op => matmul_v2 + reshape + reshape + matmul_v2 + reshape
split_matmul_grad_to_matmul(
block, matmul_grad_id, self.dist_context, self.op_namescope
)

x = matmul_grad_op.input("X")
y = matmul_grad_op.input("Y")
out_grad = matmul_grad_op.input("Out@GRAD")
x_grad = matmul_grad_op.output("X@GRAD")
y_grad = matmul_grad_op.output("Y@GRAD")
op_role = matmul_grad_op.attr("op_role")

# NOTE(Ruibiao): Required OP scheduling order: matmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut).
# c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass
# adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then
# the matmul op before it cannot be overlapped.
var_x = block.var(x[0])
var_out_grad = block.var(out_grad[0])
var_y_grad = block.var(y_grad[0])

x_dims = var_x.shape
out_grad_dims = var_out_grad.shape
y_grad_dims = var_y_grad.shape

assert len(x_dims) == len(
out_grad_dims
), f"The rank of x must be equal to that of out_grad, but got x rank = {len(x_dims)} and out_grad rank = {len(out_grad_dims)}."
if len(x_dims) > 2:
assert (
x_dims[0:2] == out_grad_dims[0:2]
), f"The first two dimensions of x must be equal to that of out_grad, but got x_dims:{x_dims} and out_grad_dims:{out_grad_dims}."
new_x_dims = [x_dims[0] * x_dims[1]] + list(x_dims[2:])
new_out_grad_dims = [
out_grad_dims[0] * out_grad_dims[1]
] + list(out_grad_dims[2:])

# NOTE(Ruibiao): Why insert reshape op here?
# When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad.
# If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding.
# Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul.
new_x = self._insert_reshape_op(
block, allreduce_id + 1, x, new_x_dims, op_role
)
new_out_grad = self._insert_reshape_op(
block, allreduce_id + 2, out_grad, new_out_grad_dims, op_role
)
new_y_grad = block.create_var(
name=f"{y_grad[0]}@reshape.out",
dtype=var_y_grad.dtype,
persistable=False,
)
self.dist_context.set_tensor_dist_attr_for_program(
new_y_grad,
self.dist_context.get_tensor_dist_attr_for_program(var_y_grad),
)

matmul_grad_dist_attr = (
self.dist_context.get_op_dist_attr_for_program(matmul_grad_op)
)
matmul_op = block._insert_op_without_sync(
index=allreduce_id + 3,
type="matmul_v2",
inputs={"X": new_x, "Y": new_out_grad},
outputs={"Out": new_y_grad},
attrs={
"trans_x": True,
"trans_y": False,
"op_role": op_role,
'op_namescope': self.op_namescope,
},
)
self.dist_context.set_op_dist_attr_for_program(
matmul_op, matmul_grad_dist_attr
)

self._insert_reshape_op(
block,
allreduce_id + 4,
[new_y_grad.name],
y_grad_dims,
op_role,
y_grad,
allreduce_op_dist_attr = (
self.dist_context.get_op_dist_attr_for_program(allreduce_op)
)

matmul_op = block._insert_op_without_sync(
index=matmul_grad_id + 1,
type="matmul_v2",
inputs={"X": out_grad, "Y": y},
outputs={"Out": x_grad},
attrs={
"trans_x": False,
"trans_y": True,
"op_role": op_role,
'op_namescope': self.op_namescope,
},
allreduce_op_inputs = allreduce_op.desc.input_names()
allreduce_op_outputs = allreduce_op.desc.output_names()

allreduce_op_inputs = {
name: allreduce_op.input(name) for name in allreduce_op_inputs
}
allreduce_op_outputs = {
name: allreduce_op.output(name) for name in allreduce_op_outputs
}

allreduce_op = block._insert_op_without_sync(
index=allreduce_id + 1,
type=allreduce_op.type,
inputs=allreduce_op_inputs,
outputs=allreduce_op_outputs,
attrs=allreduce_op.all_attrs(),
)
self.dist_context.set_op_dist_attr_for_program(
matmul_op, matmul_grad_dist_attr
allreduce_op, allreduce_op_dist_attr
)
# Remove the original allreduce op
block._remove_op(allreduce_id + 5, sync=False)

block._remove_op(matmul_grad_id, sync=False)
block._sync_with_cpp()
Loading