Skip to content

Commit

Permalink
[Auto Parallel] Use the new completion algorithm (#39086)
Browse files Browse the repository at this point in the history
* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem

* [Auto Parallel] Update the completion algorithm

* Fix the bug of auto_searcher unittest
  • Loading branch information
aoyulong authored Jan 21, 2022
1 parent f68ef9d commit e5cda6f
Show file tree
Hide file tree
Showing 15 changed files with 686 additions and 940 deletions.
6 changes: 0 additions & 6 deletions python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost

Expand Down
1,414 changes: 575 additions & 839 deletions python/paddle/distributed/auto_parallel/completion.py

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,23 @@ def get_op_dist_attr_for_graph(self, serial_op_node):
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None
def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = serial_node.id()
dist_tensor = self._dist_tensors_for_graph.get(
serial_tensor_node_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
if serial_node.is_op() and serial_node.op() is not None:
serial_op_node_id = serial_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
return None

def init_dist_attr_for_program(self):
assert self._serial_program, \
Expand Down
15 changes: 8 additions & 7 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation
from .completion import Completer
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
Expand Down Expand Up @@ -130,8 +130,8 @@ def _generate_backward(self, main_program, startup_program, loss,
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_backward_annotation(main_program)

return params_grads

Expand All @@ -142,8 +142,8 @@ def _apply_optimize(self, main_program, startup_program, params_grads):
params_grads)

# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)

return optimize_ops

Expand Down Expand Up @@ -179,8 +179,9 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
# Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(serial_main_program,
self._dist_context)
self._completer = Completer(self._dist_context)
completed_main_program = self._completer.complete_forward_annotation(
serial_main_program)
else:
completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
Expand Down Expand Up @@ -154,10 +155,9 @@ def test_mlp_dp(self):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_mlp_mp(self):
Expand All @@ -171,10 +171,9 @@ def test_mlp_mp(self):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_mlp_dp_mp(self):
Expand All @@ -189,10 +188,9 @@ def test_mlp_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

# def test_mlp_misc(self):
Expand All @@ -212,8 +210,8 @@ def test_mlp_dp_mp(self):
# train_program, start_program = mlp_pretrain_forward(train_program,
# start_program)
# # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program,
# dist_context)
# completer = Completer(dist_context)
# complete_train_program = auto.completer.complete_forward_annotation(train_program)
# # print_program_with_dist_attr(complete_train_program,
# # dist_context)
# dist_context.finalize_distributed_attr_for_program(
Expand Down Expand Up @@ -423,8 +421,9 @@ def test_attn_dp(self):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program())
Expand All @@ -440,10 +439,9 @@ def test_attn_mp(self):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_attn_dp_mp(self):
Expand All @@ -458,10 +456,9 @@ def test_attn_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())


Expand Down Expand Up @@ -747,10 +744,9 @@ def test_decoder_dp(self):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_decoder_mp(self):
Expand All @@ -764,10 +760,9 @@ def test_decoder_mp(self):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_decoder_dp_mp(self):
Expand All @@ -782,10 +777,9 @@ def test_decoder_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from paddle.distributed.fleet import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
Expand Down Expand Up @@ -817,10 +818,9 @@ def test_gpt_dp(self):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_gpt_mp(self):
Expand All @@ -834,10 +834,9 @@ def test_gpt_mp(self):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_gpt_dp_mp(self):
Expand All @@ -852,10 +851,9 @@ def test_gpt_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
Expand Down Expand Up @@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context

# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)

params_grads = parallelizer._generate_backward(
complete_train_program,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import paddle
from paddle.fluid import core
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
Expand All @@ -42,8 +43,9 @@ def get_dist_prog(train_program,
parallelizer._dist_context = dist_context

# serial forward & backward completion
complete_train_program = auto.complete_annotation(
train_program, dist_context
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program
) if complete_train_program is None else complete_train_program

# parallelizer._apply_serial_forward_pass(complete_train_program,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from paddle.distributed import fleet

import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
Expand Down Expand Up @@ -433,6 +434,12 @@ def forward(self, input):
out = F.gelu(out, approximate=True)
out = self.linear1(out)

auto.shard_tensor(
out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out)
out = F.gelu(out, approximate=True)
out = self.linear3(out)
Expand Down Expand Up @@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context

# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)

params_grads = parallelizer._generate_backward(
complete_train_program,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
Expand All @@ -49,8 +50,9 @@ def get_programs(annotated_func):
global _global_process_mesh
dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)

rank_id = 3
dist_strategy = fleet.DistributedStrategy()
Expand Down
Loading

0 comments on commit e5cda6f

Please sign in to comment.