Skip to content

Commit

Permalink
[AutoParallel] Add paddle.distributed.shard layer api (#57604)
Browse files Browse the repository at this point in the history
* def dtensor_from_fn first edition

* dtensor_from_fn first edition

* shard_layer api and utest(temporarily unavailable)

* shard_layer API and unit test preliminary complete

* complete the sample code modification according to ZhongKai's suggestion

* modify according to the review

* modify according to LiangGe's review

* Not approved yet, temporarily stored

* waiting for tensor to param

* Complete the modifications according to Weihang's review

* polish shard_layer api impl and doc

* add shard layer test

* rewrite unittest

* revert needless change

* polish doc

* add unittest for coverage

* add static branch and test

* polish en doc

* polish test details

* verify doc test demo

* Update python/paddle/distributed/auto_parallel/api.py

Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>

---------

Co-authored-by: yangxiaoyu14 <yangxiaoyu14@baidu.com>
Co-authored-by: zachary sun <70642955+sunzhongkai588@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 26, 2023
1 parent 6976947 commit 6c0f338
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from .auto_parallel.api import shard_tensor # noqa: F401
from .auto_parallel.api import dtensor_from_fn # noqa: F401
from .auto_parallel.api import reshard # noqa: F401
from .auto_parallel.api import shard_layer # noqa: F401

from .fleet import BoxPSDataset # noqa: F401

Expand Down Expand Up @@ -130,4 +131,5 @@
"shard_tensor",
"dtensor_from_fn",
"reshard",
"shard_layer",
]
164 changes: 164 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable

import paddle
import paddle.distributed as dist
from paddle import nn
from paddle.base.framework import EagerParamBase
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
Expand All @@ -24,6 +27,8 @@
# Some APIs have the same name with the previous APIs implementation, which are
# a temporary state, and the APIs here will eventually be used.

# Part1: Shard attributes related APIs


class DistAttr(core.TensorDistAttr):
"""
Expand Down Expand Up @@ -83,6 +88,9 @@ def sharding_specs(self):
return self._sharding_specs


# Part2: DistTensor construction related APIs


def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
):
Expand Down Expand Up @@ -184,6 +192,9 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
return shard_tensor(tensor, dist_attr=dist_attr)


# Part3: Data conversion related APIs


def reshard(dist_tensor, dist_attr):
"""
Reshard a distributed ``paddle.Tensor`` with given distributed attributes.
Expand Down Expand Up @@ -229,3 +240,156 @@ def reshard(dist_tensor, dist_attr):
raise RuntimeError(
"paddle.dist.reshard only support dynamic graph now. It will be supported for static graph later."
)


def shard_layer(
layer: nn.Layer,
process_mesh: dist.ProcessMesh,
shard_fn: Callable = None,
input_fn: Callable = None,
output_fn: Callable = None,
) -> nn.Layer:
"""
Converts all layer's parameters to DistTensor parameters according to
the `shard_fn` specified. It could also control the conversion of input
or output of the layer by specifying the `input_fn` and `output_fn`.
(i.e. convert the input to `paddle.Tensor` with DistTensor, convert output
back to `paddle.Tensor` with DenseTensor.)
The `shard_fn` should have the following signature:
def shard_fn(layer_name, layer, process_mesh) -> None
The `input_fn` should have the following signature:
def input_fn(inputs, process_mesh) -> list(paddle.Tensor)
In general, the type of `input_fn` return value is paddle.Tensor with DistTensor.
The `output_fn` should have the following signature:
def output_fn(outputs, process_mesh) -> list(paddle.Tensor)
In general, the type of `output_fn` return value is paddle.Tensor with DenseTensor.
Args:
layer (paddle.nn.Layer): The Layer object to be shard.
process_mesh (paddle.distributed.ProcessMesh): The `ProcessMesh` information
to be place the input `layer`.
shard_fn (Callable): The function to shard layer parameters across
the `process_mesh`. If not specified, by default we replicate
all parameters of the layer across the `process_mesh`.
input_fn (Callable): Specify how the input of the layer is sharded.
The `input_fn` will be registered for the Layer as a `forward pre-hook`.
By default we do not shard the input.
output_fn (Callable): Specify how the output of the layer is sharded or
convert it back to `paddle.Tensor` with DenseTensor.
The `output_fn` will be registered for the Layer as `forward post-hook`.
By default we do not shard or convert the output.
Returns:
Layer: A layer that contains parameters/buffers
that are all `paddle.Tensor` with DistTensor
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> def shard_fn(layer_name, layer, process_mesh):
... dist_attr = dist.DistAttr(mesh=process_mesh, sharding_specs=['x', None])
... if layer_name == 'fc1':
... layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr)
>>> layer = MLP()
>>> layer = dist.shard_layer(layer, mesh, shard_fn)
>>> print(layer)
>>> # This case need to be excuted in multi-card environment
>>> # export CUDA_VISIBLE_DEVICES=0,1
>>> # python -m paddle.distributed.launch {test_case}.py
"""
# Ensure that process_mesh is not an empty object
if process_mesh is None:
raise ValueError("The argument `process_mesh` cannot be empty.")

