Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mp_all_reduce asynchronize overlap. #55662

Merged
merged 15 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,27 @@ void BindDistributed(py::module *m) {
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())

.def(
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
"all_reduce_on_comm_stream",
[](distributed::ProcessGroup &self,
py::handle py_tensor,
distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dense =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
auto in_dense = *p_dense;
auto *out_dense = p_dense.get();
distributed::AllreduceOptions opts{op};
return self.AllReduce(out_dense,
in_dense,
opts,
/*sync_op*/ false,
/*use_calc_stream*/ false);
},
py::arg("tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())

.def(
"all_to_all_on_calc_stream",
[](distributed::ProcessGroup &self,
Expand Down
161 changes: 152 additions & 9 deletions python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import paddle
from paddle.autograd import PyLayer
from paddle.fluid import core
from paddle.nn import functional as F

from ....communication.reduce import ReduceOp, _get_reduce_op
from ...base import topology as tp
from . import mp_ops
from .random import get_rng_state_tracker
Expand Down Expand Up @@ -295,7 +297,8 @@ def __init__(

self.linear = F.linear

if fuse_matmul_bias:
self.fuse_matmul_bias = fuse_matmul_bias
if self.fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in ColumnParallelLinear, "
Expand All @@ -309,16 +312,156 @@ def __init__(

def forward(self, x):
# use inner api to process identity
if self.is_mp:
input_parallel = mp_ops._c_identity(
x, group=self.model_parallel_group
)

def _overlap_linear():
fuse_matmul_bias = self.fuse_matmul_bias

class InnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, x, weight, bias):
ctx.save_for_backward(x, weight, bias)
if (
str(os.getenv("Flags_skip_mp_c_identity")).lower()
GhostScreaming marked this conversation as resolved.
Show resolved Hide resolved
!= "true"
):
x = paddle._legacy_C_ops.c_identity(
x,
'use_calc_stream',
True,
'ring_id',
self.model_parallel_group.id,
'use_model_parallel',
True,
)
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(
x, weight, bias
)

@staticmethod
def backward(ctx, dy):
x, weight, bias = ctx.saved_tensors
GhostScreaming marked this conversation as resolved.
Show resolved Hide resolved
dx = paddle.matmul(dy, weight, transpose_y=True)
op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity")
task = self.model_parallel_group.process_group.all_reduce_on_comm_stream(
dx, op_type
)
# TODO(GhostScreaming): remove it in future.
tmp = paddle.ones([512])

if (
str(
os.getenv("Flags_fused_linear_param_grad_add")
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
).lower()
== "true"
):
if not hasattr(
paddle._C_ops, 'fused_linear_param_grad_add'
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
):
raise NotImplementedError(
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
"You set environment variable Flags_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset Flags_fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)

if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, False
)
task.wait()
return dx, dw

if hasattr(weight, "main_grad") and hasattr(
bias, "main_grad"
):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
x, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dw = paddle.matmul(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支引入了一些reshape,会不会导致一些模型变慢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的reshape只是改变了数据的逻辑shape,没有进行数据搬移。实测对性能是没啥影响的,timeline的kernel执行时间也是一致的。

x.reshape([-1, x.shape[-1]]),
dy.reshape([-1, dy.shape[-1]]),
transpose_x=True,
)
if bias is not None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias

return InnerOverlapLinear.apply(x, self.weight, self.bias)

if str(os.getenv("Flags_mp_aysnc_allreduce")).lower() == "true":
output_parallel = _overlap_linear()
else:
input_parallel = x
if self.is_mp:
input_parallel = mp_ops._c_identity(
x, group=self.model_parallel_group
)
else:
input_parallel = x

output_parallel = self.linear(
input_parallel, self.weight, self.bias, name=self._name
)
output_parallel = self.linear(
input_parallel, self.weight, self.bias, name=self._name
)

if self.gather_output and self.is_mp:
output = mp_ops._c_concat(
Expand Down
44 changes: 26 additions & 18 deletions python/paddle/distributed/fleet/layers/mpu/mp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
from paddle import _legacy_C_ops
from paddle.distributed import collective
Expand Down Expand Up @@ -45,15 +47,18 @@ def _c_identity(tensor, group=None):
class c_identity_eager(PyLayer):
@staticmethod
def forward(ctx, tensor):
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)
if str(os.getenv("Flags_skip_mp_c_identity")).lower() == "true":
return tensor
else:
return _legacy_C_ops.c_identity(
tensor,
'use_calc_stream',
True,
'ring_id',
group.id,
'use_model_parallel',
True,
)

@staticmethod
def backward(ctx, dy):
Expand Down Expand Up @@ -256,15 +261,18 @@ def forward(

@staticmethod
def backward(ctx, dy):
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)
if str(os.getenv("Flags_skip_mp_c_identity")).lower() == "true":
return dy
else:
return _legacy_C_ops.c_identity(
dy,
'use_calc_stream',
True,
'ring_id',
ctx.ring_id,
'use_model_parallel',
True,
)

return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel
Expand Down