Skip to content

Commit

Permalink
tensor fusion for param group (#57690)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Sep 26, 2023
1 parent b38d643 commit f359de5
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 23 deletions.
118 changes: 95 additions & 23 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,38 +349,17 @@ def filter_params(params, is_fp32, is_distributed, need_clip):
return params, dtype


def fused_parameters(
def _fused_parameters_impl(
parameters,
use_main_grad=False,
fuse_param=True,
comm_overlap=False,
comm_group=None,
act=None,
dst=-1,
acc_step=1,
scale_after_comm=False,
):
"""
Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled.
:param parameters: all parameters to be fused.
:param use_main_grad: does the gradient use main grad or not
:param comm_overlap: enable comm overlap or not
:param comm_group: the comm group for comm overlap
:param dst: the dst for comm overlap
:param acc_step: acc steps, using for comm overlap
:param fuse_param: fuse param or not
:param scale_after_comm: if enable comm overlap, specify the location of grad scale
:return: param storage if fused, comm buffers is comm overlap
"""
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
act = (
HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE
)
if comm_overlap:
assert comm_group is not None
if act == HOOK_ACTION.REDUCE:
assert dst != -1
elif act == HOOK_ACTION.ALL_REDUCE:
dst = -1
param_groups = []
attrs = []

Expand Down Expand Up @@ -449,3 +428,96 @@ def fused_parameters(
all_buffers += other_buffers

return decay_fused, all_fused, all_buffers


def fused_parameters(
parameters,
use_main_grad=False,
fuse_param=True,
comm_overlap=False,
comm_group=None,
act=None,
dst=-1,
acc_step=1,
scale_after_comm=False,
group_params=False,
):
"""
Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled.
:param parameters: all parameters to be fused.
:param use_main_grad: does the gradient use main grad or not
:param comm_overlap: enable comm overlap or not
:param comm_group: the comm group for comm overlap
:param act: the comm operation, could be chosen from reduce and allreduce
:param dst: the dst for comm overlap
:param acc_step: acc steps, using for comm overlap
:param fuse_param: fuse param or not
:param scale_after_comm: if enable comm overlap, specify the location of grad scale
:param group_params: the format of the input parameters is param group
:return: param storage if fused, comm buffers if comm overlap, param groups if use group params
"""
if act is None:
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
act = (
HOOK_ACTION.ALL_REDUCE
if not g_shard_use_reduce
else HOOK_ACTION.REDUCE
)
if comm_overlap:
if comm_group is None:
assert (
act == HOOK_ACTION.ALL_REDUCE
), "Only allreduce action can use default comm group"
comm_group = paddle.distributed.collective._get_default_group()
if act == HOOK_ACTION.REDUCE:
assert dst != -1
elif act == HOOK_ACTION.ALL_REDUCE:
dst = -1

if group_params:
updated_parameters = []
comm_buffers = []
for idx, group_param in enumerate(parameters):
assert isinstance(
group_param, dict
), "For group params, each group should be a dictionary."
assert (
'params' in group_param.keys()
), "For group params, each group should have parameters."
real_param = group_param['params']
(
group_decay_fused,
group_all_fused,
group_all_buffers,
) = _fused_parameters_impl(
real_param,
use_main_grad=use_main_grad,
fuse_param=fuse_param,
comm_overlap=comm_overlap,
comm_group=comm_group,
act=act,
dst=dst,
acc_step=acc_step,
scale_after_comm=scale_after_comm,
)
if comm_overlap:
comm_buffers.extend(group_all_buffers)
for fused_tensor in group_all_fused:
fused_tensor.optimize_attr = real_param[0].optimize_attr
group_param['params'] = group_all_fused
updated_parameters.append(group_param)
return updated_parameters, comm_buffers
else:
decay_fused, all_fused, all_buffers = _fused_parameters_impl(
parameters,
use_main_grad=use_main_grad,
fuse_param=fuse_param,
comm_overlap=comm_overlap,
comm_group=comm_group,
act=act,
dst=dst,
acc_step=acc_step,
scale_after_comm=scale_after_comm,
)

return decay_fused, all_fused, all_buffers
82 changes: 82 additions & 0 deletions test/collective/fleet/hybrid_parallel_tensor_fusion_with_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.utils.tensor_fusion_helper import (
HOOK_ACTION,
fused_parameters,
)


class SimpleDPNet(paddle.nn.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size):
super().__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
)

self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
)

self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
)

self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
)

def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x


class TestDistSharding(unittest.TestCase):
def setUp(self):
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 1,
"dp_degree": 2,
"mp_degree": 1,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=self.strategy)

def test_fusion(self):
model = SimpleDPNet(20, 10, 8, 10)
parameters = model.parameters()
parameters[0].optimize_attr = {'lr': 1}
param_group = [{'params': parameters}, {'params': parameters}]
fused_parameters(
param_group,
act=HOOK_ACTION.ALL_REDUCE,
comm_overlap=True,
group_params=True,
)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def test_hybrid_parallel_sharding_tensor_fusion_amp(self):
def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')

def test_group_param_tensor_fusion(self):
self.run_mnist_2gpu('hybrid_parallel_tensor_fusion_with_group.py')


if __name__ == "__main__":
unittest.main()

0 comments on commit f359de5

Please sign in to comment.