Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Add the support for the auto completion of while_op #39939

Merged
merged 13 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
std::unordered_map<std::string, std::pair<VarDesc *, int>>
name_to_desc_block_id;

block_id_ = block.ID();
const BlockDesc *block_var_visible = &block;
while (block_var_visible != nullptr) {
for (auto *var : block_var_visible->AllVars()) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class Graph {
auto *x =
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}

Expand All @@ -245,6 +246,7 @@ class Graph {
"The OpDesc used to create operator node is null."));
auto *x = AddNode(new ir::Node(op_desc));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}

Expand All @@ -263,6 +265,7 @@ class Graph {
num_node_created_);
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}

Expand All @@ -276,6 +279,7 @@ class Graph {
}
auto *x = AddNode(new ir::Node(name, type, block_id_));
x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x;
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class Node {
// Only use this for auto parallel.
// A node does not have original desc if the return is zero.
uint64_t OriginalDescId() const { return original_desc_id_; }
int GraphId() const { return graph_id_; }

bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; }
Expand Down Expand Up @@ -246,10 +247,12 @@ class Node {
// Store the original id of var desc or op desc.
// Only use this for auto parallel.
uint64_t original_desc_id_{0};
int graph_id_{-1};

private:
// ID can only set by a Graph.
void SetId(int id) { id_ = id; }
void SetGraphId(int graph_id) { graph_id_ = graph_id; }

// desc_order can only set by a Graph when constructing a Graph from a
// BlockDesc.
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void BindNode(py::module *m) {
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("graph_id", &Node::GraphId)
.def("original_desc_id", &Node::OriginalDescId)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
Expand Down
437 changes: 367 additions & 70 deletions python/paddle/distributed/auto_parallel/completion.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions python/paddle/distributed/auto_parallel/dist_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __str__(self):
class OperatorDistributedAttribute:
def __init__(self):
self._process_mesh = None
self._op_type = None
self._impl_type = None
self._impl_idx = None
self._inputs_dist_attrs = {}
Expand All @@ -194,11 +195,23 @@ def process_mesh(self, process_mesh):
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
# In while op, the proess mesh is not shared by all inputs and outputs
if self._op_type == "while":
return None
for dist_attr in self._inputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
for dist_attr in self._outputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh

@property
def op_type(self):
return self._op_type

@op_type.setter
def op_type(self, op_type):
if op_type is not None:
self._op_type = op_type

@property
def impl_type(self):
return self._impl_type
Expand Down Expand Up @@ -326,6 +339,8 @@ def init(self, dist_attr):
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
# Make sure proscess_meshes in dist op be same
if self.op_type == "while":
return None
process_meshes = []
process_meshes.append(self.process_mesh)
for tensor_dist_attr in self.inputs_dist_attrs.values():
Expand Down
114 changes: 84 additions & 30 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid.framework import get_flags, set_flags
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
Expand All @@ -39,6 +40,10 @@ def set_default_distributed_context(dist_context):
_g_default_distributed_context = dist_context


def _node_id(node):
return (node.node.graph_id(), node.node.id())


class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
Expand Down Expand Up @@ -146,7 +151,7 @@ def get_dist_tensor_for_program(self, serial_tensor):
return None

def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
serial_tensor_node_id = _node_id(serial_tensor_node)
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)

def get_dist_op_for_program(self, serial_op):
Expand All @@ -168,7 +173,7 @@ def del_dist_op_for_program(self, serial_tensor):
del self._dist_ops_for_program[serial_tensor_id]

def get_dist_op_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
serial_op_node_id = _node_id(serial_op_node)
return self._dist_ops_for_graph.get(serial_op_node_id, None)

def get_tensor_dist_attr_for_program(self, serial_tensor):
Expand Down Expand Up @@ -197,7 +202,7 @@ def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
self.add_dist_tensor_for_program(dist_tensor)

def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
serial_tensor_node_id = _node_id(serial_tensor_node)
dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
None)
if dist_tensor:
Expand Down Expand Up @@ -242,7 +247,7 @@ def set_op_dist_attr_for_program(self, serial_op, dist_attr):
self.add_dist_op_for_program(dist_op)

def get_op_dist_attr_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
serial_op_node_id = _node_id(serial_op_node)
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
Expand All @@ -262,15 +267,15 @@ def get_op_dist_attr_for_graph(self, serial_op_node):

