Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Use the new completion algorithm #39086

Merged
merged 14 commits into from
Jan 21, 2022
6 changes: 0 additions & 6 deletions python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost

Expand Down
1,414 changes: 575 additions & 839 deletions python/paddle/distributed/auto_parallel/completion.py

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,23 @@ def get_op_dist_attr_for_graph(self, serial_op_node):
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None
def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = serial_node.id()
dist_tensor = self._dist_tensors_for_graph.get(
serial_tensor_node_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
if serial_node.is_op() and serial_node.op() is not None:
serial_op_node_id = serial_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
return None

def init_dist_attr_for_program(self):
assert self._serial_program, \
Expand Down
15 changes: 8 additions & 7 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation, complete_update_annotation
from .completion import Completer
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
Expand Down Expand Up @@ -130,8 +130,8 @@ def _generate_backward(self, main_program, startup_program, loss,
no_grad_set,
callbacks,
distop_context=self._dist_context.dist_op_context)
complete_backward_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_backward_annotation(main_program)

return params_grads

Expand All @@ -142,8 +142,8 @@ def _apply_optimize(self, main_program, startup_program, params_grads):
params_grads)

# update completion
complete_update_annotation(
main_program, dist_context=self._dist_context)
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)

return optimize_ops

Expand Down Expand Up @@ -179,8 +179,9 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
# Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(serial_main_program,
self._dist_context)
self._completer = Completer(self._dist_context)
completed_main_program = self._completer.complete_forward_annotation(
serial_main_program)
else:
completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
Expand Down Expand Up @@ -154,10 +155,9 @@ def test_mlp_dp(self):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_mlp_mp(self):
Expand All @@ -171,10 +171,9 @@ def test_mlp_mp(self):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_mlp_dp_mp(self):
Expand All @@ -189,10 +188,9 @@ def test_mlp_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

# def test_mlp_misc(self):
Expand All @@ -212,8 +210,8 @@ def test_mlp_dp_mp(self):
# train_program, start_program = mlp_pretrain_forward(train_program,
# start_program)
# # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program,
# dist_context)
# completer = Completer(dist_context)
# complete_train_program = auto.completer.complete_forward_annotation(train_program)
# # print_program_with_dist_attr(complete_train_program,
# # dist_context)
# dist_context.finalize_distributed_attr_for_program(
Expand Down Expand Up @@ -423,8 +421,9 @@ def test_attn_dp(self):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program())
Expand All @@ -440,10 +439,9 @@ def test_attn_mp(self):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_attn_dp_mp(self):
Expand All @@ -458,10 +456,9 @@ def test_attn_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())


Expand Down Expand Up @@ -747,10 +744,9 @@ def test_decoder_dp(self):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_decoder_mp(self):
Expand All @@ -764,10 +760,9 @@ def test_decoder_mp(self):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_decoder_dp_mp(self):
Expand All @@ -782,10 +777,9 @@ def test_decoder_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from paddle.distributed.fleet import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
Expand Down Expand Up @@ -817,10 +818,9 @@ def test_gpt_dp(self):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_gpt_mp(self):
Expand All @@ -834,10 +834,9 @@ def test_gpt_mp(self):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())

def test_gpt_dp_mp(self):
Expand All @@ -852,10 +851,9 @@ def test_gpt_dp_mp(self):
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
self.assertTrue(dist_context.validate_dist_attr_for_program())


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
Expand Down Expand Up @@ -154,8 +155,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context

# serial forward & backward completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)

params_grads = parallelizer._generate_backward(
complete_train_program,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import paddle
from paddle.fluid import core
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
Expand All @@ -42,8 +43,9 @@ def get_dist_prog(train_program,
parallelizer._dist_context = dist_context

# serial forward & backward completion
complete_train_program = auto.complete_annotation(
train_program, dist_context
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program
) if complete_train_program is None else complete_train_program

# parallelizer._apply_serial_forward_pass(complete_train_program,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from paddle.distributed import fleet

import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
Expand Down Expand Up @@ -433,6 +434,12 @@ def forward(self, input):
out = F.gelu(out, approximate=True)
out = self.linear1(out)

auto.shard_tensor(
out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
out = self.linear2(out)
out = F.gelu(out, approximate=True)
out = self.linear3(out)
Expand Down Expand Up @@ -476,8 +483,9 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
parallelizer._dist_context = dist_context

# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)

params_grads = parallelizer._generate_backward(
complete_train_program,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
Expand All @@ -49,8 +50,9 @@ def get_programs(annotated_func):
global _global_process_mesh
dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)

rank_id = 3
dist_strategy = fleet.DistributedStrategy()
Expand Down
Loading