Skip to content

Commit

Permalink
Add mp_all_reduce asynchronize overlap. (#55662)
Browse files Browse the repository at this point in the history
* [WIP] Add mp_all_reduce asynchronize overlap.

* Fix some problems.

* Fix dw compute bug, and use a temporary solution to achieve overlap.

* Use fused_linear_param_grad_add to compute dw.

* Reformat ColumnParallel _overlap_linear. Use environment flags to
control following behaviors:
1. export Flags_mp_aysnc_allreduce=True to turn on mp async all_reduce
2. export Flags_skip_mp_c_identity=True to skip two c_identity operators
   in dygraph mode.
3. export Flags_fused_linear_param_grad_add to enable fused_linear_param_grad_add
   in ColumnParallel backward with mp async all_reduce.

* Polish code.

* Remove useless communication API.

* Fix some problems in mp_async_all_reduce and skip_c_identity.

* Add test cases.

* Remove environment variable Flags_fused_linear_param_grad_add in test case.

* Reset error threshold.

* Reset threshold in test case.

* Add useful log. Remove useless test cases.
  • Loading branch information
GhostScreaming authored Aug 16, 2023
1 parent a8981be commit 6b1dfb5
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 27 deletions.
161 changes: 152 additions & 9 deletions python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from paddle.fluid import core
from paddle.nn import functional as F

from ....communication.reduce import ReduceOp, _get_reduce_op
from ...base import topology as tp
from . import mp_ops
from .mp_ops import _get_mp_env_flag
from .random import get_rng_state_tracker

__all__ = []
Expand All @@ -32,6 +34,13 @@ def is_fused_matmul_bias_supported():
return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue')


def is_fused_linear_param_grad_add_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(paddle._C_ops, 'fused_linear_param_grad_add')
else:
return False


class VocabParallelEmbedding(paddle.nn.Layer):
"""Embedding mp parallelized in the vocabulary dimension.
this class is used for splitting embedding in mp group.
Expand Down Expand Up @@ -295,7 +304,8 @@ def __init__(

self.linear = F.linear

if fuse_matmul_bias:
self.fuse_matmul_bias = fuse_matmul_bias
if self.fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in ColumnParallelLinear, "
Expand All @@ -309,16 +319,149 @@ def __init__(

def forward(self, x):
# use inner api to process identity
if self.is_mp:
input_parallel = mp_ops._c_identity(
x, group=self.model_parallel_group
)

def _overlap_linear():
fuse_matmul_bias = self.fuse_matmul_bias

class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias):
ctx.save_for_backward(x, weight, bias)
if (
_get_mp_env_flag("Flags_mp_aysnc_allreduce")
and _get_mp_env_flag("Flags_skip_mp_c_identity")
) is False:
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
True,
'ring_id',
self.model_parallel_group.id,
'use_model_parallel',
True,
)
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(
x, weight, bias
)

@staticmethod
def backward(ctx, dy):
x, weight, bias = ctx.saved_tensor()
dx = paddle.matmul(dy, weight, transpose_y=True)
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
task = self.model_parallel_group.process_group.all_reduce(
dx, op_type, sync_op=False
)
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])

if _get_mp_env_flag("Flags_fused_linear_param_grad_add"):
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set environment variable Flags_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset Flags_fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)

if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, False
)
task.wait()
return dx, dw

if hasattr(weight, "main_grad") and hasattr(
bias, "main_grad"
):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dw = paddle.matmul(
x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if bias is None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias

return InnerOverlapLinear.apply(x, self.weight, self.bias)

if _get_mp_env_flag("Flags_mp_aysnc_allreduce"):
output_parallel = _overlap_linear()
else:
input_parallel = x
if self.is_mp:
input_parallel = mp_ops._c_identity(
x, group=self.model_parallel_group
)
else:
input_parallel = x

output_parallel = self.linear(
input_parallel, self.weight, self.bias, name=self._name
)
output_parallel = self.linear(
input_parallel, self.weight, self.bias, name=self._name
)

if self.gather_output and self.is_mp:
output = mp_ops._c_concat(
Expand Down
80 changes: 62 additions & 18 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
from paddle import _legacy_C_ops
from paddle.distributed import collective
Expand All @@ -22,6 +24,38 @@

from ....communication.reduce import ReduceOp, _get_reduce_op

_first_get_mp_env_flag = True


def _get_mp_env_flag(flag):
global _first_get_mp_env_flag
if _first_get_mp_env_flag:
print(
"Flags_mp_aysnc_allreduce is {}, which is used to support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear.".format(
str(os.getenv("Flags_mp_aysnc_allreduce")).lower()
)
)
print(
"Flags_fused_linear_param_grad_add is {}, which is used to support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_fused_linear_param_grad_add")).lower()
)
)
print(
"Flags_skip_mp_c_identity is {}, which is used to support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.".format(
str(os.getenv("Flags_skip_mp_c_identity")).lower()
)
)
# Model parallel environment flag.
# Flags_mp_aysnc_allreduce: support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear
# Flags_fused_linear_param_grad_add: support fused_linear_param_grad_add in ColumnParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
# Flags_skip_mp_c_identity: support skip c_identity in ColumnParallelLinear and RowParallelLinear. Only works when Flags_mp_aysnc_allreduce is True.
assert flag in [
"Flags_mp_aysnc_allreduce",
"Flags_fused_linear_param_grad_add",
"Flags_skip_mp_c_identity",
], "Only support set Flags_mp_aysnc_allreduce (support all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear), Flags_fused_linear_param_grad_add (support fused_linear_param_grad_add in ColumnParallelLinear) and Flags_skip_mp_c_identity (support skip c_identity in ColumnParallelLinear with Flags_mp_aysnc_allreduce=True, and skip c_identity in RowParallelLinear)"
return str(os.getenv(flag)).lower() in ["true", "1"]


def _c_identity(tensor, group=None):
"""
Expand All @@ -45,15 +79,20 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
if _get_mp_env_flag(
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return tensor
else:
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)

@staticmethod
def backward(ctx, dy):
Expand Down Expand Up @@ -256,15 +295,20 @@ def forward(

@staticmethod
def backward(ctx, dy):
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
if _get_mp_env_flag(
"Flags_mp_aysnc_allreduce"
) and _get_mp_env_flag("Flags_skip_mp_c_identity"):
return dy
else:
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)

return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel
Expand Down

0 comments on commit 6b1dfb5

Please sign in to comment.