diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 113b7a962ec7..c15ba0973d67 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -154,6 +154,7 @@ set(pnnx_pass_level2_SRCS pass_level2/F_mish.cpp pass_level2/F_normalize.cpp pass_level2/F_pad.cpp + pass_level2/F_pairwise_distance.cpp pass_level2/F_pixel_shuffle.cpp pass_level2/F_pixel_unshuffle.cpp pass_level2/F_prelu.cpp diff --git a/tools/pnnx/src/pass_level2/F_pairwise_distance.cpp b/tools/pnnx/src/pass_level2/F_pairwise_distance.cpp new file mode 100644 index 000000000000..8177b25d52f1 --- /dev/null +++ b/tools/pnnx/src/pass_level2/F_pairwise_distance.cpp @@ -0,0 +1,44 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class F_pairwise_distance : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input_0 0 1 x1 +pnnx.Input input_1 0 1 x2 +prim::Constant op_0 0 1 p value=%p +prim::Constant op_1 0 1 eps value=%eps +prim::Constant op_2 0 1 keepdim value=%keepdim +aten::pairwise_distance op_3 5 1 x1 x2 p eps keepdim out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pairwise_distance"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pairwise_distance, 10) + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index e8cb3972b0c7..fab12342b58f 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -54,6 +54,7 @@ pnnx_add_test(F_max_pool2d) pnnx_add_test(F_max_pool3d) pnnx_add_test(F_normalize) pnnx_add_test(F_pad) +pnnx_add_test(F_pairwise_distance) pnnx_add_test(F_pixel_shuffle) pnnx_add_test(F_pixel_unshuffle) pnnx_add_test(F_prelu) diff --git a/tools/pnnx/tests/test_F_pairwise_distance.py b/tools/pnnx/tests/test_F_pairwise_distance.py new file mode 100644 index 000000000000..243f61e1b0e0 --- /dev/null +++ b/tools/pnnx/tests/test_F_pairwise_distance.py @@ -0,0 +1,58 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + z1 = F.pairwise_distance(x,y,p=1,keepdim=False) + z2 = F.pairwise_distance(x,y,p=2,keepdim=True) + z3 = F.pairwise_distance(x,y) + z4 = F.pairwise_distance(x,y,eps = 1e-3) + return z1,z2,z3,z4 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(12, 128, 128) + y = torch.rand(12, 128, 128) + + a0,a1,a2,a3 = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_F_pairwise_distance.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_F_pairwise_distance.pt inputshape=[12,128,128],[12,128,128]") + + # pnnx inference + import test_F_pairwise_distance_pnnx + b0,b1,b2,b3 = test_F_pairwise_distance_pnnx.test_inference() + + return torch.equal(a0,b0) and torch.equal(a1,b1) and torch.equal(a2,b2) and torch.equal(a3,b3) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)