Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-Inference] rebuild matmul pass: trt and gpu_cpu #39369

Merged
merged 4 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference)
pass_library(map_matmul_to_mul_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
Expand Down Expand Up @@ -98,8 +97,14 @@ pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)

if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference)
endif()

if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h"
#include "paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h"

#include <cmath>
#include <string>
Expand All @@ -28,7 +28,7 @@ namespace ir {

class Node;

MapMatmul2MulPass::MapMatmul2MulPass() {
GpuCpuMapMatmul2MulPass::GpuCpuMapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End();
}

MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
GpuCpuMapMatmulV2ToMulPass::GpuCpuMapMatmulV2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -104,7 +104,7 @@ MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
.End();
}

MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
GpuCpuMapMatmulV2ToMatmulPass::GpuCpuMapMatmulV2ToMatmulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -143,7 +143,7 @@ MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
.End();
}

Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
GpuCpuFlatten2MatmulFusePass::GpuCpuFlatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -197,7 +197,7 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
.End();
}

Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
GpuCpuSqueeze2MatmulFusePass::GpuCpuSqueeze2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -251,10 +251,10 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
.End();
}

void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_to_mul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -264,7 +264,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul to mul";
VLOG(4) << "gpu_cpu map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
Expand All @@ -286,7 +286,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmul2MulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -311,7 +311,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in out mul op compat failed.";
return;
}
}
Expand All @@ -321,10 +321,10 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -335,7 +335,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "map matmul_v2 to mul";
VLOG(3) << "gpu_cpu map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
Expand All @@ -360,7 +360,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed.";
return;
}
OpDesc desc(matmul_v2_op->Op()->Block());
Expand All @@ -386,7 +386,8 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMulPass in out mul op compat failed.";
return;
}
}
Expand All @@ -396,10 +397,10 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_matmul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_matmul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -409,15 +410,15 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul_v2 to matmul";
VLOG(4) << "gpu_cpu map matmul_v2 to matmul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern);
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMatmulPass in op compat failed.";
return;
}

Expand Down Expand Up @@ -463,7 +464,8 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMatmulPass in out matmul op compat failed.";
return;
}
};
Expand All @@ -472,10 +474,10 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "squeeze2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_squeeze2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -485,7 +487,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul";
VLOG(4) << "gpu_cpu fuse squeeze2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
Expand Down Expand Up @@ -518,7 +520,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuSqueeze2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -542,7 +544,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuSqueeze2MatmulFusePass in out mul op compat failed.";
return;
}
}
Expand All @@ -552,7 +555,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
GpuCpuReshape2MatmulFusePass::GpuCpuReshape2MatmulFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -614,10 +617,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
.End();
}

void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "reshape2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_reshape2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -627,7 +630,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse reshape2+matmul to mul";
VLOG(4) << "gpu_cpu fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
Expand Down Expand Up @@ -662,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuReshape2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -680,7 +683,8 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
matmul_op->Op()->GetAttr("out_threshold"));
}
if (!IsCompat(desc)) {
LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuReshape2MatmulFusePass in out mul op compat failed.";
return;
}
auto mul_node = g->CreateOpNode(&desc);
Expand All @@ -696,10 +700,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "flatten2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_flatten2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -709,7 +713,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse flatten2+matmul to mul";
VLOG(4) << "gpu_cpu fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
Expand Down Expand Up @@ -749,7 +753,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {

if (pattern_found) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuFlatten2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -774,7 +778,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuFlatten2MatmulFusePass in out mul op compat failed.";
return;
}
}
Expand All @@ -788,50 +793,51 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework
} // namespace paddle

REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("mul", 0));

REGISTER_PASS(map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("mul", 0));

REGISTER_PASS(map_matmul_v2_to_matmul_pass,
paddle::framework::ir::MapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_matmul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_matmul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.LE("matmul", 1));

REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_squeeze2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuSqueeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_squeeze2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("squeeze2", 0)
.EQ("mul", 0));

REGISTER_PASS(reshape2_matmul_fuse_pass,
paddle::framework::ir::Reshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_reshape2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuReshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_reshape2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("reshape2", 0)
.EQ("mul", 0));

REGISTER_PASS(flatten2_matmul_fuse_pass,
paddle::framework::ir::Flatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_flatten2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuFlatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_flatten2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
Expand Down
Loading