Skip to content

Commit

Permalink
add chunk_id (#62884)
Browse files Browse the repository at this point in the history
  • Loading branch information
heavyrain-lzy authored Mar 21, 2024
1 parent 765c669 commit c937d8d
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/paddle/distributed/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def _insert_reshape_op(
x,
shape,
op_role,
chunk_id,
dist_context,
out=None,
op_namescope="/",
Expand Down Expand Up @@ -829,7 +830,7 @@ def _insert_reshape_op(
process_mesh=x_dist_attr.process_mesh,
ref_mapping=x_dist_attr.dims_mapping,
ctx=dist_context,
chunk_id=x_dist_attr.chunk_id,
chunk_id=chunk_id,
)

return out
Expand Down Expand Up @@ -881,12 +882,16 @@ def split_matmul_grad_to_matmul(
# When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad.
# If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding.
# Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul.
chunk_id = dist_context.get_op_dist_attr_for_program(
matmul_grad_op
).chunk_id
new_x = _insert_reshape_op(
block,
matmul_grad_id + 1,
x,
new_x_dims,
op_role,
chunk_id=chunk_id,
dist_context=dist_context,
op_namescope=op_namescope,
)
Expand All @@ -896,6 +901,7 @@ def split_matmul_grad_to_matmul(
out_grad,
new_out_grad_dims,
op_role,
chunk_id=chunk_id,
dist_context=dist_context,
op_namescope=op_namescope,
)
Expand Down Expand Up @@ -934,6 +940,7 @@ def split_matmul_grad_to_matmul(
[new_y_grad.name],
y_grad_dims,
op_role,
chunk_id=chunk_id,
dist_context=dist_context,
out=y_grad,
op_namescope=op_namescope,
Expand Down

0 comments on commit c937d8d

Please sign in to comment.