-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sequence Parallel Support Overlap #62284
Merged
Merged
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4a9981e
update sequence_parallel_utils.py
iosmers 252c4b5
add sp
iosmers ddc7dd9
add test
iosmers 7ec7cb5
add test
iosmers a2dcb60
add a.txt
iosmers 0290fb0
update shape
iosmers 8c95741
delete a.txt
iosmers 439ee99
add test
iosmers 99c39e6
add
iosmers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -227,6 +227,162 @@ def is_fused_matmul_bias_supported(): | |
return False | ||
|
||
|
||
def is_fused_linear_param_grad_add_supported(): | ||
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): | ||
return hasattr(paddle._C_ops, 'fused_linear_param_grad_add') | ||
else: | ||
return False | ||
|
||
|
||
class SPInnerOverlapLinear(paddle.autograd.PyLayer): | ||
@staticmethod | ||
def forward( | ||
ctx, | ||
x, | ||
weight, | ||
bias, | ||
fuse_matmul_bias, | ||
sp_fused_linear_param_grad_add, | ||
model_parallel_group, | ||
): | ||
ctx.sp_fused_linear_param_grad_add = sp_fused_linear_param_grad_add | ||
ctx.model_parallel_group = model_parallel_group | ||
|
||
world_size = model_parallel_group.nranks | ||
is_mp = world_size > 1 | ||
if is_mp: | ||
input_parallel = all_gather(x) | ||
else: | ||
input_parallel = x | ||
|
||
ctx.save_for_backward(x, weight, bias, input_parallel) | ||
if not fuse_matmul_bias: | ||
output = paddle._C_ops.linear(input_parallel, weight, bias) | ||
else: | ||
output = paddle._legacy_C_ops.fused_gemm_epilogue( | ||
input_parallel, weight, bias | ||
) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, dy): | ||
x, weight, bias, input_parallel = ctx.saved_tensor() | ||
parallelism = ctx.model_parallel_group.nranks | ||
|
||
if dy.dtype == weight.dtype: | ||
dinput_parallel = paddle.matmul(dy, weight, transpose_y=True) | ||
else: | ||
dinput_parallel = paddle.matmul( | ||
dy, paddle.cast(weight, dtype=dy.dtype), transpose_y=True | ||
) | ||
|
||
assert ( | ||
dinput_parallel.shape[0] % parallelism == 0 | ||
), "Input sequence length {} can't be divided exactly by sequence parallelism {}".format( | ||
dinput_parallel.shape[0], parallelism | ||
) | ||
|
||
dx_shape = dinput_parallel.shape | ||
dx_shape[0] = dx_shape[0] // parallelism | ||
dx = paddle.empty(shape=dx_shape, dtype=dinput_parallel.dtype) | ||
hcg = fleet.get_hybrid_communicate_group() | ||
group = hcg.get_model_parallel_group() | ||
task = dist.stream.reduce_scatter( | ||
dx, | ||
dinput_parallel, | ||
op=dist.ReduceOp.SUM, | ||
group=group, | ||
sync_op=False, | ||
) | ||
|
||
if ctx.sp_fused_linear_param_grad_add: | ||
if not is_fused_linear_param_grad_add_supported(): | ||
raise NotImplementedError( | ||
"You set sp_fused_linear_param_grad_add=True, " | ||
"however, the paddle you are using not support this operation. " | ||
"Please unset fused_linear_param_grad_add or use paddle compiled " | ||
"with cuda 11.6 or higher." | ||
) | ||
if bias is None: | ||
if hasattr(weight, "main_grad"): | ||
( | ||
weight.main_grad, | ||
_, | ||
) = paddle._C_ops.fused_linear_param_grad_add( | ||
input_parallel, dy, weight.main_grad, None, True, False | ||
) | ||
task.wait() | ||
return dx, None | ||
else: | ||
if weight.grad is not None: | ||
( | ||
weight.grad, | ||
_, | ||
) = paddle._C_ops.fused_linear_param_grad_add( | ||
input_parallel, dy, weight.grad, None, False, False | ||
) | ||
task.wait() | ||
return dx, None | ||
else: | ||
( | ||
dw, | ||
_, | ||
) = paddle._C_ops.fused_linear_param_grad_add( | ||
input_parallel, dy, None, None, False, False | ||
) | ||
task.wait() | ||
return dx, dw | ||
|
||
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"): | ||
( | ||
weight.main_grad, | ||
bias.main_grad, | ||
) = paddle._C_ops.fused_linear_param_grad_add( | ||
input_parallel, | ||
dy, | ||
weight.main_grad, | ||
bias.main_grad, | ||
True, | ||
True, | ||
) | ||
task.wait() | ||
return dx, None, None | ||
else: | ||
if weight.grad is not None: | ||
assert bias.grad is not None | ||
( | ||
weight.grad, | ||
bias.grad, | ||
) = paddle._C_ops.fused_linear_param_grad_add( | ||
input_parallel, dy, weight.grad, bias.grad, False, True | ||
) | ||
task.wait() | ||
return dx, None, None | ||
else: | ||
( | ||
dw, | ||
dbias, | ||
) = paddle._C_ops.fused_linear_param_grad_add( | ||
input_parallel, dy, None, None, False, True | ||
) | ||
task.wait() | ||
return dx, dw, dbias | ||
else: | ||
FeixLiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dy = dy.reshape([-1, dy.shape[-1]]) | ||
dw = paddle.matmul( | ||
input_parallel.reshape([-1, input_parallel.shape[-1]]), | ||
dy, | ||
transpose_x=True, | ||
) | ||
if bias is None: | ||
task.wait() | ||
return dx, dw | ||
else: | ||
dbias = paddle.sum(dy, axis=0) | ||
task.wait() | ||
return dx, dw, dbias | ||
|
||
|
||
class ColumnSequenceParallelLinear(Layer): | ||
def __init__( | ||
self, | ||
|
@@ -285,6 +441,7 @@ def __init__( | |
) | ||
|
||
self.weight.is_distributed = True if self.is_mp else False | ||
self.fuse_matmul_bias = fuse_matmul_bias | ||
|
||
if has_bias: | ||
# initialize bias to zero like Megatron | ||
|
@@ -312,18 +469,28 @@ def __init__( | |
|
||
self.linear = fused_linear | ||
|
||
# sp_configs = fleet.fleet._user_defined_strategy.hybrid_configs["sp_configs"] | ||
# self.sp_asyn_reduce_scatter = self.is_mp and sp_configs.sp_asyn_reduce_scatter | ||
|
||
# self.sp_fused_linear_param_grad_add = ( | ||
# self.is_mp | ||
# and sp_configs.sp_asyn_reduce_scatter | ||
# and sp_configs.sp_fused_linear_param_grad_add | ||
# ) | ||
|
||
self.sp_asyn_reduce_scatter = True | ||
self.sp_fused_linear_param_grad_add = True | ||
|
||
def forward(self, x): | ||
# sequence parallelism is same as model parallelism | ||
# if sequence parallel is true, input shape is [s, b, h] | ||
# else input shape is [b, s, h] | ||
if self.is_mp: | ||
input_parallel = AllGatherOp.apply(x) | ||
else: | ||
input_parallel = x | ||
output = self.linear( | ||
input_parallel, self.weight, self.bias, name=self._name | ||
# sequence parallelism is same as model parallelis, if sequence parallel is true, input shape is [s, b, h],else input shape is [b, s, h] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
return SPInnerOverlapLinear.apply( | ||
x, | ||
self.weight, | ||
self.bias, | ||
self.fuse_matmul_bias, | ||
self.sp_fused_linear_param_grad_add, | ||
self.model_parallel_group, | ||
) | ||
return output | ||
|
||
|
||
class MPScale(PyLayer): | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是不是直接assert is_mp就行了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改