Skip to content

Commit

Permalink
[Auto Parallel]Update comp cost and completion for gpt auto search (P…
Browse files Browse the repository at this point in the history
…addlePaddle#46387)

* update comp cost and completion for gpt auto search

* add unittest
  • Loading branch information
Caozhou1995 authored and zhaoyingli committed Oct 19, 2022
1 parent def66b1 commit c8b7203
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
72 changes: 72 additions & 0 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False

def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
Expand Down Expand Up @@ -719,6 +720,8 @@ def _update_process_mesh(self):
self._update_process_mesh_between_graphs()

def _prepare(self):
if self._has_prepared:
return
self._while_op_nodes = {}
self._array_nodes = {}
self._node_pairs_between_graphs = []
Expand All @@ -732,6 +735,8 @@ def _prepare(self):
if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
# Add the array input node
self._array_nodes[array_var_name].append(node.inputs[0])
if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None:
Expand All @@ -752,6 +757,7 @@ def _prepare(self):
and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(after_node, node))
self._has_prepared = True

def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program.
Expand Down Expand Up @@ -899,6 +905,72 @@ def _update_dist_attr_for_dp(self):
else:
dist_op.dist_attr = original_op_dist_attr

def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program

self._dist_context.initialize()

self._prepare()

has_set_dist_attr = set()

all_nodes = self._dist_context.serial_ordered_nodes
for node in all_nodes:
if node.is_op():
if node.op().type() in ["while"]:
continue
dist_op = self._dist_context.get_dist_op_for_graph(node)
op_dist_attr = dist_op.dist_attr
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
tensor_desc = tensor_node.var()
tensor_name = tensor_desc.name()
tensor = dist_op.get_serial_input(tensor_name)
# Use the first op to set the tensor dist attr
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name) if tensor.is_parameter else [
-1 for i in tensor_desc.shape()
]
has_set_dist_attr.add(tensor_name)
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_name = tensor_node.var().name()
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
has_set_dist_attr.add(tensor_name)

self._update_process_mesh_for_specials()

self._update_process_mesh_between_graphs()

self._update_dims_mapping_for_special()

self._update_dims_mapping_between_graphs()

# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()

# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()

self._dist_context.validate_dist_attr_for_program()

def _complete_high_order_grad_annotation(self, serial_main_program=None):
"""
NOTE:
Expand Down
55 changes: 55 additions & 0 deletions python/paddle/distributed/auto_parallel/cost/comp_op_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,25 @@ def calc_time(self):
return 0


@register_op_cost
class DropoutGradOpCost(CompOpCost):
OP_TYPE = "dropout_grad"

def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add"
Expand Down Expand Up @@ -395,6 +414,42 @@ def calc_time(self):
return 0


@register_op_cost
class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle"

def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle_grad"

def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)

# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0

def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0


@register_op_cost
class GatherOpCost(CompOpCost):
OP_TYPE = "gather"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import DropoutGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleGradOpCost

from test_cluster import cluster_json

Expand Down Expand Up @@ -417,6 +420,22 @@ def test_comp_cost(self):
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = DropoutGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = FusedSoftmaxMaskUpperTriangleOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

op_cost = FusedSoftmaxMaskUpperTriangleGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)

# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ def test_completer(self):
train_program)
# print_program_with_dist_attr(complete_train_program, dist_context)

def test_completer_by_dist_op(self):
train_program, start_program, dataloader, i, loss = get_program()
dist_context = DistributedContext()
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
complete_train_program = completer._complete_tensor_dist_attr_by_op()


if __name__ == "__main__":
unittest.main()

0 comments on commit c8b7203

Please sign in to comment.