Skip to content

Commit

Permalink
[Paddle Inference]Enhance the shape check of trt_embedding_eltwise_la…
Browse files Browse the repository at this point in the history
…yernorm_fuse_pass,… (#54861)

* Enhance the shape check of trt_embedding_eltwise_layernorm_fuse_pass, embedding_eltwise_layernorm_fuse_pass
  • Loading branch information
Wangzheee authored Jun 27, 2023
1 parent f8d0214 commit e49c17d
Show file tree
Hide file tree
Showing 5 changed files with 974 additions and 232 deletions.
147 changes: 92 additions & 55 deletions paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,44 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(

std::vector<std::string> ids;
std::vector<std::string> embs;

auto ids0_shape = start_pattern_in_nodes[i][0].first->Var()->GetShape();
bool flag = true;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
auto ids_shape = start_pattern_in_nodes[i][iter].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
auto ids_shape = inner_pattern_ins[js[iter]].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
Expand All @@ -322,66 +355,70 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
"inputs with lookup_table_v2";
return fusion_count;
}
if (flag) {
OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}

OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}

new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}

new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}

auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);

for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} else {
VLOG(3) << "Shape check failed, stop "
"embedding_eltwise_layernorm_fuse_pass.";
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}

return fusion_count;
Expand Down
135 changes: 86 additions & 49 deletions paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,68 +311,105 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(

std::vector<std::string> ids;
std::vector<std::string> embs;

auto ids0_shape = start_pattern_in_nodes[i][0].first->Var()->GetShape();
bool flag = true;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
auto ids_shape = start_pattern_in_nodes[i][iter].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
auto ids_shape = inner_pattern_ins[js[iter]].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}

OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
if (use_varseqlen && pos_id != "" && mask_id != "") {
new_op_desc.SetInput("PosId", {pos_id});
new_op_desc.SetInput("MaskId", {mask_id});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}
if (flag) {
OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
if (use_varseqlen && pos_id != "" && mask_id != "") {
new_op_desc.SetInput("PosId", {pos_id});
new_op_desc.SetInput("MaskId", {mask_id});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}

auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);

for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} else {
VLOG(3) << "Shape check failed, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}

return fusion_count;
}

Expand Down
4 changes: 2 additions & 2 deletions test/ir/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_merge_layernorm_fuse_pass PROPERTIES TIMEOUT 180)
set_tests_properties(test_skip_merge_layernorm_fuse_pass PROPERTIES TIMEOUT
180)
set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT
120)
set_tests_properties(test_trt_emb_eltwise_layernorm_fuse_pass
PROPERTIES TIMEOUT 180)

set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120)
Expand Down
Loading

0 comments on commit e49c17d

Please sign in to comment.