From 1c40615b2dbeac41cffa9738d4888b822e1509ba Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 12 Jul 2024 14:36:10 +0800 Subject: [PATCH] pnnx convert onnx sdap reduce min/max/mean/sum/prod (#5579) * pnnx convert onnx sdap * test reduce --- .../F_scaled_dot_product_attention.cpp | 91 +++++++++++++++++++ tools/pnnx/src/pass_level2/torch_max.cpp | 80 +++++++++++++--- tools/pnnx/src/pass_level2/torch_min.cpp | 80 +++++++++++++--- tools/pnnx/src/pass_level2/torch_prod.cpp | 32 ++++--- tools/pnnx/tests/onnx/CMakeLists.txt | 8 +- .../test_F_scaled_dot_product_attention.py | 64 +++++++++++++ tools/pnnx/tests/onnx/test_torch_max.py | 62 +++++++++++++ tools/pnnx/tests/onnx/test_torch_mean.py | 60 ++++++++++++ tools/pnnx/tests/onnx/test_torch_min.py | 62 +++++++++++++ tools/pnnx/tests/onnx/test_torch_prod.py | 60 ++++++++++++ tools/pnnx/tests/onnx/test_torch_sum.py | 60 ++++++++++++ 11 files changed, 617 insertions(+), 42 deletions(-) create mode 100644 tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py create mode 100644 tools/pnnx/tests/onnx/test_torch_max.py create mode 100644 tools/pnnx/tests/onnx/test_torch_mean.py create mode 100644 tools/pnnx/tests/onnx/test_torch_min.py create mode 100644 tools/pnnx/tests/onnx/test_torch_prod.py create mode 100644 tools/pnnx/tests/onnx/test_torch_sum.py diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index 36ca3c334f2..9fba1e770cc 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -80,4 +80,95 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +class F_scaled_dot_product_attention_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +12 11 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +Transpose op_0 1 1 key kt perm=(0,1,3,2) +prim::Constant op_1 0 1 scale value=%sqrt_scale +aten::mul op_2 2 1 query scale q +prim::Constant op_3 0 1 scale2 value=%sqrt_scale +aten::mul op_4 2 1 kt scale2 k +MatMul op_5 2 1 q k qk +Softmax op_6 1 1 qk 4 axis=-1 +MatMul op_7 2 1 4 value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.scaled_dot_product_attention"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dropout_p"] = 0.f; + op->params["is_causal"] = false; + + const float sqrt_scale = captured_params.at("sqrt_scale").f; + const float scale = sqrt_scale * sqrt_scale; + + op->params["scale"] = scale; + + if (!op->inputs[0]->shape.empty()) + { + const int embed_dim = op->inputs[0]->shape[op->inputs[0]->shape.size() - 1]; + if (NearlyEqual(scale, 1.f / sqrt(embed_dim), 0.001)) + { + // drop scale=None for compatibility with old torch + op->params.erase("scale"); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_onnx, 10) + +class F_scaled_dot_product_attention_onnx_1 : public F_scaled_dot_product_attention_onnx +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +14 13 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +Transpose op_0 1 1 key kt perm=(0,1,3,2) +prim::Constant op_1 0 1 scale value=%sqrt_scale +aten::mul op_2 2 1 query scale q +prim::Constant op_3 0 1 scale2 value=%sqrt_scale +aten::mul op_4 2 1 kt scale2 k +MatMul op_5 2 1 q k qk +aten::add op_6 2 1 qk attn_mask qkm +Softmax op_7 1 1 qkm 4 axis=-1 +MatMul op_8 2 1 4 value out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_onnx_1, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 68479b85d5b..b606fed066b 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -83,30 +83,80 @@ pnnx.Output output 1 0 out if (captured_params.find("op_0.axes") != captured_params.end()) { op->params["dim"] = captured_params.at("op_0.axes"); - } - else - { - // reduce all - const int input_rank = (int)op->inputs[0]->shape.size(); - std::vector dim(input_rank); - for (int i = 0; i < input_rank; i++) + + if (captured_params.find("op_0.keepdims") != captured_params.end()) { - dim[i] = i; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; } - op->params["dim"] = dim; - } - - if (captured_params.find("op_0.keepdims") != captured_params.end()) - { - op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; } else { - op->params["keepdim"] = true; + // reduce all } } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx, 20) +class torch_max_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ReduceMax op_0 1 1 input out %*=%* +ArgMax op_1 1 1 input indices %*=%* +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.max"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") == captured_params.end()) + return false; + + if (captured_params.find("op_0.keepdims") == captured_params.end()) + return false; + + if (captured_params.find("op_1.axis") == captured_params.end()) + return false; + + if (captured_params.find("op_1.keepdims") == captured_params.end()) + return false; + + if (captured_params.at("op_0.axes").type != 5 || captured_params.at("op_0.axes").ai.size() != 1) + return false; + + if (captured_params.at("op_1.axis").type != 2) + return false; + + if (captured_params.at("op_0.axes").ai[0] != captured_params.at("op_1.axis").i) + return false; + + if (captured_params.at("op_0.keepdims").i != captured_params.at("op_1.keepdims").i) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dim"] = captured_params.at("op_0.axes").ai[0]; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx_1, 19) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index c5e48bbc64b..35cc4988a19 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -83,30 +83,80 @@ pnnx.Output output 1 0 out if (captured_params.find("op_0.axes") != captured_params.end()) { op->params["dim"] = captured_params.at("op_0.axes"); - } - else - { - // reduce all - const int input_rank = (int)op->inputs[0]->shape.size(); - std::vector dim(input_rank); - for (int i = 0; i < input_rank; i++) + + if (captured_params.find("op_0.keepdims") != captured_params.end()) { - dim[i] = i; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; } - op->params["dim"] = dim; - } - - if (captured_params.find("op_0.keepdims") != captured_params.end()) - { - op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; } else { - op->params["keepdim"] = true; + // reduce all } } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx, 20) +class torch_min_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +ReduceMin op_0 1 1 input out %*=%* +ArgMin op_1 1 1 input indices %*=%* +pnnx.Output output 2 0 out indices +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.min"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") == captured_params.end()) + return false; + + if (captured_params.find("op_0.keepdims") == captured_params.end()) + return false; + + if (captured_params.find("op_1.axis") == captured_params.end()) + return false; + + if (captured_params.find("op_1.keepdims") == captured_params.end()) + return false; + + if (captured_params.at("op_0.axes").type != 5 || captured_params.at("op_0.axes").ai.size() != 1) + return false; + + if (captured_params.at("op_1.axis").type != 2) + return false; + + if (captured_params.at("op_0.axes").ai[0] != captured_params.at("op_1.axis").i) + return false; + + if (captured_params.at("op_0.keepdims").i != captured_params.at("op_1.keepdims").i) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dim"] = captured_params.at("op_0.axes").ai[0]; + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx_1, 19) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_prod.cpp b/tools/pnnx/src/pass_level2/torch_prod.cpp index 7f15c2ba88a..51b614ec0dc 100644 --- a/tools/pnnx/src/pass_level2/torch_prod.cpp +++ b/tools/pnnx/src/pass_level2/torch_prod.cpp @@ -58,24 +58,34 @@ pnnx.Output output 1 0 out return "torch.prod"; } + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") == captured_params.end()) + return false; + + if (captured_params.at("op_0.axes").type != 2 && captured_params.at("op_0.axes").type != 5) + return false; + + if (captured_params.at("op_0.axes").type == 5 && captured_params.at("op_0.axes").ai.size() > 1) + return false; + + return true; + } + void write(Operator* op, const std::map& captured_params) const { - if (captured_params.find("op_0.axes") != captured_params.end()) + int dim; + if (captured_params.at("op_0.axes").type == 2) { - op->params["dim"] = captured_params.at("op_0.axes"); + dim = captured_params.at("op_0.axes").i; } - else + else // if (captured_params.at("op_0.axes").type == 5) { - // reduce all - const int input_rank = (int)op->inputs[0]->shape.size(); - std::vector dim(input_rank); - for (int i = 0; i < input_rank; i++) - { - dim[i] = i; - } - op->params["dim"] = dim; + dim = captured_params.at("op_0.axes").ai[0]; } + op->params["dim"] = dim; + if (captured_params.find("op_0.keepdims") != captured_params.end()) { op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 12d816cd8e2..0c0a136fbaf 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -36,7 +36,7 @@ pnnx_onnx_add_test(F_pad) pnnx_onnx_add_test(F_prelu) pnnx_onnx_add_test(F_relu) pnnx_onnx_add_test(F_relu6) -# pnnx_onnx_add_test(F_scaled_dot_product_attention) +pnnx_onnx_add_test(F_scaled_dot_product_attention) pnnx_onnx_add_test(F_sigmoid) pnnx_onnx_add_test(F_softmax) pnnx_onnx_add_test(F_upsample_bilinear) @@ -103,3 +103,9 @@ pnnx_onnx_add_test(shufflenet_v2_x1_0) pnnx_onnx_add_test(squeezenet1_1) pnnx_onnx_add_test(swin_t) pnnx_onnx_add_test(vit_b_32) + +pnnx_onnx_add_test(torch_max) +pnnx_onnx_add_test(torch_mean) +pnnx_onnx_add_test(torch_min) +pnnx_onnx_add_test(torch_prod) +pnnx_onnx_add_test(torch_sum) diff --git a/tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py b/tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py new file mode 100644 index 00000000000..2802b3794a7 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_scaled_dot_product_attention.py @@ -0,0 +1,64 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, q, k, v, m): + x = F.scaled_dot_product_attention(q, k, v) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=m) + return x, y + +def test(): + if version.parse(torch.__version__) < version.parse('2.1'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + q = torch.rand(3, 8, 128, 64) + k = torch.rand(3, 8, 48, 64) + v = torch.rand(3, 8, 48, 77) + m = torch.rand(3, 8, 128, 48) + + a = net(q, k, v, m) + + # export onnx + torch.onnx.export(net, (q, k, v, m), "test_F_scaled_dot_product_attention.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_scaled_dot_product_attention.onnx inputshape=[3,8,128,64],[3,8,48,64],[3,8,48,77],[3,8,128,48]") + + # pnnx inference + import test_F_scaled_dot_product_attention_pnnx + b = test_F_scaled_dot_product_attention_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_max.py b/tools/pnnx/tests/onnx/test_torch_max.py new file mode 100644 index 00000000000..0ab18bec47d --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_max.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x, x_indices = torch.max(x, dim=1, keepdim=False) + y = torch.max(y) + w = torch.max(z, w) + z, z_indices = torch.max(z, dim=0, keepdim=True) + return x, x_indices, y, z, z_indices, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + w = torch.rand(5, 9, 10) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_torch_max.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_max.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[5,9,10]") + + # pnnx inference + import test_torch_max_pnnx + b = test_torch_max_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_mean.py b/tools/pnnx/tests/onnx/test_torch_mean.py new file mode 100644 index 00000000000..cf599057579 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_mean.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.mean(x, dim=1, keepdim=False) + y = torch.mean(y, dim=(2,3), keepdim=False) + z = torch.mean(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_mean.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_mean.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_mean_pnnx + b = test_torch_mean_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_min.py b/tools/pnnx/tests/onnx/test_torch_min.py new file mode 100644 index 00000000000..e41584afe0c --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_min.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x, x_indices = torch.min(x, dim=1, keepdim=False) + y = torch.min(y) + w = torch.min(z, w) + z, z_indices = torch.min(z, dim=0, keepdim=True) + return x, x_indices, y, z, z_indices, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + w = torch.rand(5, 9, 10) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_torch_min.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_min.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10],[5,9,10]") + + # pnnx inference + import test_torch_min_pnnx + b = test_torch_min_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_prod.py b/tools/pnnx/tests/onnx/test_torch_prod.py new file mode 100644 index 00000000000..c36b97c2a31 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_prod.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.prod(x, dim=1, keepdim=False) + y = torch.prod(y, dim=2, keepdim=False) + z = torch.prod(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_prod.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_prod.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_prod_pnnx + b = test_torch_prod_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_torch_sum.py b/tools/pnnx/tests/onnx/test_torch_sum.py new file mode 100644 index 00000000000..3ae6412f09b --- /dev/null +++ b/tools/pnnx/tests/onnx/test_torch_sum.py @@ -0,0 +1,60 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.sum(x, dim=1, keepdim=False) + y = torch.sum(y, dim=(2,3), keepdim=False) + z = torch.sum(z, dim=0, keepdim=True) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export onnx + torch.onnx.export(net, (x, y, z), "test_torch_sum.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_torch_sum.onnx inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_sum_pnnx + b = test_torch_sum_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)