Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding stage 1 tensor fusion #55427

Merged
merged 5 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,18 @@ message PpConfig {
optional bool profiling = 5 [ default = false ];
}

message DygraphShardingConfig {
optional bool tensor_fusion = 1 [ default = false ];
}

message HybridConfig {
optional int32 dp_degree = 1 [ default = -1 ];
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
optional MpConfig mp_configs = 5;
optional PpConfig pp_configs = 6;
optional DygraphShardingConfig sharding_configs = 7;
}

message AMPConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import paddle
from paddle import framework
from paddle.distributed import fleet

from ...utils.log_util import logger
from ...utils.tensor_fusion_helper import fused_parameters


def _is_trainable(param):
Expand Down Expand Up @@ -62,15 +64,53 @@ def __init__(self, optimizer, hcg):
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank()

strategy = fleet.fleet._user_defined_strategy
self.tensor_fusion = strategy.hybrid_configs[
'sharding_configs'
].tensor_fusion
pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap
if self.tensor_fusion:
assert (
not pp_overlap
), "Can not enable pp's sharding_comm_overlap and sharding's tensor_fusion at the same time."

self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()

self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
if not self.tensor_fusion:
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
else:
self._use_main_grad = hasattr(self._parameter_list[0], "main_grad")
self._rank2decay = {}
self._rank2fused = {}
self._tensor_fusion()

decay_params = [
p.name for p in self._rank2decay[self._sharding_rank]
]
all_params = self._rank2fused[self._sharding_rank]
apply_decay_param_fun = lambda x: x in decay_params

params = []
for v in self._rank2fused.values():
params += v
self._parameter_list = params
self._param_groups = params

self._set_inner_opt_attr('_parameter_list', all_params)
self._set_inner_opt_attr('_param_groups', all_params)
origin_decay_param_fun = getattr(
self._inner_opt, '_apply_decay_param_fun', None
)
if origin_decay_param_fun is not None:
self._set_inner_opt_attr(
'_apply_decay_param_fun', apply_decay_param_fun
)