def get_dist_attr_for_graph(self, serial_node):
if serial_node.is_var() and serial_node.var() is not None:
serial_tensor_node_id = serial_node.id()
serial_tensor_node_id = _node_id(serial_node)
dist_tensor = self._dist_tensors_for_graph.get(
serial_tensor_node_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
if serial_node.is_op() and serial_node.op() is not None:
serial_op_node_id = serial_node.id()
serial_op_node_id = _node_id(serial_node)
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
Expand Down Expand Up @@ -311,50 +316,79 @@ def init_dist_attr_for_program(self):
def order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if node.id() == target_node.id():
if _node_id(node) == _node_id(target_node):
return True
return False

ordered_tensor_nodes = []
ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes()
serial_ordered_tensor_nodes = []
serial_ordered_op_nodes = []
all_nodes = []
# for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
for node in graph.all_nodes():
all_nodes.append(node)
for node in all_nodes:
if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node)
serial_ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes:
serial_ordered_op_nodes.append(node)
serial_ordered_tensor_nodes.sort(
key=lambda node: node.node.original_desc_id())
serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id())
num_nodes_before = len(serial_ordered_tensor_nodes) + len(
serial_ordered_op_nodes)

new_serial_ordered_tensor_nodes = []
new_serial_ordered_op_nodes = []
for op_node in serial_ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
new_serial_ordered_op_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
new_serial_ordered_tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes)
assert len(self._serial_ordered_nodes) == num_nodes_before, \
"The number of nodes before ordering is not the same after ordering."
new_serial_ordered_tensor_nodes.sort(
key=lambda node: node.node.original_desc_id())
new_serial_ordered_op_nodes.sort(
key=lambda node: node.node.original_desc_id())
self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
assert len(self._serial_ordered_nodes) == len(
self._serial_ordered_tensor_nodes) + len(
self._serial_ordered_op_nodes)
self._serial_orphan_tensor_nodes = []
for tensor_node in serial_ordered_tensor_nodes:
if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
self._serial_orphan_tensor_nodes.append(tensor_node)
if len(self._serial_ordered_nodes) != num_nodes_before:
print(
"WARNING: there are some orphan tensors or ops which are not used in the execution."
)

def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
if self._is_initialized_for_graph:
return
# Convert program to graph
set_flags({"FLAGS_convert_all_blocks": True})
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
Expand All @@ -365,10 +399,11 @@ def init_dist_attr_for_graph(self):
if tensor_id == cur_tensor_id \
or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
dist_tensor = cur_dist_tensor
self._node_id_to_tensor_id[node.id()] = cur_tensor_id
self._node_id_to_tensor_id[_node_id(
node)] = cur_tensor_id
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
serial_tensor_node_id = node.id()
serial_tensor_node_id = _node_id(node)
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_tensor.dist_attr)
self._dist_tensors_for_graph[
Expand All @@ -381,10 +416,10 @@ def init_dist_attr_for_graph(self):
if op_id == cur_op_id \
or op_id == cur_dist_op.serial_op.desc.original_id():
dist_op = cur_dist_op
self._node_id_to_op_id[node.id()] = cur_op_id
self._node_id_to_op_id[_node_id(node)] = cur_op_id
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
serial_op_node_id = node.id()
serial_op_node_id = _node_id(node)
new_dist_op = DistributedOperator(dist_op.serial_op,
dist_op.dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
Expand All @@ -402,10 +437,11 @@ def copy_dist_attr_from_graph_to_program(self):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"Both program and graph must be initialized."
updated_tensors = {}
all_nodes = self._serial_graph.all_nodes()
# all_nodes = self._serial_graph.all_nodes()
all_nodes = self._serial_ordered_nodes
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_id = self._node_id_to_tensor_id[node.id()]
tensor_id = self._node_id_to_tensor_id[_node_id(node)]
updated = updated_tensors.get(tensor_id, False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
Expand All @@ -416,16 +452,31 @@ def copy_dist_attr_from_graph_to_program(self):
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_id] = True
if node.is_op() and node.op() is not None:
op_id = self._node_id_to_op_id[node.id()]
op_id = self._node_id_to_op_id[_node_id(node)]
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
# TODO: the completion algorithm will skip orphan tensors,
# here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes:
serial_tensor_id = orphan_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id,
None)
if dist_tensor:
dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
else:
serial_tensor_id = orphan_node.var().original_id()
dist_tensor = self._dist_tensors_for_program.get(
serial_tensor_id, None)
dist_tensor.dist_attr.process_mesh = self._process_meshes[0]

def amend_dist_attr_for_program(self):
for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr
if serial_tensor.type == core.VarDesc.VarType.READER:
if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
Expand All @@ -446,6 +497,7 @@ def amend_dist_attr_for_program(self):
tensor_shape = []
else:
if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or dist_op.serial_op.type == "create_py_reader":
tensor_shape = []
else:
Expand All @@ -459,8 +511,9 @@ def amend_dist_attr_for_program(self):
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(
arg_name).type == core.VarDesc.VarType.READER:
if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_output(arg_name).shape
Expand Down Expand Up @@ -498,7 +551,8 @@ def __deepcopy__(self, memo):
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes":
or k == "_serial_ordered_nodes" or k == "_serial_ordered_tensor_nodes" \
or k == "_serial_ordered_op_nodes":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
Expand Down
Loading