From e7cae68a227f7bab2e085a9e1f24437d6749ac23 Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 13 Jul 2024 23:56:29 +0800 Subject: [PATCH] pnnx convert onnx logsoftmax/logsigmoid/mish/selu/sigmoid/silu/softmin/softplus/softshrink/softsign/tanh/tanhshrink (#5581) --- tools/pnnx/src/pass_level2/F_log_softmax.cpp | 73 +++++++++++++++++++ tools/pnnx/src/pass_level2/F_logsigmoid.cpp | 22 ++++++ tools/pnnx/src/pass_level2/F_mish.cpp | 23 ++++++ tools/pnnx/src/pass_level2/F_selu.cpp | 21 ++++++ tools/pnnx/src/pass_level2/F_softmin.cpp | 22 ++++++ tools/pnnx/src/pass_level2/F_softplus.cpp | 58 +++++++++++++++ tools/pnnx/src/pass_level2/F_softshrink.cpp | 58 +++++++++++++++ tools/pnnx/src/pass_level2/F_softsign.cpp | 24 +++++++ tools/pnnx/src/pass_level2/F_tanhshrink.cpp | 22 ++++++ tools/pnnx/tests/onnx/CMakeLists.txt | 22 ++++++ tools/pnnx/tests/onnx/test_F_log_softmax.py | 66 +++++++++++++++++ tools/pnnx/tests/onnx/test_F_logsigmoid.py | 66 +++++++++++++++++ tools/pnnx/tests/onnx/test_F_mish.py | 76 ++++++++++++++++++++ tools/pnnx/tests/onnx/test_F_selu.py | 66 +++++++++++++++++ tools/pnnx/tests/onnx/test_F_sigmoid.py | 9 ++- tools/pnnx/tests/onnx/test_F_silu.py | 69 ++++++++++++++++++ tools/pnnx/tests/onnx/test_F_softmin.py | 66 +++++++++++++++++ tools/pnnx/tests/onnx/test_F_softplus.py | 70 ++++++++++++++++++ tools/pnnx/tests/onnx/test_F_softshrink.py | 70 ++++++++++++++++++ tools/pnnx/tests/onnx/test_F_softsign.py | 66 +++++++++++++++++ tools/pnnx/tests/onnx/test_F_tanh.py | 66 +++++++++++++++++ tools/pnnx/tests/onnx/test_F_tanhshrink.py | 66 +++++++++++++++++ tools/pnnx/tests/onnx/test_nn_LogSigmoid.py | 68 ++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_LogSoftmax.py | 71 ++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_Mish.py | 72 +++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_SELU.py | 68 ++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_SiLU.py | 68 ++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_Sigmoid.py | 9 ++- tools/pnnx/tests/onnx/test_nn_Softmin.py | 71 ++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_Softplus.py | 73 +++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_Softshrink.py | 73 +++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_Softsign.py | 68 ++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_Tanh.py | 68 ++++++++++++++++++ tools/pnnx/tests/onnx/test_nn_Tanhshrink.py | 68 ++++++++++++++++++ 34 files changed, 1872 insertions(+), 6 deletions(-) create mode 100644 tools/pnnx/tests/onnx/test_F_log_softmax.py create mode 100644 tools/pnnx/tests/onnx/test_F_logsigmoid.py create mode 100644 tools/pnnx/tests/onnx/test_F_mish.py create mode 100644 tools/pnnx/tests/onnx/test_F_selu.py create mode 100644 tools/pnnx/tests/onnx/test_F_silu.py create mode 100644 tools/pnnx/tests/onnx/test_F_softmin.py create mode 100644 tools/pnnx/tests/onnx/test_F_softplus.py create mode 100644 tools/pnnx/tests/onnx/test_F_softshrink.py create mode 100644 tools/pnnx/tests/onnx/test_F_softsign.py create mode 100644 tools/pnnx/tests/onnx/test_F_tanh.py create mode 100644 tools/pnnx/tests/onnx/test_F_tanhshrink.py create mode 100644 tools/pnnx/tests/onnx/test_nn_LogSigmoid.py create mode 100644 tools/pnnx/tests/onnx/test_nn_LogSoftmax.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Mish.py create mode 100644 tools/pnnx/tests/onnx/test_nn_SELU.py create mode 100644 tools/pnnx/tests/onnx/test_nn_SiLU.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Softmin.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Softplus.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Softshrink.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Softsign.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Tanh.py create mode 100644 tools/pnnx/tests/onnx/test_nn_Tanhshrink.py diff --git a/tools/pnnx/src/pass_level2/F_log_softmax.cpp b/tools/pnnx/src/pass_level2/F_log_softmax.cpp index 0264973783b..ad9eba30d1c 100644 --- a/tools/pnnx/src/pass_level2/F_log_softmax.cpp +++ b/tools/pnnx/src/pass_level2/F_log_softmax.cpp @@ -39,4 +39,77 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_log_softmax, 10) +class F_log_softmax_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +LogSoftmax op_0 1 1 input out axis=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.log_softmax"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_log_softmax_onnx, 10) + +class F_log_softmax_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +Transpose op_0 1 1 input a perm=%perm +LogSoftmax op_1 1 1 a b axis=%axis +Transpose op_2 1 1 b out perm=%perm +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.log_softmax"; + } + + bool match(const std::map& captured_params) const + { + const std::vector& perm = captured_params.at("perm").ai; + const int axis = captured_params.at("axis").i; + + if (axis >= (int)perm.size()) + return false; + + int excount = 0; + for (int i = 0; i < (int)perm.size(); i++) + { + if (perm[i] != i) + excount++; + } + + if (excount != 2) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& perm = captured_params.at("perm").ai; + const int axis = captured_params.at("axis").i; + + op->params["dim"] = perm[axis]; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_log_softmax_onnx_1, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_logsigmoid.cpp b/tools/pnnx/src/pass_level2/F_logsigmoid.cpp index e35670686a0..e0d4df607f2 100644 --- a/tools/pnnx/src/pass_level2/F_logsigmoid.cpp +++ b/tools/pnnx/src/pass_level2/F_logsigmoid.cpp @@ -37,4 +37,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_logsigmoid, 10) +class F_logsigmoid_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +aten::sigmoid op_0 1 1 input a +aten::log op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.logsigmoid"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_logsigmoid_onnx, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_mish.cpp b/tools/pnnx/src/pass_level2/F_mish.cpp index 1a083ba85d9..485a7e3b0b5 100644 --- a/tools/pnnx/src/pass_level2/F_mish.cpp +++ b/tools/pnnx/src/pass_level2/F_mish.cpp @@ -62,4 +62,27 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_mish_1, 9) +class F_mish_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +Softplus op_0 1 1 input a +aten::tanh op_1 1 1 a b +aten::mul op_2 2 1 input b out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.mish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_mish_onnx, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_selu.cpp b/tools/pnnx/src/pass_level2/F_selu.cpp index 592c3dd8ed7..9df970b1bbc 100644 --- a/tools/pnnx/src/pass_level2/F_selu.cpp +++ b/tools/pnnx/src/pass_level2/F_selu.cpp @@ -37,4 +37,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_selu, 10) +class F_selu_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +Selu op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.selu"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_selu_onnx, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softmin.cpp b/tools/pnnx/src/pass_level2/F_softmin.cpp index bb0768663c5..89e5d9aeaf8 100644 --- a/tools/pnnx/src/pass_level2/F_softmin.cpp +++ b/tools/pnnx/src/pass_level2/F_softmin.cpp @@ -40,4 +40,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmin, 9) +class F_softmin_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +aten::neg op_0 1 1 input 6 +Softmax op_1 1 1 6 out axis=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softmin"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmin_onnx, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softplus.cpp b/tools/pnnx/src/pass_level2/F_softplus.cpp index c6a5279b414..8d346eb76ed 100644 --- a/tools/pnnx/src/pass_level2/F_softplus.cpp +++ b/tools/pnnx/src/pass_level2/F_softplus.cpp @@ -39,4 +39,62 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softplus, 10) +class F_softplus_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input_0 0 1 input +Softplus op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softplus"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["beta"] = 1.f; + op->params["threshold"] = 20.f; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softplus_onnx, 10) + +class F_softplus_onnx_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 input +prim::Constant op_0 0 1 beta value=%beta +aten::mul op_1 2 1 input beta a +Softplus op_2 1 1 a b +prim::Constant op_3 0 1 beta2 value=%beta +aten::div op_4 2 1 b beta2 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softplus"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["beta"] = captured_params.at("beta"); + op->params["threshold"] = 20.f; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softplus_onnx_1, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softshrink.cpp b/tools/pnnx/src/pass_level2/F_softshrink.cpp index 286990bf2c5..8d14a8a644b 100644 --- a/tools/pnnx/src/pass_level2/F_softshrink.cpp +++ b/tools/pnnx/src/pass_level2/F_softshrink.cpp @@ -38,4 +38,62 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softshrink, 10) +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +class F_softshrink_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +15 14 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 lambd value=%lambd +aten::gt op_1 2 1 input lambd 8 +prim::Constant op_2 0 1 lambd2 value=%lambd +aten::sub op_3 2 1 input lambd2 9 +prim::Constant op_4 0 1 zero value=0 +aten::where op_5 3 1 8 9 zero a +prim::Constant op_6 0 1 mlambd value=%lambd2 +aten::lt op_7 2 1 input mlambd 11 +prim::Constant op_8 0 1 lambd3 value=%lambd +aten::add op_9 2 1 input lambd3 12 +prim::Constant op_10 0 1 zero2 value=0 +aten::where op_11 3 1 11 12 zero2 b +aten::add op_12 2 1 a b out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softshrink"; + } + + bool match(const std::map& captured_params) const + { + float lambd = captured_params.at("lambd").f; + float lambd2 = captured_params.at("lambd2").f; + return NearlyEqual(lambd, -lambd2, 0.001); + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["lambd"] = captured_params.at("lambd"); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softshrink_onnx, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softsign.cpp b/tools/pnnx/src/pass_level2/F_softsign.cpp index 4ec8ae9e520..ae6005d6337 100644 --- a/tools/pnnx/src/pass_level2/F_softsign.cpp +++ b/tools/pnnx/src/pass_level2/F_softsign.cpp @@ -41,4 +41,28 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softsign, 10) +class F_softsign_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +aten::abs op_0 1 1 input 6 +prim::Constant op_1 0 1 8 value=1 +aten::add op_2 2 1 6 8 9 +aten::div op_3 2 1 input 9 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softsign"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softsign_onnx, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_tanhshrink.cpp b/tools/pnnx/src/pass_level2/F_tanhshrink.cpp index d8d6c311fcd..01e578bf8ad 100644 --- a/tools/pnnx/src/pass_level2/F_tanhshrink.cpp +++ b/tools/pnnx/src/pass_level2/F_tanhshrink.cpp @@ -39,4 +39,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_tanhshrink, 9) +class F_tanhshrink_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +aten::tanh op_0 1 1 input 7 +aten::sub op_1 2 1 input 7 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.tanhshrink"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_tanhshrink_onnx, 9) + } // namespace pnnx diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 0c0a136fbaf..0e283e77d48 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -29,16 +29,27 @@ pnnx_onnx_add_test(F_layer_norm) pnnx_onnx_add_test(F_leaky_relu) pnnx_onnx_add_test(F_linear) pnnx_onnx_add_test(F_local_response_norm) +pnnx_onnx_add_test(F_logsigmoid) +pnnx_onnx_add_test(F_log_softmax) pnnx_onnx_add_test(F_max_pool1d) pnnx_onnx_add_test(F_max_pool2d) pnnx_onnx_add_test(F_max_pool3d) +pnnx_onnx_add_test(F_mish) pnnx_onnx_add_test(F_pad) pnnx_onnx_add_test(F_prelu) pnnx_onnx_add_test(F_relu) pnnx_onnx_add_test(F_relu6) pnnx_onnx_add_test(F_scaled_dot_product_attention) +pnnx_onnx_add_test(F_selu) pnnx_onnx_add_test(F_sigmoid) +pnnx_onnx_add_test(F_silu) pnnx_onnx_add_test(F_softmax) +pnnx_onnx_add_test(F_softmin) +pnnx_onnx_add_test(F_softplus) +pnnx_onnx_add_test(F_softshrink) +pnnx_onnx_add_test(F_softsign) +pnnx_onnx_add_test(F_tanh) +pnnx_onnx_add_test(F_tanhshrink) pnnx_onnx_add_test(F_upsample_bilinear) pnnx_onnx_add_test(F_upsample_nearest) pnnx_onnx_add_test(F_upsample) @@ -74,10 +85,13 @@ pnnx_onnx_add_test(nn_LayerNorm) pnnx_onnx_add_test(nn_LeakyReLU) pnnx_onnx_add_test(nn_Linear) pnnx_onnx_add_test(nn_LocalResponseNorm) +pnnx_onnx_add_test(nn_LogSigmoid) +pnnx_onnx_add_test(nn_LogSoftmax) pnnx_onnx_add_test(nn_LSTM) pnnx_onnx_add_test(nn_MaxPool1d) pnnx_onnx_add_test(nn_MaxPool2d) pnnx_onnx_add_test(nn_MaxPool3d) +pnnx_onnx_add_test(nn_Mish) pnnx_onnx_add_test(nn_MultiheadAttention) pnnx_onnx_add_test(nn_PReLU) pnnx_onnx_add_test(nn_ReflectionPad1d) @@ -88,8 +102,16 @@ pnnx_onnx_add_test(nn_ReplicationPad1d) pnnx_onnx_add_test(nn_ReplicationPad2d) pnnx_onnx_add_test(nn_ReplicationPad3d) pnnx_onnx_add_test(nn_RNN) +pnnx_onnx_add_test(nn_SELU) pnnx_onnx_add_test(nn_Sigmoid) +pnnx_onnx_add_test(nn_SiLU) pnnx_onnx_add_test(nn_Softmax) +pnnx_onnx_add_test(nn_Softmin) +pnnx_onnx_add_test(nn_Softplus) +pnnx_onnx_add_test(nn_Softshrink) +pnnx_onnx_add_test(nn_Softsign) +pnnx_onnx_add_test(nn_Tanh) +pnnx_onnx_add_test(nn_Tanhshrink) pnnx_onnx_add_test(nn_Upsample) pnnx_onnx_add_test(nn_UpsamplingBilinear2d) pnnx_onnx_add_test(nn_UpsamplingNearest2d) diff --git a/tools/pnnx/tests/onnx/test_F_log_softmax.py b/tools/pnnx/tests/onnx/test_F_log_softmax.py new file mode 100644 index 00000000000..8bc657c6778 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_log_softmax.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.log_softmax(x, 1) + y = F.log_softmax(y, 0) + z = F.log_softmax(z, 2) + w = F.log_softmax(w, 3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_log_softmax.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_log_softmax.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_log_softmax_pnnx + b = test_F_log_softmax_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_logsigmoid.py b/tools/pnnx/tests/onnx/test_F_logsigmoid.py new file mode 100644 index 00000000000..a731936a109 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_logsigmoid.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.logsigmoid(x) + y = F.logsigmoid(y) + z = F.logsigmoid(z) + w = F.logsigmoid(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_logsigmoid.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_logsigmoid.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_logsigmoid_pnnx + b = test_F_logsigmoid_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_mish.py b/tools/pnnx/tests/onnx/test_F_mish.py new file mode 100644 index 00000000000..69026d38b2b --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_mish.py @@ -0,0 +1,76 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +def mish_forward_0(x): + return x * F.softplus(x).tanh() + +def mish_forward_1(x): + return x.mul(torch.tanh(F.softplus(x))) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.mish(x) + y = F.mish(y) + z = mish_forward_0(z) + w = mish_forward_1(w) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.9'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_mish.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_mish.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_mish_pnnx + b = test_F_mish_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_selu.py b/tools/pnnx/tests/onnx/test_F_selu.py new file mode 100644 index 00000000000..e70f9344191 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_selu.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.selu(x) + y = F.selu(y) + z = F.selu(z) + w = F.selu(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_selu.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_selu.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_selu_pnnx + b = test_F_selu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_sigmoid.py b/tools/pnnx/tests/onnx/test_F_sigmoid.py index 684a7ab48d9..c90e570e005 100644 --- a/tools/pnnx/tests/onnx/test_F_sigmoid.py +++ b/tools/pnnx/tests/onnx/test_F_sigmoid.py @@ -41,7 +41,7 @@ def test(): z = torch.rand(1, 3, 12, 16) w = torch.rand(1, 5, 7, 9, 11) - a0, a1, a2, a3 = net(x, y, z, w) + a = net(x, y, z, w) # export onnx torch.onnx.export(net, (x, y, z, w), "test_F_sigmoid.onnx") @@ -52,9 +52,12 @@ def test(): # pnnx inference import test_F_sigmoid_pnnx - b0, b1, b2, b3 = test_F_sigmoid_pnnx.test_inference() + b = test_F_sigmoid_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/onnx/test_F_silu.py b/tools/pnnx/tests/onnx/test_F_silu.py new file mode 100644 index 00000000000..d6cc987262e --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_silu.py @@ -0,0 +1,69 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def silu_forward_0(x): + return x * torch.sigmoid(x) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.silu(x) + y = F.silu(y) + z = F.silu(z) + w = silu_forward_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_silu.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_silu.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_silu_pnnx + b = test_F_silu_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_softmin.py b/tools/pnnx/tests/onnx/test_F_softmin.py new file mode 100644 index 00000000000..88a82fea00a --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_softmin.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.softmin(x, 1) + y = F.softmin(y, 0) + z = F.softmin(z, 2) + w = F.softmin(w, 3) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_softmin.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_softmin.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softmin_pnnx + b = test_F_softmin_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_softplus.py b/tools/pnnx/tests/onnx/test_F_softplus.py new file mode 100644 index 00000000000..c261f58d67c --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_softplus.py @@ -0,0 +1,70 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.softplus(x) + y = F.softplus(y, 2, 5.2) + z = F.softplus(z, -0.7, 15) + w = F.softplus(w, 0.1, 0.3) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_softplus.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_softplus.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softplus_pnnx + b = test_F_softplus_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_softshrink.py b/tools/pnnx/tests/onnx/test_F_softshrink.py new file mode 100644 index 00000000000..7f1fb883807 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_softshrink.py @@ -0,0 +1,70 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.softshrink(x) + y = F.softshrink(y, 0.1) + z = F.softshrink(z, 0.22) + w = F.softshrink(w, 0) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_softshrink.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_softshrink.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softshrink_pnnx + b = test_F_softshrink_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_softsign.py b/tools/pnnx/tests/onnx/test_F_softsign.py new file mode 100644 index 00000000000..27164f3dfc1 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_softsign.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.softsign(x) + y = F.softsign(y) + z = F.softsign(z) + w = F.softsign(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_softsign.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_softsign.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_softsign_pnnx + b = test_F_softsign_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_tanh.py b/tools/pnnx/tests/onnx/test_F_tanh.py new file mode 100644 index 00000000000..b56d513f655 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_tanh.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.tanh(x) + y = F.tanh(y) + z = F.tanh(z) + w = F.tanh(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_tanh.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_tanh.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_tanh_pnnx + b = test_F_tanh_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_F_tanhshrink.py b/tools/pnnx/tests/onnx/test_F_tanhshrink.py new file mode 100644 index 00000000000..7be2bf57cb1 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_F_tanhshrink.py @@ -0,0 +1,66 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = F.tanhshrink(x) + y = F.tanhshrink(y) + z = F.tanhshrink(z) + w = F.tanhshrink(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(12, 2, 16) + z = torch.rand(1, 3, 12, 16) + w = torch.rand(1, 5, 7, 9, 11) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_F_tanhshrink.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_F_tanhshrink.onnx inputshape=[1,16],[12,2,16],[1,3,12,16],[1,5,7,9,11]") + + # pnnx inference + import test_F_tanhshrink_pnnx + b = test_F_tanhshrink_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_LogSigmoid.py b/tools/pnnx/tests/onnx/test_nn_LogSigmoid.py new file mode 100644 index 00000000000..ddb44cbf442 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_LogSigmoid.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LogSigmoid() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_LogSigmoid.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_LogSigmoid.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_LogSigmoid_pnnx + b = test_nn_LogSigmoid_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_LogSoftmax.py b/tools/pnnx/tests/onnx/test_nn_LogSoftmax.py new file mode 100644 index 00000000000..dbe8dc96d82 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_LogSoftmax.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.LogSoftmax(dim=1) + self.act_1 = nn.LogSoftmax(dim=1) + self.act_2 = nn.LogSoftmax(dim=0) + self.act_3 = nn.LogSoftmax(dim=2) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_1(y) + z = self.act_2(z) + w = self.act_3(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_LogSoftmax.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_LogSoftmax.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_LogSoftmax_pnnx + b = test_nn_LogSoftmax_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Mish.py b/tools/pnnx/tests/onnx/test_nn_Mish.py new file mode 100644 index 00000000000..481ba718111 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Mish.py @@ -0,0 +1,72 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Mish() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.9'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Mish.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Mish.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Mish_pnnx + b = test_nn_Mish_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_SELU.py b/tools/pnnx/tests/onnx/test_nn_SELU.py new file mode 100644 index 00000000000..a78c9e2336f --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_SELU.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.SELU() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_SELU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_SELU.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_SELU_pnnx + b = test_nn_SELU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_SiLU.py b/tools/pnnx/tests/onnx/test_nn_SiLU.py new file mode 100644 index 00000000000..e509ddb6754 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_SiLU.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.SiLU() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_SiLU.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_SiLU.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_SiLU_pnnx + b = test_nn_SiLU_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Sigmoid.py b/tools/pnnx/tests/onnx/test_nn_Sigmoid.py index 5b9cfc9a2be..72d5d798ef4 100644 --- a/tools/pnnx/tests/onnx/test_nn_Sigmoid.py +++ b/tools/pnnx/tests/onnx/test_nn_Sigmoid.py @@ -43,7 +43,7 @@ def test(): z = torch.rand(1, 12, 24, 64) w = torch.rand(1, 12, 24, 32, 64) - a0, a1, a2, a3 = net(x, y, z, w) + a = net(x, y, z, w) # export onnx torch.onnx.export(net, (x, y, z, w), "test_nn_Sigmoid.onnx") @@ -54,9 +54,12 @@ def test(): # pnnx inference import test_nn_Sigmoid_pnnx - b0, b1, b2, b3 = test_nn_Sigmoid_pnnx.test_inference() + b = test_nn_Sigmoid_pnnx.test_inference() - return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/onnx/test_nn_Softmin.py b/tools/pnnx/tests/onnx/test_nn_Softmin.py new file mode 100644 index 00000000000..9cb8417f2f6 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Softmin.py @@ -0,0 +1,71 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softmin(dim=1) + self.act_1 = nn.Softmin(dim=1) + self.act_2 = nn.Softmin(dim=0) + self.act_3 = nn.Softmin(dim=2) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_1(y) + z = self.act_2(z) + w = self.act_3(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Softmin.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Softmin.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softmin_pnnx + b = test_nn_Softmin_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Softplus.py b/tools/pnnx/tests/onnx/test_nn_Softplus.py new file mode 100644 index 00000000000..445c6341b29 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Softplus.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softplus() + self.act_1 = nn.Softplus(beta=0.7, threshold=15) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Softplus.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Softplus.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softplus_pnnx + b = test_nn_Softplus_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Softshrink.py b/tools/pnnx/tests/onnx/test_nn_Softshrink.py new file mode 100644 index 00000000000..b86e9239c16 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Softshrink.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softshrink() + self.act_1 = nn.Softshrink(lambd=1.3) + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_1(z) + w = self.act_1(w) + return x, y, z, w + +def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Softshrink.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Softshrink.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softshrink_pnnx + b = test_nn_Softshrink_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Softsign.py b/tools/pnnx/tests/onnx/test_nn_Softsign.py new file mode 100644 index 00000000000..da86752ca67 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Softsign.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Softsign() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Softsign.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Softsign.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Softsign_pnnx + b = test_nn_Softsign_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Tanh.py b/tools/pnnx/tests/onnx/test_nn_Tanh.py new file mode 100644 index 00000000000..083275d277f --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Tanh.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Tanh() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Tanh.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Tanh.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Tanh_pnnx + b = test_nn_Tanh_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/onnx/test_nn_Tanhshrink.py b/tools/pnnx/tests/onnx/test_nn_Tanhshrink.py new file mode 100644 index 00000000000..20cabe2559a --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_Tanhshrink.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.act_0 = nn.Tanhshrink() + + def forward(self, x, y, z, w): + x = x * 2 - 1 + y = y * 2 - 1 + z = z * 2 - 1 + w = w * 2 - 1 + x = self.act_0(x) + y = self.act_0(y) + z = self.act_0(z) + w = self.act_0(w) + return x, y, z, w + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12) + y = torch.rand(1, 12, 64) + z = torch.rand(1, 12, 24, 64) + w = torch.rand(1, 12, 24, 32, 64) + + a = net(x, y, z, w) + + # export onnx + torch.onnx.export(net, (x, y, z, w), "test_nn_Tanhshrink.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_Tanhshrink.onnx inputshape=[1,12],[1,12,64],[1,12,24,64],[1,12,24,32,64]") + + # pnnx inference + import test_nn_Tanhshrink_pnnx + b = test_nn_Tanhshrink_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)