Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[auto parallel] Shard optimizer API #59342

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
dtensor_from_fn,
reshard,
shard_layer,
shard_optimizer,
)

from .fleet import BoxPSDataset # noqa: F401
Expand Down Expand Up @@ -157,4 +158,5 @@
"Shard",
"Replicate",
"Partial",
"shard_optimizer",
]
140 changes: 139 additions & 1 deletion python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Callable

import paddle
Expand Down Expand Up @@ -406,3 +406,141 @@ def replicate_layer_params_and_buffers(
"`paddle.distributed.shard_layer` only supports dynamic graph mode "
"now. It will be supported for static graph mode later."
)


class _ShardOptimizer:
def __init__(self, optimizer, shard_fn=None):
assert (
optimizer is not None
), "The argument `optimizer` cannot be empty."
assert isinstance(
optimizer, paddle.optimizer.AdamW
), "`paddle.distributed.ShardOptimizer` only supports AdamW optimizer for now."

self.target_block = (
paddle.base.framework.default_main_program().global_block()
)
optimizer.helper = paddle.base.layer_helper.LayerHelper(
optimizer.__class__.__name__
)
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
self._inner_opt = optimizer
self._shard_fn = shard_fn

def _shard_accumulator(self, param):
# create the accumulators
self._inner_opt._create_accumulators(self.target_block, [param])

target_name = param.name
if param.name in self._inner_opt._master_weights.keys():
target_name = self._inner_opt._master_weights[param.name].name

# shard the accumulators
for key in self._inner_opt._accumulators.keys():
accumulator = self._inner_opt._accumulators[key][target_name]
if accumulator.is_dist():
continue
if self._shard_fn is not None:
self._inner_opt._accumulators[key][
target_name
] = self._shard_fn(key, param, accumulator)
else:
if param.is_dist():
if 'beta' not in key:
# If param is a dist tensor should keep the shard info
# for accumulators except beta.
placements = param.placements
else:
# The beta should be replicated cross param's mesh
placements = [
dist.Replicate()
for _ in range(len(param.process_mesh.shape))
]
self._inner_opt._accumulators[key][
target_name
] = shard_tensor(
accumulator,
mesh=param.process_mesh,
placements=placements,
)

def step(self):
if not isinstance(self._inner_opt._parameter_list[0], dict):
params_grads = []
for param in self._inner_opt._parameter_list:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
for p, g in params_grads:
self._shard_accumulator(p)
self._inner_opt._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
else:
for param_group in self._inner_opt._param_groups:
params_grads = defaultdict(lambda: [])
for param in param_group['params']:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads['params'].append((param, grad_var))
params_grads.update(
{k: v for k, v in param_group.items() if k != 'params'}
)
for p, g in params_grads['params']:
self._shard_accumulator(p)
self._inner_opt._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)

def __getattr__(self, item):
return getattr(self._inner_opt, item)


def shard_optimizer(optimizer, shard_fn=None):
"""

Warp the global view optimizer to distributed view.

Note:
The `shard_fn` should have the following signature:
def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator

Args:
optimizer (paddle.optimizer.Optimizer): The optimizer to be sharded.
shard_fn (Callable, optional): The function to shard accumulators. If not specified,
we simply pass down the dist attr of the params.

Returns:
An optimizer with distributed view.

Examples:
.. code-block:: python

>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt)
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py

"""
return _ShardOptimizer(optimizer, shard_fn)
12 changes: 5 additions & 7 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,12 @@ def test_adamw_mp(self):
opt.clear_grad()
for key in opt._accumulators.keys():
for k, v in opt._accumulators[key].items():
if 'momentum' in key:
if 'moment' in key:
assert opt._accumulators[key][k].is_dist()
if 'w' in k:
assert opt._accumulators[key][k].shape == [10, 10]
assert opt._accumulators[key][k]._local_shape == [10, 5]
else:
assert opt._accumulators[key][k].shape == [10]
assert opt._accumulators[key][k]._local_shape == [5]
assert (
opt._accumulators[key][k].shape[-1]
== opt._accumulators[key][k]._local_shape[-1] * 2
)
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