def clear_grad(self, set_to_zero=True):
"""
Expand All @@ -85,7 +125,25 @@ def clear_grad(self, set_to_zero=True):
p.main_grad._clear()
p.main_grad = None
elif not hasattr(p, "main_grad"):
p.clear_gradient(set_to_zero)
if self.tensor_fusion:
if set_to_zero:
p.grad.zero_()
else:
p.grad._clear()
p.grad = None
else:
p.clear_gradient(set_to_zero)

def _tensor_fusion(self):
for i in range(self._sharding_world_size):
params = self._rank2params[i]
decay_fused, all_fused = fused_parameters(
params, self._use_main_grad
)
self._rank2decay[i] = decay_fused
self._rank2fused[i] = all_fused
for p in all_fused:
self._param2rank[p.name] = i

def _partition_parameters(self):
"""
Expand Down Expand Up @@ -167,7 +225,12 @@ def _sharding_sync_parameters(self):
logger.debug("sharding start sync parameters")
with framework.no_grad():
# TODO detach not need (?)
for rank, params in self._rank2params.items():
valid_rank_to_params = (
self._rank2params
if not self.tensor_fusion
else self._rank2fused
)
for rank, params in valid_rank_to_params.items():
for param in params:
paddle.distributed.broadcast(
param,
Expand Down Expand Up @@ -236,9 +299,12 @@ def step(self):
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)
update_param_names = [
p.name for p in self._rank2params[self._sharding_rank]
]
rank_params = (
self._rank2params[self._sharding_rank]
if not self.tensor_fusion
else self._rank2fused[self._sharding_rank]
)
update_param_names = [p.name for p in rank_params]
update_params_grads = [
(p, g) for p, g in params_grads if p.name in update_param_names
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
from .group_sharded_utils import Type, cvt_to_device, device_guard


class BufferWarper(core.eager.Tensor):
def __init__(self):
super().__init__()
self.need_clip = True
self.is_distributed = False
self.trainable = True


class InternalStorage:
"""
This is a basic class, which is responsible for consolidating the basic storage tensor.
Expand Down Expand Up @@ -97,6 +105,12 @@ def to(self, device, dtype=None, keep_alignment=True):
self.buffer = self.buffer.cast(dtype=dtype)
self._dtype = dtype

def warp_buffer(self):
tmp_buffer = BufferWarper()
self._buffer = self.buffer
tmp_buffer.get_tensor()._share_data_with(self.buffer.get_tensor())
self.buffer = tmp_buffer


class ParamStorage(InternalStorage):
"""
Expand Down
192 changes: 192 additions & 0 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from collections import OrderedDict

import numpy as np

import paddle
from paddle.framework import core

alignment = {
"gpu": 256,
}

align = {
paddle.float16.value: 2,
paddle.bfloat16.value: 2,
paddle.float32.value: 4,
}


def assign_group_by_size(parameters, group_size=256 * 1024 * 1024):
# TODO(Yuang Liu): make pp_utils/utils use this tensor fusion helper
is_sparse_gradient = [False] * len(parameters)

group_indices = core.eager_assign_group_by_size(
parameters, is_sparse_gradient, [group_size, group_size]
)

var_groups = OrderedDict()
for group_idx, indices in enumerate(group_indices):
for index in indices:
var_groups.setdefault(group_idx, []).append(parameters[index])
return var_groups


def flatten_dense_tensors(parameters, use_main_grad):
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import (
GradStorage,
ParamStorage,
)

_buffer_size = 0
_param2align = {}
dtype = parameters[0].dtype

for param in parameters:
assert param.trainable, "param must be trainable..."
size = np.prod(param.shape) * align[dtype]
remaining = size % alignment["gpu"]
ali = 0 if remaining == 0 else alignment["gpu"] - remaining
align_ = ali // align[dtype]
_buffer_size += np.prod(param.shape) + align_
_param2align[param.name] = align_

param_storage = ParamStorage(size=_buffer_size, dtype=dtype, device="gpu")

param_storage.add_rank_params(parameters, _param2align)

# process gradient
grad_dtype = paddle.float32 if use_main_grad else dtype
grad_storage = GradStorage(
size=_buffer_size,
dtype=grad_dtype,
device="gpu",
destination="0",
parm2align=_param2align,
)

for param in parameters:
grad_storage.add_grad(param, _param2align[param.name])

param_storage.warp_buffer()
grad_storage.warp_buffer()

if not use_main_grad:
# param_storage --> grad_storage
param_storage.buffer._copy_gradient_from(grad_storage.buffer)
else:
param_storage.buffer.main_grad = grad_storage.buffer
param_storage.buffer.stop_gradient = False
return param_storage, grad_storage


def obtain_storage(parameters, use_main_grad, clip, dist):
if len(parameters) < 1:
return []

var_groups = assign_group_by_size(parameters)
storage = []
for group_idx, parameters in var_groups.items():
param_storage, grad_storage = flatten_dense_tensors(
parameters, use_main_grad
)
param_storage.buffer.need_clip = clip
param_storage.buffer.is_distributed = dist
storage.append(param_storage.buffer)
return storage


def filter_params(params, is_fp32, is_distributed, need_clip):
params = list(
filter(
lambda x: x.is_distributed
if is_distributed
else (not x.is_distributed),
params,
)
)
params = list(
filter(
lambda x: getattr(x, 'need_clip', True)
if need_clip
else (not getattr(x, 'need_clip', True)),
params,
)
)
params = list(
filter(
lambda x: x.dtype == paddle.float32
if is_fp32
else x.dtype != paddle.float32,
params,
)
)
dtype = None
for p in params:
if dtype is None:
dtype = p.dtype
else:
assert dtype == p.dtype

return params, dtype


def fused_parameters(parameters, use_main_grad):
param_groups = []
attrs = []

is_fp32 = [True, False]
is_distributed = [True, False]
need_clip = [True, False]

no_fp32_dtype = None
for fp32, dist, clip in itertools.product(
is_fp32, is_distributed, need_clip
):
params, dtype = filter_params(parameters, fp32, dist, clip)
if not fp32:
if no_fp32_dtype is None:
no_fp32_dtype = dtype
elif dtype is not None:
assert no_fp32_dtype == dtype
attrs.append([dtype, dist, clip])
param_groups.append(params)

decay_fused = []
all_fused = []
for params, attr in zip(param_groups, attrs):
decay_params = []
other_params = []

for param in params:
if not any(nd in param.name for nd in ["bias", "norm", "b_0"]):
decay_params.append(param)
else:
other_params.append(param)

is_distributed = attr[1]
need_clip = attr[2]
decay = obtain_storage(
decay_params, use_main_grad, need_clip, is_distributed
)
other = obtain_storage(
other_params, use_main_grad, need_clip, is_distributed
)
decay_fused += decay
all_fused += decay
all_fused += other

return decay_fused, all_fused
Loading