diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 3cb993790d2b2..548eae655cce5 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -349,38 +349,17 @@ def filter_params(params, is_fp32, is_distributed, need_clip): return params, dtype -def fused_parameters( +def _fused_parameters_impl( parameters, use_main_grad=False, fuse_param=True, comm_overlap=False, comm_group=None, + act=None, dst=-1, acc_step=1, scale_after_comm=False, ): - """ - Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled. - :param parameters: all parameters to be fused. - :param use_main_grad: does the gradient use main grad or not - :param comm_overlap: enable comm overlap or not - :param comm_group: the comm group for comm overlap - :param dst: the dst for comm overlap - :param acc_step: acc steps, using for comm overlap - :param fuse_param: fuse param or not - :param scale_after_comm: if enable comm overlap, specify the location of grad scale - :return: param storage if fused, comm buffers is comm overlap - """ - g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) - act = ( - HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE - ) - if comm_overlap: - assert comm_group is not None - if act == HOOK_ACTION.REDUCE: - assert dst != -1 - elif act == HOOK_ACTION.ALL_REDUCE: - dst = -1 param_groups = [] attrs = [] @@ -449,3 +428,96 @@ def fused_parameters( all_buffers += other_buffers return decay_fused, all_fused, all_buffers + + +def fused_parameters( + parameters, + use_main_grad=False, + fuse_param=True, + comm_overlap=False, + comm_group=None, + act=None, + dst=-1, + acc_step=1, + scale_after_comm=False, + group_params=False, +): + """ + Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled. + :param parameters: all parameters to be fused. + :param use_main_grad: does the gradient use main grad or not + :param comm_overlap: enable comm overlap or not + :param comm_group: the comm group for comm overlap + :param act: the comm operation, could be chosen from reduce and allreduce + :param dst: the dst for comm overlap + :param acc_step: acc steps, using for comm overlap + :param fuse_param: fuse param or not + :param scale_after_comm: if enable comm overlap, specify the location of grad scale + :param group_params: the format of the input parameters is param group + :return: param storage if fused, comm buffers if comm overlap, param groups if use group params + """ + if act is None: + g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1)) + act = ( + HOOK_ACTION.ALL_REDUCE + if not g_shard_use_reduce + else HOOK_ACTION.REDUCE + ) + if comm_overlap: + if comm_group is None: + assert ( + act == HOOK_ACTION.ALL_REDUCE + ), "Only allreduce action can use default comm group" + comm_group = paddle.distributed.collective._get_default_group() + if act == HOOK_ACTION.REDUCE: + assert dst != -1 + elif act == HOOK_ACTION.ALL_REDUCE: + dst = -1 + + if group_params: + updated_parameters = [] + comm_buffers = [] + for idx, group_param in enumerate(parameters): + assert isinstance( + group_param, dict + ), "For group params, each group should be a dictionary." + assert ( + 'params' in group_param.keys() + ), "For group params, each group should have parameters." + real_param = group_param['params'] + ( + group_decay_fused, + group_all_fused, + group_all_buffers, + ) = _fused_parameters_impl( + real_param, + use_main_grad=use_main_grad, + fuse_param=fuse_param, + comm_overlap=comm_overlap, + comm_group=comm_group, + act=act, + dst=dst, + acc_step=acc_step, + scale_after_comm=scale_after_comm, + ) + if comm_overlap: + comm_buffers.extend(group_all_buffers) + for fused_tensor in group_all_fused: + fused_tensor.optimize_attr = real_param[0].optimize_attr + group_param['params'] = group_all_fused + updated_parameters.append(group_param) + return updated_parameters, comm_buffers + else: + decay_fused, all_fused, all_buffers = _fused_parameters_impl( + parameters, + use_main_grad=use_main_grad, + fuse_param=fuse_param, + comm_overlap=comm_overlap, + comm_group=comm_group, + act=act, + dst=dst, + acc_step=acc_step, + scale_after_comm=scale_after_comm, + ) + + return decay_fused, all_fused, all_buffers diff --git a/test/collective/fleet/hybrid_parallel_tensor_fusion_with_group.py b/test/collective/fleet/hybrid_parallel_tensor_fusion_with_group.py new file mode 100644 index 0000000000000..62df6613e51b7 --- /dev/null +++ b/test/collective/fleet/hybrid_parallel_tensor_fusion_with_group.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.utils.tensor_fusion_helper import ( + HOOK_ACTION, + fused_parameters, +) + + +class SimpleDPNet(paddle.nn.Layer): + def __init__(self, vocab_size, hidden_size, inner_size, output_size): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + ) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistSharding(unittest.TestCase): + def setUp(self): + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": 1, + "dp_degree": 2, + "mp_degree": 1, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=self.strategy) + + def test_fusion(self): + model = SimpleDPNet(20, 10, 8, 10) + parameters = model.parameters() + parameters[0].optimize_attr = {'lr': 1} + param_group = [{'params': parameters}, {'params': parameters}] + fused_parameters( + param_group, + act=HOOK_ACTION.ALL_REDUCE, + comm_overlap=True, + group_params=True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py index 91cffab861f9e..f6152a13e10c1 100644 --- a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -39,6 +39,9 @@ def test_hybrid_parallel_sharding_tensor_fusion_amp(self): def test_hybrid_parallel_sharding_state_dict(self): self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py') + def test_group_param_tensor_fusion(self): + self.run_mnist_2gpu('hybrid_parallel_tensor_fusion_with_group.py') + if __name__ == "__main__": unittest.main()