From bedbe599ff6042dc15dd16fd32e41505fdb2a59d Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 10 Oct 2023 16:29:30 +0800 Subject: [PATCH] pnnx support torch-2.1 (#5074) --- .ci/pnnx.yml | 4 +++ tools/pnnx/CMakeLists.txt | 5 ++++ tools/pnnx/src/main.cpp | 5 ++++ .../F_scaled_dot_product_attention.cpp | 27 +++++++++++++++++++ 4 files changed, 41 insertions(+) diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 267a0afa289b..3f116a4fa2e4 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -48,6 +48,10 @@ jobs: torchvision-version: 0.15.1 torchvision-cache-key: '0_15_1' + - torch-version: 2.1.0 + torchvision-version: 0.16.0 + torchvision-cache-key: '0_16_0' + runs-on: pool-name: docker container: diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index 3a08cbc249e8..0c8326fc942f 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -70,6 +70,11 @@ if(Torch_VERSION VERSION_LESS "1.8") message(FATAL_ERROR "pnnx only supports PyTorch >= 1.8") endif() +if(Torch_VERSION VERSION_GREATER_EQUAL "2.1") + # c++17 is required for using torch 2.1+ headers + set(CMAKE_CXX_STANDARD 17) +endif() + if(TorchVision_FOUND) message(STATUS "Building with TorchVision") add_definitions(-DPNNX_TORCHVISION) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index dc8ca72dc7e8..f745ef03473e 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -300,6 +300,11 @@ int main(int argc, char** argv) fprintf(stderr, "\n"); } +#ifdef PNNX_TORCHVISION + // call some vision api to register vision ops :P + (void)vision::cuda_version(); +#endif + for (auto m : customop_modules) { fprintf(stderr, "load custom module %s\n", m.c_str()); diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index e7ca7bbf8243..8dcfafaf12b4 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -42,4 +42,31 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10) +class F_scaled_dot_product_attention_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +prim::Constant op_0 0 1 dropout_p value=%dropout_p +prim::Constant op_1 0 1 is_causal value=%is_causal +prim::Constant op_2 0 1 scale value=%scale +aten::scaled_dot_product_attention op_3 7 1 query key value attn_mask dropout_p is_causal scale out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.scaled_dot_product_attention"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) + } // namespace pnnx