Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#5 from Caozhou1995/new_api_1
Browse files Browse the repository at this point in the history
update shard annotation
  • Loading branch information
Caozhou1995 authored Sep 13, 2022
2 parents 3fa80fe + 1665367 commit 91ce66b
Show file tree
Hide file tree
Showing 22 changed files with 365 additions and 952 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, 1],
dtype='float32')

auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None, None])

mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
Expand All @@ -127,7 +123,7 @@ def mlp_pretrain_forward(train_program, start_program):

def train():
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])

dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@

paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]],
dim_names=["x", "y", "z"])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])


class MLPLayer(nn.Layer):
Expand All @@ -74,16 +75,8 @@ def __init__(self,
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

def forward(self, input):
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])

out = self.norm(input)
out = self.linear0(out)
Expand Down Expand Up @@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program):
embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
embedding_out = embedding(fill_constant_out)

auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, PP_MESH_0, ["x", None])
auto.shard_tensor(label, PP_MESH_1, ["x", None])

mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
batch_size = 4
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
_g_process_mesh = [
auto.ProcessMesh([0, 1], dim_names=["x"]),
auto.ProcessMesh([2, 3], dim_names=["x"])
]


def get_random_inputs_and_labels(input_shape, label_shape):
Expand Down Expand Up @@ -82,18 +85,10 @@ def __init__(self,

def forward(self, input):
out = self.norm(input)
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear0.weight, _g_process_mesh[0], [None, "x"])
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear1.weight, _g_process_mesh[1], ["x", None])
out = self.linear1(out)

return out
Expand Down Expand Up @@ -123,16 +118,8 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None])
auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None])

mlp_start = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp():
is_sparse=False)
loss = paddle.fluid.layers.reduce_mean(emb_out)

auto.shard_tensor(src_ids,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(
src_ids, auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
["x", None, None])
emb_weight = block.vars["emb_weight"]
auto.shard_tensor(emb_weight,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [1, -1]
})
auto.shard_tensor(
emb_weight, auto.ProcessMesh([[0, 1], [2, 3]],
dim_names=["x", "y"]), ["y", None])

return main_program, start_program, loss

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,82 +22,58 @@

paddle.enable_static()

mesh = [[0, 1], [2, 3]]
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])


def init_x_row(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, 1, -1]
})
auto.shard_tensor(x, mesh, ["x", "y", None])

return x
else:
x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1, 1]
})
auto.shard_tensor(x, mesh, ["x", None, "y"])

return x


def init_x_col(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(x, mesh, [None, "x"])

return x
else:
x = paddle.static.data(name='x', shape=[8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(x, mesh, ["x", None])

return x


def init_y_row(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(y, mesh, [None, "y"])

return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(y, mesh, ["y", None])

return y


def init_y_col(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(y, mesh, ["y", None])

return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(y, mesh, [None, "y"])

return y


Expand Down
Loading

0 comments on commit 91ce66b

Please sign in to comment.