From 2ded0bf9e8b03a6c4d34da01d07cdd690a2d5508 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 30 Mar 2023 07:36:04 +0000 Subject: [PATCH 1/4] fix skiplayernorm, add trt_version check --- .../fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 18ea8850dc5bf..a5c073d591bb2 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -105,6 +105,16 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("skip_layernorm_fuse", graph); + + auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion(); + if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 + + std::get<2>(trt_version) * 10 < + 7200) { + VLOG(3) << "skip_layernorm oss plugin only available for trt version >= " + "7.2 Stop this pass"; + return; + } + int found_subgraph_count = 0; GraphPatternDetector gpd; From 56c48a304f44c19b54680d8e91a8819e0d39f2a1 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 30 Mar 2023 08:24:52 +0000 Subject: [PATCH 2/4] fix --- paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index a5c073d591bb2..472b098c9d00a 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/inference/tensorrt/helper.h" namespace paddle { namespace framework { From 3137d7e0c040e51acb7d76b909a03b81130f99b6 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 30 Mar 2023 09:37:28 +0000 Subject: [PATCH 3/4] fix --- paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 472b098c9d00a..a249d04d28abb 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -18,7 +18,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" +#ifdef PADDLE_WITH_TENSORRT #include "paddle/fluid/inference/tensorrt/helper.h" +#endif namespace paddle { namespace framework { From 3628f91f1e12d2eefaf82b48e820d679f8305cba Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 30 Mar 2023 12:08:23 +0000 Subject: [PATCH 4/4] fix --- paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index a249d04d28abb..6186dd72688e7 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -109,6 +109,7 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("skip_layernorm_fuse", graph); +#ifdef PADDLE_WITH_TENSORRT auto trt_version = paddle::inference::tensorrt::GetTrtRuntimeVersion(); if (std::get<0>(trt_version) * 1000 + std::get<1>(trt_version) * 100 + std::get<2>(trt_version) * 10 < @@ -117,7 +118,10 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { "7.2 Stop this pass"; return; } - +#else + // if no tensorrt, early stop + return; +#endif int found_subgraph_count = 0; GraphPatternDetector gpd;