Expand Down
176 changes: 176 additions & 0 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist


class TestSemiAutoParallelShardOptimizerAPI:
def __init__(self):
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def check_tensor_eq(self, a, b, rtol=1e-05, atol=0, verbose=True):
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol, verbose=verbose)

def get_single_card_rst(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.weight = linear.weight.numpy()
self.bias = linear.bias.numpy()

def shard_layer_fn(self, layer_name, layer, process_mesh):
layer.weight = dist.shard_tensor(
layer.weight, process_mesh, [dist.Shard(1)]
)
layer.bias = dist.shard_tensor(
layer.bias, process_mesh, [dist.Shard(0)]
)

def test_opt(self, opt):
for key in opt._accumulators.keys():
for k, v in opt._accumulators[key].items():
assert opt._accumulators[key][k].is_dist()
if 'moment' in key:
assert (
opt._accumulators[key][k].shape[-1]
== opt._accumulators[key][k]._local_shape[-1] * 2
)
else:
assert opt._accumulators[key][k].shape == [1]
assert opt._accumulators[key][k]._local_shape == [1]

def test_shard_optimizer_mp(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.shard_optimizer(opt)
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.test_opt(opt)
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

def test_shard_optimizer_from_non_shard_layer(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.shard_optimizer(opt)
for _ in range(5):
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.check_tensor_eq(self.weight, linear.weight.numpy())
self.check_tensor_eq(self.bias, linear.bias.numpy())

def shard_opt_fn(self, accumulator_name, param, accumulator):
if param.is_dist():
if 'beta' not in accumulator_name:
placements = param.placements
else:
placements = [
dist.Replicate()
for _ in range(len(param.process_mesh.shape))
]
return dist.shard_tensor(
accumulator, param.process_mesh, placements
)
return accumulator

def test_shard_optimizer_shard_fn(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.shard_optimizer(opt, self.shard_opt_fn)
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.test_opt(opt)

def test_shard_optimizer_master_params(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10], dtype="float16")
linear = paddle.amp.decorate(linear, level="O2", dtype="float16")
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
opt = paddle.optimizer.AdamW(
parameters=linear.parameters(), multi_precision=True
)
opt = dist.shard_optimizer(opt)
loss = linear(batch)
loss.backward()
opt.step()
self.test_opt(opt)
for k, v in opt._master_weights.items():
assert v.dtype == paddle.float32
assert v.is_dist()
assert v.shape[-1] == v._local_shape[-1] * 2

def test_shard_optimizer_params_group(self):
paddle.seed(self._seed)
linear = paddle.nn.Linear(10, 10)
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
linear.weight.optimize_attr = {'lr': 1}
linear.bias.optimize_attr = {'lr': 1}
params_group = [{'params': linear.weight}, {'params': linear.bias}]
opt = paddle.optimizer.AdamW(parameters=params_group)
opt = dist.shard_optimizer(opt)
loss = linear(batch)
loss.backward()
opt.step()
opt.clear_grad()
self.test_opt(opt)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")

self.get_single_card_rst()
self.test_shard_optimizer_params_group()
self.test_shard_optimizer_shard_fn()
if self._backend == "gpu":
self.test_shard_optimizer_master_params()
self.test_shard_optimizer_mp()
self.test_shard_optimizer_from_non_shard_layer()


if __name__ == '__main__':
TestSemiAutoParallelShardOptimizerAPI().run_test_case()
10 changes: 10 additions & 0 deletions test/auto_parallel/test_semi_auto_parallel_single_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def test_shard_optimizer(self):
user_defined_envs=envs,
)

def test_shard_optimizer_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_shard_optimizer_api.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()