From 0e70f000f61b3f641c77a33d548cdec6520f254a Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 12 Jul 2023 16:30:02 +0800 Subject: [PATCH 1/5] Tensor fusion for sharding stage 1: 1. paramters fusion for optimizer and broadcast 2. gradients fusion for allreduce --- .../framework/distributed_strategy.proto | 5 + .../dygraph_sharding_optimizer.py | 83 +++++++- .../sharding/group_sharded_storage.py | 14 ++ .../fleet/utils/tensor_fusion_helper.py | 201 ++++++++++++++++++ ...rid_parallel_sharding_model_with_fusion.py | 186 ++++++++++++++++ ...test_parallel_dygraph_sharding_parallel.py | 3 + 6 files changed, 481 insertions(+), 11 deletions(-) create mode 100644 python/paddle/distributed/fleet/utils/tensor_fusion_helper.py create mode 100644 test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index b18d7533779d5..74f377fd875de 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -66,6 +66,10 @@ message PpConfig { optional bool profiling = 5 [ default = false ]; } +message DygraphShardingConfig { + optional bool tensor_fusion = 1 [ default = false ]; +} + message HybridConfig { optional int32 dp_degree = 1 [ default = -1 ]; optional int32 mp_degree = 2 [ default = 1 ]; @@ -73,6 +77,7 @@ message HybridConfig { optional int32 sharding_degree = 4 [ default = 1 ]; optional MpConfig mp_configs = 5; optional PpConfig pp_configs = 6; + optional DygraphShardingConfig sharding_configs = 7; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index a26d56b2dc193..432005048cbdd 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -18,8 +18,10 @@ import paddle from paddle import framework +from paddle.distributed import fleet from ...utils.log_util import logger +from ...utils.tensor_fusion_helper import fused_parameters def _is_trainable(param): @@ -62,15 +64,48 @@ def __init__(self, optimizer, hcg): self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() self._sharding_rank = self._hcg.get_sharding_parallel_rank() + strategy = fleet.fleet._user_defined_strategy + self.tensor_fusion = strategy.hybrid_configs[ + 'sharding_configs' + ].tensor_fusion + self._rank2params = self._partition_parameters() self._param2rank = self._map_param_to_rank() - self._set_inner_opt_attr( - '_parameter_list', self._rank2params[self._sharding_rank] - ) - self._set_inner_opt_attr( - '_param_groups', self._rank2params[self._sharding_rank] - ) + if not self.tensor_fusion: + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] + ) + self._set_inner_opt_attr( + '_param_groups', self._rank2params[self._sharding_rank] + ) + else: + self._use_main_grad = hasattr(self._parameter_list[0], "main_grad") + self._rank2decay = {} + self._rank2fused = {} + self._tensor_fusion() + + decay_params = [ + p.name for p in self._rank2decay[self._sharding_rank] + ] + all_params = self._rank2fused[self._sharding_rank] + apply_decay_param_fun = lambda x: x in decay_params + + params = [] + for v in self._rank2fused.values(): + params += v + self._parameter_list = params + self._param_groups = params + + self._set_inner_opt_attr('_parameter_list', all_params) + self._set_inner_opt_attr('_param_groups', all_params) + origin_decay_param_fun = getattr( + self._inner_opt, '_apply_decay_param_fun', None + ) + if origin_decay_param_fun is not None: + self._set_inner_opt_attr( + '_apply_decay_param_fun', apply_decay_param_fun + ) def clear_grad(self, set_to_zero=True): """ @@ -85,7 +120,25 @@ def clear_grad(self, set_to_zero=True): p.main_grad._clear() p.main_grad = None elif not hasattr(p, "main_grad"): - p.clear_gradient(set_to_zero) + if self.tensor_fusion: + if set_to_zero: + p.grad.zero_() + else: + p.grad._clear() + p.grad = None + else: + p.clear_gradient(set_to_zero) + + def _tensor_fusion(self): + for i in range(self._sharding_world_size): + params = self._rank2params[i] + decay_fused, all_fused = fused_parameters( + params, self._use_main_grad + ) + self._rank2decay[i] = decay_fused + self._rank2fused[i] = all_fused + for p in all_fused: + self._param2rank[p.name] = i def _partition_parameters(self): """ @@ -167,7 +220,12 @@ def _sharding_sync_parameters(self): logger.debug("sharding start sync parameters") with framework.no_grad(): # TODO detach not need (?) - for rank, params in self._rank2params.items(): + valid_rank_to_params = ( + self._rank2params + if not self.tensor_fusion + else self._rank2fused + ) + for rank, params in valid_rank_to_params.items(): for param in params: paddle.distributed.broadcast( param, @@ -236,9 +294,12 @@ def step(self): params_grads = self._inner_opt._grad_clip(params_grads) # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize self._set_inner_opt_attr('_grad_clip', None) - update_param_names = [ - p.name for p in self._rank2params[self._sharding_rank] - ] + rank_params = ( + self._rank2params[self._sharding_rank] + if not self.tensor_fusion + else self._rank2fused[self._sharding_rank] + ) + update_param_names = [p.name for p in rank_params] update_params_grads = [ (p, g) for p, g in params_grads if p.name in update_param_names ] diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index ccd84693975d5..c63b5d69b6fe0 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -30,6 +30,14 @@ from .group_sharded_utils import Type, cvt_to_device, device_guard +class BufferWarper(core.eager.Tensor): + def __init__(self): + super().__init__() + self.need_clip = True + self.is_distributed = False + self.trainable = True + + class InternalStorage: """ This is a basic class, which is responsible for consolidating the basic storage tensor. @@ -97,6 +105,12 @@ def to(self, device, dtype=None, keep_alignment=True): self.buffer = self.buffer.cast(dtype=dtype) self._dtype = dtype + def warp_buffer(self): + tmp_buffer = BufferWarper() + self._buffer = self.buffer + tmp_buffer.get_tensor()._share_data_with(self.buffer.get_tensor()) + self.buffer = tmp_buffer + class ParamStorage(InternalStorage): """ diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py new file mode 100644 index 0000000000000..690fe37ad5286 --- /dev/null +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -0,0 +1,201 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections import OrderedDict + +import numpy as np + +import paddle +from paddle.framework import core + +alignment = { + "gpu": 256, +} + +align = { + paddle.float16.value: 2, + paddle.bfloat16.value: 2, + paddle.float32.value: 4, +} + + +def assign_group_by_size(parameters, group_size=256 * 1024 * 1024): + is_sparse_gradient = [False] * len(parameters) + + group_indices = core.eager_assign_group_by_size( + parameters, is_sparse_gradient, [group_size, group_size] + ) + + var_groups = OrderedDict() + for group_idx, indices in enumerate(group_indices): + for index in indices: + var_groups.setdefault(group_idx, []).append(parameters[index]) + return var_groups + + +def flatten_dense_tensors(parameters, use_main_grad): + from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import ( + GradStorage, + ParamStorage, + ) + + _buffer_size = 0 + _param2align = {} + dtype = parameters[0].dtype + + for param in parameters: + assert param.trainable, "param must be trainable..." + size = np.prod(param.shape) * align[dtype] + remaining = size % alignment["gpu"] + ali = 0 if remaining == 0 else alignment["gpu"] - remaining + align_ = ali // align[dtype] + _buffer_size += np.prod(param.shape) + align_ + _param2align[param.name] = align_ + + param_storage = ParamStorage(size=_buffer_size, dtype=dtype, device="gpu") + + param_storage.add_rank_params(parameters, _param2align) + + # process gradient + grad_dtype = paddle.float32 if use_main_grad else dtype + grad_storage = GradStorage( + size=_buffer_size, + dtype=grad_dtype, + device="gpu", + destination="0", + parm2align=_param2align, + ) + + for param in parameters: + grad_storage.add_grad(param, _param2align[param.name]) + + param_storage.warp_buffer() + grad_storage.warp_buffer() + + if not use_main_grad: + # param_storage --> grad_storage + param_storage.buffer._copy_gradient_from(grad_storage.buffer) + else: + param_storage.buffer.main_grad = grad_storage.buffer + param_storage.buffer.stop_gradient = False + return param_storage, grad_storage + + +def obtain_storage(parameters, use_main_grad, clip, dist): + if len(parameters) < 1: + return [] + + var_groups = assign_group_by_size(parameters) + storage = [] + for group_idx, parameters in var_groups.items(): + param_storage, grad_storage = flatten_dense_tensors( + parameters, use_main_grad + ) + param_storage.buffer.need_clip = clip + param_storage.buffer.is_distributed = dist + storage.append(param_storage.buffer) + return storage + + +def fused_parameters(parameters, use_main_grad): + logging.log( + '[Tensor Fusion] Fusing tensors into tensor chunks, it may take a while.' + ) + # filter for mp's distributed params + dist = list(filter(lambda x: x.is_distributed, parameters)) + no_dist = list(filter(lambda x: not x.is_distributed, parameters)) + + # filter for different dtype + fp32_dist = list(filter(lambda x: x.dtype == paddle.float32, dist)) + fp32_no_dist = list(filter(lambda x: x.dtype == paddle.float32, no_dist)) + no_fp32_dist = list(filter(lambda x: x.dtype != paddle.float32, dist)) + no_fp32_no_dist = list(filter(lambda x: x.dtype != paddle.float32, no_dist)) + + # no fp32 param's dtype should be the same + all_no_fp32_param = no_fp32_dist + no_fp32_no_dist + no_fp32_dtype = None + for p in all_no_fp32_param: + if no_fp32_dtype is None: + no_fp32_dtype = p.dtype + else: + assert ( + p.dtype == no_fp32_dtype + ), "Tensor fusion only support two different param dtypes." + + # filter for need clip + fp32_dist_clip = list( + filter(lambda x: getattr(x, 'need_clip', True), fp32_dist) + ) + fp32_no_dist_clip = list( + filter(lambda x: getattr(x, 'need_clip', True), fp32_no_dist) + ) + no_fp32_dist_clip = list( + filter(lambda x: getattr(x, 'need_clip', True), no_fp32_dist) + ) + no_fp32_no_dist_clip = list( + filter(lambda x: getattr(x, 'need_clip', True), no_fp32_no_dist) + ) + fp32_dist_no_clip = list( + filter(lambda x: not getattr(x, 'need_clip', True), fp32_dist) + ) + fp32_no_dist_no_clip = list( + filter(lambda x: not getattr(x, 'need_clip', True), fp32_no_dist) + ) + no_fp32_dist_no_clip = list( + filter(lambda x: not getattr(x, 'need_clip', True), no_fp32_dist) + ) + no_fp32_no_dist_no_clip = list( + filter(lambda x: not getattr(x, 'need_clip', True), no_fp32_no_dist) + ) + + param_groups = [ + fp32_dist_clip, + fp32_no_dist_clip, + no_fp32_dist_clip, + no_fp32_no_dist_clip, + fp32_dist_no_clip, + fp32_no_dist_no_clip, + no_fp32_dist_no_clip, + no_fp32_no_dist_no_clip, + ] + attrs = [ + [paddle.float32, True, True], + [paddle.float32, False, True], + [no_fp32_dtype, True, True], + [no_fp32_dtype, False, True], + [paddle.float32, True, False], + [paddle.float32, False, False], + [no_fp32_dtype, True, False], + [no_fp32_dtype, False, False], + ] + + decay_fused = [] + all_fused = [] + for params, attr in zip(param_groups, attrs): + decay_params = [] + other_params = [] + + for param in params: + if not any(nd in param.name for nd in ["bias", "norm", "b_0"]): + decay_params.append(param) + else: + other_params.append(param) + + decay = obtain_storage(decay_params, use_main_grad, attr[2], attr[1]) + other = obtain_storage(other_params, use_main_grad, attr[2], attr[1]) + decay_fused += decay + all_fused += decay + all_fused += other + + return decay_fused, all_fused diff --git a/test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py b/test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py new file mode 100644 index 0000000000000..310313119b4c3 --- /dev/null +++ b/test/collective/fleet/hybrid_parallel_sharding_model_with_fusion.py @@ -0,0 +1,186 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, +) + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 +STEPS = 10 + + +class SimpleDPNet(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistSharding(unittest.TestCase): + def setUp(self): + random.seed(2021) + np.random.seed(2021) + paddle.seed(2021) + + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": 2, + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + } + self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True + fleet.init(is_collective=True, strategy=self.strategy) + self.data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + + if paddle.distributed.get_rank() == 0: + self.batch_sharding = paddle.to_tensor(self.data[:2]) + else: + self.batch_sharding = paddle.to_tensor(self.data[2:]) + + self.batch_single = paddle.to_tensor(self.data) + + def train_batch(self, batch, model, optimizer): + output = model(batch) + loss = output.mean() + loss.backward() + optimizer.step() + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + clip = paddle.nn.ClipGradByGlobalNorm(0.5) + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.001, + weight_decay=0.001, + grad_clip=clip, + ) + return optimizer + + def build_model_optimizer(self): + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_a = self.build_optimizer(model_a) + + model_b = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer(model_b) + + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + return model_a, optimizer_a, model_b, optimizer_b + + def sharding_model(self): + ( + model_a, + optimizer_a, + model_b, + optimizer_b, + ) = self.build_model_optimizer() + + self.assertTrue( + isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer) + ) + + for idx in range(STEPS): + loss_a = self.train_batch(self.batch_sharding, model_a, optimizer_a) + loss_b = self.train_batch(self.batch_single, model_b, optimizer_b) + np.testing.assert_allclose(loss_a, loss_b, rtol=1e-6, atol=1e-6) + + for j in range(len(model_a.parameters())): + np.testing.assert_allclose( + model_a.parameters()[j].numpy(), + model_b.parameters()[j].numpy(), + rtol=1e-6, + atol=1e-7, + ) + + def test_sharding_adam(self): + self.sharding_model() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py index 857093ee7b44c..f3dbfb63bce36 100644 --- a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -22,6 +22,9 @@ class TestHybridParallel(TestMultipleGpus): def test_hybrid_parallel_sharding_logic(self): self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') + def test_hybrid_parallel_sharding_tensor_fusion(self): + self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion.py') + def test_hybrid_parallel_sharding_state_dict(self): self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py') From cc1474ab46ded039ff89bf0a797bdaaced0da2ba Mon Sep 17 00:00:00 2001 From: liuyuang Date: Mon, 17 Jul 2023 09:03:42 +0800 Subject: [PATCH 2/5] bug fix --- python/paddle/distributed/fleet/utils/tensor_fusion_helper.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 690fe37ad5286..f60b09f8498d1 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging from collections import OrderedDict import numpy as np @@ -109,9 +108,6 @@ def obtain_storage(parameters, use_main_grad, clip, dist): def fused_parameters(parameters, use_main_grad): - logging.log( - '[Tensor Fusion] Fusing tensors into tensor chunks, it may take a while.' - ) # filter for mp's distributed params dist = list(filter(lambda x: x.is_distributed, parameters)) no_dist = list(filter(lambda x: not x.is_distributed, parameters)) From 30abde4d0eb742e9c4272794c55e200e91f83083 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 18 Jul 2023 14:36:55 +0800 Subject: [PATCH 3/5] add assertion --- .../dygraph_optimizer/dygraph_sharding_optimizer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 432005048cbdd..02c8ca6092aa8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -68,6 +68,11 @@ def __init__(self, optimizer, hcg): self.tensor_fusion = strategy.hybrid_configs[ 'sharding_configs' ].tensor_fusion + pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap + if self.tensor_fusion: + assert ( + not pp_overlap + ), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time." self._rank2params = self._partition_parameters() self._param2rank = self._map_param_to_rank() From 89a47af0dfc335a210f45d0a57c375f6e4fb4560 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Tue, 18 Jul 2023 17:22:42 +0800 Subject: [PATCH 4/5] update filter logic --- .../fleet/utils/tensor_fusion_helper.py | 127 +++++++++--------- 1 file changed, 61 insertions(+), 66 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index f60b09f8498d1..4ff999f33260f 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools from collections import OrderedDict import numpy as np @@ -30,6 +31,7 @@ def assign_group_by_size(parameters, group_size=256 * 1024 * 1024): + # TODO(Yuang Liu): make pp_utils/utils use this tensor fusion helper is_sparse_gradient = [False] * len(parameters) group_indices = core.eager_assign_group_by_size( @@ -107,74 +109,61 @@ def obtain_storage(parameters, use_main_grad, clip, dist): return storage -def fused_parameters(parameters, use_main_grad): - # filter for mp's distributed params - dist = list(filter(lambda x: x.is_distributed, parameters)) - no_dist = list(filter(lambda x: not x.is_distributed, parameters)) - - # filter for different dtype - fp32_dist = list(filter(lambda x: x.dtype == paddle.float32, dist)) - fp32_no_dist = list(filter(lambda x: x.dtype == paddle.float32, no_dist)) - no_fp32_dist = list(filter(lambda x: x.dtype != paddle.float32, dist)) - no_fp32_no_dist = list(filter(lambda x: x.dtype != paddle.float32, no_dist)) - - # no fp32 param's dtype should be the same - all_no_fp32_param = no_fp32_dist + no_fp32_no_dist - no_fp32_dtype = None - for p in all_no_fp32_param: - if no_fp32_dtype is None: - no_fp32_dtype = p.dtype - else: - assert ( - p.dtype == no_fp32_dtype - ), "Tensor fusion only support two different param dtypes." - - # filter for need clip - fp32_dist_clip = list( - filter(lambda x: getattr(x, 'need_clip', True), fp32_dist) - ) - fp32_no_dist_clip = list( - filter(lambda x: getattr(x, 'need_clip', True), fp32_no_dist) - ) - no_fp32_dist_clip = list( - filter(lambda x: getattr(x, 'need_clip', True), no_fp32_dist) - ) - no_fp32_no_dist_clip = list( - filter(lambda x: getattr(x, 'need_clip', True), no_fp32_no_dist) - ) - fp32_dist_no_clip = list( - filter(lambda x: not getattr(x, 'need_clip', True), fp32_dist) - ) - fp32_no_dist_no_clip = list( - filter(lambda x: not getattr(x, 'need_clip', True), fp32_no_dist) +def filter_params(params, is_fp32, is_distributed, need_clip): + params = list( + filter( + lambda x: x.is_distributed + if is_distributed + else (not x.is_distributed), + params, + ) ) - no_fp32_dist_no_clip = list( - filter(lambda x: not getattr(x, 'need_clip', True), no_fp32_dist) + params = list( + filter( + lambda x: getattr(x, 'need_clip', True) + if need_clip + else (not getattr(x, 'need_clip', True)), + params, + ) ) - no_fp32_no_dist_no_clip = list( - filter(lambda x: not getattr(x, 'need_clip', True), no_fp32_no_dist) + params = list( + filter( + lambda x: x.dtype == paddle.float32 + if is_fp32 + else x.dtype != paddle.float32, + params, + ) ) + dtype = None + for p in params: + if dtype is None: + dtype = p.dtype + else: + assert dtype == p.dtype + + return params, dtype + + +def fused_parameters(parameters, use_main_grad): + param_groups = [] + attrs = [] - param_groups = [ - fp32_dist_clip, - fp32_no_dist_clip, - no_fp32_dist_clip, - no_fp32_no_dist_clip, - fp32_dist_no_clip, - fp32_no_dist_no_clip, - no_fp32_dist_no_clip, - no_fp32_no_dist_no_clip, - ] - attrs = [ - [paddle.float32, True, True], - [paddle.float32, False, True], - [no_fp32_dtype, True, True], - [no_fp32_dtype, False, True], - [paddle.float32, True, False], - [paddle.float32, False, False], - [no_fp32_dtype, True, False], - [no_fp32_dtype, False, False], - ] + is_fp32 = [True, False] + is_distributed = [True, False] + need_clip = [True, False] + + no_fp32_dtype = None + for fp32, dist, clip in itertools.product( + is_fp32, is_distributed, need_clip + ): + params, dtype = filter_params(parameters, fp32, dist, clip) + if not fp32: + if no_fp32_dtype is None: + no_fp32_dtype = dtype + else: + assert no_fp32_dtype == dtype + attrs.append([dtype, dist, clip]) + param_groups.append(params) decay_fused = [] all_fused = [] @@ -188,8 +177,14 @@ def fused_parameters(parameters, use_main_grad): else: other_params.append(param) - decay = obtain_storage(decay_params, use_main_grad, attr[2], attr[1]) - other = obtain_storage(other_params, use_main_grad, attr[2], attr[1]) + is_distributed = attr[1] + need_clip = attr[2] + decay = obtain_storage( + decay_params, use_main_grad, need_clip, is_distributed + ) + other = obtain_storage( + other_params, use_main_grad, need_clip, is_distributed + ) decay_fused += decay all_fused += decay all_fused += other From f7432462d8fd75066673474f03bb27caccc41be9 Mon Sep 17 00:00:00 2001 From: liuyuang Date: Wed, 19 Jul 2023 07:36:38 +0800 Subject: [PATCH 5/5] bug fixer --- python/paddle/distributed/fleet/utils/tensor_fusion_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 4ff999f33260f..e097b4686fbc8 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -160,7 +160,7 @@ def fused_parameters(parameters, use_main_grad): if not fp32: if no_fp32_dtype is None: no_fp32_dtype = dtype - else: + elif dtype is not None: assert no_fp32_dtype == dtype attrs.append([dtype, dist, clip]) param_groups.append(params)