diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 2c814bd486cd..7743a8ae453e 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -369,6 +369,7 @@ set(pnnx_pass_level5_SRCS pass_level5/fuse_pixel_unshuffle.cpp pass_level5/fuse_layernorm.cpp pass_level5/fuse_multiheadattention.cpp + pass_level5/fuse_rmsnorm.cpp pass_level5/fuse_scaled_dot_product_attention.cpp pass_level5/fuse_select_to_unbind.cpp pass_level5/fuse_silu.cpp diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index 8bb3270aa2c3..5f08b80f5ef9 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -44,6 +44,7 @@ #include "pass_level5/fuse_multiheadattention.h" #include "pass_level5/fuse_pad_conv1d.h" #include "pass_level5/fuse_pad_conv2d.h" +#include "pass_level5/fuse_rmsnorm.h" #include "pass_level5/fuse_scaled_dot_product_attention.h" #include "pass_level5/fuse_select_to_unbind.h" #include "pass_level5/fuse_silu.h" @@ -145,6 +146,7 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_channel_shuffle(g); fuse_layernorm(g); + fuse_rmsnorm(g); fuse_multiheadattention(g); fuse_scaled_dot_product_attention(g); diff --git a/tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp b/tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp new file mode 100644 index 000000000000..7b99770ed6ed --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp @@ -0,0 +1,97 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_rmsnorm.h" + +#include "pass_level2.h" + +#include +#include + +namespace pnnx { + +class fuse_rmsnorm_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32 +pnnx.Expression op_1 1 1 input sq expr=pow(@0,2) +torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True +pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,rsqrt(add(@2,%eps)))) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.RMSNorm rmsnorm 1 1 input out elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_rmsnorm_pass_1 : public fuse_rmsnorm_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32 +pnnx.Expression op_1 1 1 input sq expr=pow(@0,2.000000e+00) +torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True +pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,reciprocal(sqrt(add(@2,%eps))))) +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_rmsnorm_pass_onnx : public fuse_rmsnorm_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32 +pnnx.Expression op_1 1 1 input sq expr=pow(@0,2.000000e+00) +torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True +pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,div(1.000000e+00,sqrt(add(@2,%eps))))) +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +void fuse_rmsnorm(Graph& graph) +{ + fuse_rmsnorm_pass a; + fuse_rmsnorm_pass_1 a1; + fuse_rmsnorm_pass_onnx b; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &a1, opindex); + pnnx_graph_rewrite(graph, &b, opindex); +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_rmsnorm.h b/tools/pnnx/src/pass_level5/fuse_rmsnorm.h new file mode 100644 index 000000000000..0ba18e37f61b --- /dev/null +++ b/tools/pnnx/src/pass_level5/fuse_rmsnorm.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void fuse_rmsnorm(Graph& graph); + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index daf5501e9d8b..0dd566c37b58 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -346,6 +346,7 @@ pnnx_add_test(pnnx_fuse_input_unpack) pnnx_add_test(pnnx_fuse_layernorm) pnnx_add_test(pnnx_fuse_linear_batchnorm1d) pnnx_add_test(pnnx_fuse_multiheadattention) +pnnx_add_test(pnnx_fuse_rmsnorm) pnnx_add_test(pnnx_fuse_scaled_dot_product_attention) pnnx_add_test(pnnx_fuse_select_to_unbind) pnnx_add_test(pnnx_fuse_slice_to_tensor_split) diff --git a/tools/pnnx/tests/ncnn/test_F_rms_norm.py b/tools/pnnx/tests/ncnn/test_F_rms_norm.py index 4e60d9314aae..f30f72f9ac45 100644 --- a/tools/pnnx/tests/ncnn/test_F_rms_norm.py +++ b/tools/pnnx/tests/ncnn/test_F_rms_norm.py @@ -57,7 +57,7 @@ def test(): b = test_F_rms_norm_ncnn.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-4, 1e-4): + if not torch.allclose(a0, b0, 1e-3, 1e-3): return False return True diff --git a/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py index 0d5efa211e4d..e69ad1220bc1 100644 --- a/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py +++ b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py @@ -57,7 +57,7 @@ def test(): b = test_nn_RMSNorm_ncnn.test_inference() for a0, b0 in zip(a, b): - if not torch.allclose(a0, b0, 1e-4, 1e-4): + if not torch.allclose(a0, b0, 1e-3, 1e-3): return False return True diff --git a/tools/pnnx/tests/test_pnnx_fuse_rmsnorm.py b/tools/pnnx/tests/test_pnnx_fuse_rmsnorm.py new file mode 100644 index 000000000000..b04fa93442fa --- /dev/null +++ b/tools/pnnx/tests/test_pnnx_fuse_rmsnorm.py @@ -0,0 +1,77 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.rand(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * x + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.rmsn_0 = T5LayerNorm(26) + self.rmsn_1 = T5LayerNorm(21) + + def forward(self, x, y): + x = self.rmsn_0(x) + y = self.rmsn_1(y) + return x, y + +def test(): + if version.parse(torch.__version__) < version.parse('2.4'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 64, 26) + y = torch.rand(3, 15, 15, 21) + + a0, a1 = net(x, y) + + # export onnx + torch.onnx.export(net, (x,y), "test.onnx") + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_pnnx_fuse_rmsnorm.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_pnnx_fuse_rmsnorm.pt inputshape=[1,64,26],[3,15,15,21]") + + # pnnx inference + import test_pnnx_fuse_rmsnorm_pnnx + b0, b1 = test_pnnx_fuse_rmsnorm_pnnx.test_inference() + + return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)