Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference]Enhance the shape check of trt_embedding_eltwise_layernorm_fuse_pass,… #54861

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 92 additions & 55 deletions paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,44 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(

std::vector<std::string> ids;
std::vector<std::string> embs;

auto ids0_shape = start_pattern_in_nodes[i][0].first->Var()->GetShape();
bool flag = true;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
auto ids_shape = start_pattern_in_nodes[i][iter].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
auto ids_shape = inner_pattern_ins[js[iter]].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
Expand All @@ -322,66 +355,70 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
"inputs with lookup_table_v2";
return fusion_count;
}
if (flag) {
OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}

OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}

new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}

new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}

auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);

for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} else {
VLOG(3) << "Shape check failed, stop "
"embedding_eltwise_layernorm_fuse_pass.";
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}

return fusion_count;
Expand Down
135 changes: 86 additions & 49 deletions paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,68 +311,105 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(

std::vector<std::string> ids;
std::vector<std::string> embs;

auto ids0_shape = start_pattern_in_nodes[i][0].first->Var()->GetShape();
bool flag = true;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
auto ids_shape = start_pattern_in_nodes[i][iter].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
auto ids_shape = inner_pattern_ins[js[iter]].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}

OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
if (use_varseqlen && pos_id != "" && mask_id != "") {
new_op_desc.SetInput("PosId", {pos_id});
new_op_desc.SetInput("MaskId", {mask_id});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}
if (flag) {
OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
if (use_varseqlen && pos_id != "" && mask_id != "") {
new_op_desc.SetInput("PosId", {pos_id});
new_op_desc.SetInput("MaskId", {mask_id});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));

if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}

auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);

for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);

// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} else {
VLOG(3) << "Shape check failed, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}

return fusion_count;
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2567,7 +2567,7 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
}
if (op_type == "lookup_table") {
if (op_type == "lookup_table" || op_type == "lookup_table_v2") {
if (!with_dynamic_shape) {
VLOG(3) << "the lookup_table does not support "
"static shape yet";
Expand Down
4 changes: 2 additions & 2 deletions test/ir/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_merge_layernorm_fuse_pass PROPERTIES TIMEOUT 180)
set_tests_properties(test_skip_merge_layernorm_fuse_pass PROPERTIES TIMEOUT
180)
set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT
120)
set_tests_properties(test_trt_emb_eltwise_layernorm_fuse_pass
PROPERTIES TIMEOUT 180)

set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120)
Expand Down
Loading