# Check the legality of process_mesh
if not isinstance(process_mesh, dist.ProcessMesh):
raise ValueError(
"The argument `process_mesh` is not `dist.ProcessMesh` type."
)

def replicate_layer_params_and_buffers(
layer: nn.Layer, mesh: dist.ProcessMesh
) -> None:
for key, param in layer._parameters.items():
if param is not None and not param.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(param.shape))],
)
layer.add_parameter(
key,
shard_tensor(param, dist_attr=replicated_dist_attr),
)
else:
# do nothing, the dist parameters has already been shard by shard_fn
pass
for key, buffer in layer._buffers.items():
if buffer is not None and not buffer.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(buffer.shape))],
)
layer.register_buffer(
key,
shard_tensor(buffer, dist_attr=replicated_dist_attr),
)
else:
# do nothing, the dist buffers has already been shard by shard_fn
pass

if paddle.in_dynamic_mode():
if shard_fn is None:
# if shard_fn not specified, by default replicate
# all layer's parameters and buffers
for name, sublayers in layer.named_sublayers(include_self=True):
replicate_layer_params_and_buffers(sublayers, process_mesh)
else:
# apply shard_fn to sublayers, contains self
for name, sublayers in layer.named_sublayers(include_self=True):
shard_fn(name, sublayers, process_mesh)
# shard_fn may not deal with all parameters and buffers,
# the parameters and buffers that are not shard by shard_fn
# still need to be shard to replicated
replicate_layer_params_and_buffers(sublayers, process_mesh)

# register input_fn as layer's forward pre hook
if input_fn is not None:
layer.register_forward_pre_hook(
lambda _, inputs: input_fn(inputs, process_mesh)
)
# register output_fn as layer's forward post hook
if output_fn is not None:
layer.register_forward_post_hook(
lambda _, inputs, outputs: output_fn(outputs, process_mesh)
)

return layer
else:
# TODO(chenweihang): Support static mode branch later.
raise NotImplementedError(
"`paddle.distributed.shard_layer` only supports dynamic graph mode "
"now. It will be supported for static graph mode later."
)
1 change: 1 addition & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_tensor MODULES test_dist_tensor)
py_test_modules(test_api_dist_branch MODULES test_api_dist_branch)
py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api)
py_test_modules(test_shard_layer_api MODULES test_shard_layer_api)
py_test_modules(test_cost_interface MODULES test_cost_interface)
# End of unittests WITH single card WITHOUT timeout

Expand Down
174 changes: 174 additions & 0 deletions test/auto_parallel/test_shard_layer_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
import paddle.distributed as dist
from paddle import nn


# TODO(chenweihang): test for paddle nn Layer API
class DemoLayer(nn.Layer):
def __init__(self, num_features):
super().__init__()
self.w0 = self.create_parameter(shape=[num_features, num_features])
self.w1 = self.create_parameter(shape=[num_features, num_features])

