Skip to content

Commit

Permalink
change the code format
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Nov 29, 2023
1 parent a9d1e03 commit a0388dd
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 44 deletions.
4 changes: 2 additions & 2 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
dtensor_from_fn,
reshard,
shard_layer,
ShardOptimizer,
shard_optimizer,
)

from .fleet import BoxPSDataset # noqa: F401
Expand Down Expand Up @@ -158,5 +158,5 @@
"Shard",
"Replicate",
"Partial",
"ShardOptimizer",
"shard_optimizer",
]
77 changes: 40 additions & 37 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,43 +408,7 @@ def replicate_layer_params_and_buffers(
)


class ShardOptimizer:
"""
Warp the global view optimizer to distributed view.
The `shard_fn` should have the following signature:
def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator
Args:
optimizer (paddle.optimizer.Optimizer): The optimizer to be sharded.
shard_fn (optional, Callable): The function to shard accumulators. If not specified,
we simply pass down the dist attr of the params.
Method:
step(): same with optimzier.step()
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.ShardOptimizer(opt)
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""

class _ShardOptimizer:
def __init__(self, optimizer, shard_fn=None):
assert (
paddle.in_dynamic_mode()
Expand Down Expand Up @@ -536,3 +500,42 @@ def step(self):

def __getattr__(self, item):
return getattr(self._inner_opt, item)


def shard_optimizer(optimizer, shard_fn=None):
"""
Warp the global view optimizer to distributed view.
The `shard_fn` should have the following signature:
def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator
Args:
optimizer (paddle.optimizer.Optimizer): The optimizer to be sharded.
shard_fn (optional, Callable): The function to shard accumulators. If not specified,
we simply pass down the dist attr of the params.
Method:
step(): same with optimzier.step()
Examples:
.. code-block:: python
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
... def __init__(self):
... super().__init__()
... self.fc1 = paddle.nn.Linear(8, 8)
... self.fc2 = paddle.nn.Linear(8, 8)
...
... def forward(self, input):
... return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt)
>>> for _ in range(5):
>>> loss = layer(batch)
>>> loss.backward()
>>> opt.step()
>>> opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
"""
return _ShardOptimizer(optimizer, shard_fn)
10 changes: 5 additions & 5 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_shard_optimizer_mp(self):
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.ShardOptimizer(opt)
opt = dist.shard_optimizer(opt)
for _ in range(5):
loss = linear(batch)
loss.backward()
Expand All @@ -84,7 +84,7 @@ def test_shard_optimizer_from_non_shard_layer(self):
linear = paddle.nn.Linear(10, 10)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.ShardOptimizer(opt)
opt = dist.shard_optimizer(opt)
for _ in range(5):
loss = linear(batch)
loss.backward()
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_shard_optimizer_shard_fn(self):
dist.shard_layer(linear, self._mesh, self.shard_layer_fn)
batch = paddle.rand(shape=[10, 10])
opt = paddle.optimizer.AdamW(parameters=linear.parameters())
opt = dist.ShardOptimizer(opt, self.shard_opt_fn)
opt = dist.shard_optimizer(opt, self.shard_opt_fn)
loss = linear(batch)
loss.backward()
opt.step()
Expand All @@ -129,7 +129,7 @@ def test_shard_optimizer_master_params(self):
opt = paddle.optimizer.AdamW(
parameters=linear.parameters(), multi_precision=True
)
opt = dist.ShardOptimizer(opt)
opt = dist.shard_optimizer(opt)
loss = linear(batch)
loss.backward()
opt.step()
Expand All @@ -148,7 +148,7 @@ def test_shard_optimizer_params_group(self):
linear.bias.optimize_attr = {'lr': 1}
params_group = [{'params': linear.weight}, {'params': linear.bias}]
opt = paddle.optimizer.AdamW(parameters=params_group)
opt = dist.ShardOptimizer(opt)
opt = dist.shard_optimizer(opt)
loss = linear(batch)
loss.backward()
opt.step()
Expand Down

0 comments on commit a0388dd

Please sign in to comment.