def forward(self, x):
y = paddle.matmul(x, self.w0)
z = paddle.matmul(y, self.w1)
return z


class MyLayer(nn.Layer):
def __init__(self, num_features, num_layers):
super().__init__()
self.seq = nn.Sequential(
*[DemoLayer(num_features) for _ in range(num_layers)]
)

def forward(self, x):
return self.seq(x)


class TestShardLayer(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
self.num_features = 10
self.num_layers = 10

def test_shard_layer_base(self):
layer = MyLayer(self.num_features, self.num_layers)

def shard_fn(layer_name, layer, process_mesh):
if isinstance(layer, nn.Linear):
for name, param in layer.named_parameters():
if 'weight' in name:
dist_param = dist.shard_tensor(
param,
dist_attr=dist.DistAttr(
mesh=process_mesh, sharding_specs=[None, None]
),
)
else:
dist_param = dist.shard_tensor(
param,
dist_attr=dist.DistAttr(
mesh=process_mesh, sharding_specs=[None]
),
)
layer.add_parameter(name, dist_param)

# test shard parameters
sharded_params_layer = dist.shard_layer(layer, self.mesh, shard_fn)

for param in sharded_params_layer.parameters():
self.assertTrue(param.is_dist())
for x in param.dist_attr.dims_mapping:
self.assertEqual(x, -1)

# test shard buffers
test_buffer = paddle.randn([10])
layer.register_buffer("test_buffer", test_buffer, persistable=True)
sharded_buffers_layer = dist.shard_layer(layer, self.mesh, shard_fn)
self.assertTrue(sharded_buffers_layer.test_buffer.is_dist())
self.assertEqual(
sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1]
)

def test_shard_layer_input_fn_and_output_fn(self):
layer = MyLayer(self.num_features, self.num_layers)

def input_fn(inputs, process_mesh):
return dist.shard_tensor(
inputs[0], dist_attr=dist.DistAttr(process_mesh, [None, None])
)

def output_fn(outputs, process_mesh):
assert outputs.is_dist()
# TODO(chenweihang): replace by dist.unshard_dtensor later
return paddle.to_tensor(outputs.numpy())

# test shard parameters
replicate_params_layer = dist.shard_layer(
layer, self.mesh, input_fn=input_fn, output_fn=output_fn
)

x = paddle.randn([5, self.num_features])
dense_out = replicate_params_layer(x)
self.assertTrue(dense_out.is_dense())

for param in replicate_params_layer.parameters():
self.assertTrue(param.is_dist())
for x in param.dist_attr.dims_mapping:
self.assertEqual(x, -1)

# test shard buffers
test_buffer = paddle.randn([10])
layer.register_buffer("test_buffer", test_buffer, persistable=True)
sharded_buffers_layer = dist.shard_layer(
layer, self.mesh, input_fn=input_fn, output_fn=output_fn
)
self.assertTrue(sharded_buffers_layer.test_buffer.is_dist())
self.assertEqual(
sharded_buffers_layer.test_buffer.dist_attr.dims_mapping, [-1]
)

def test_process_mesh_argument_error(self):
layer = MyLayer(self.num_features, self.num_layers)

exception = None
try:
dist.shard_layer(layer, None)
except ValueError as ex:
self.assertIn(
"The argument `process_mesh` cannot be empty",
str(ex),
)
exception = ex
self.assertIsNotNone(exception)

exception = None
try:
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=[None, None]
)
dist.shard_layer(layer, dist_attr)
except ValueError as ex:
self.assertIn(
"The argument `process_mesh` is not `dist.ProcessMesh` type",
str(ex),
)
exception = ex
self.assertIsNotNone(exception)

def test_shard_layer_static_mode(self):
paddle.enable_static()
layer = MyLayer(self.num_features, self.num_layers)

exception = None
try:
dist.shard_layer(layer, self.mesh)
except NotImplementedError as ex:
self.assertIn(
"`paddle.distributed.shard_layer` only supports dynamic graph mode now",
str(ex),
)
exception = ex
self.assertIsNotNone(exception)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6c0f338

Please sign in to comment.