From 22b3c2569b0f8057ab0a3f7efd873b2af081da1f Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 15 Aug 2024 15:05:00 +0800 Subject: [PATCH] fix --- tools/pnnx/src/pass_level1/nn_RMSNorm.cpp | 2 +- tools/pnnx/src/pass_ncnn/F_rms_norm.cpp | 4 +++- tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp | 4 +++- tools/pnnx/tests/ncnn/test_F_layer_norm.py | 6 +++--- tools/pnnx/tests/ncnn/test_F_rms_norm.py | 8 +++++--- tools/pnnx/tests/ncnn/test_nn_LayerNorm.py | 6 +++--- tools/pnnx/tests/ncnn/test_nn_RMSNorm.py | 8 +++----- 7 files changed, 21 insertions(+), 17 deletions(-) diff --git a/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp b/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp index 4433f5989353..498f0453c14f 100644 --- a/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp +++ b/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp @@ -37,7 +37,7 @@ class RMSNorm : public FuseModulePass op->params["normalized_shape"] = rmsn->namedInput("normalized_shape"); op->params["eps"] = rmsn->namedInput("eps"); - op->params["elementwise_affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + op->params["elementwise_affine"] = mod.hasattr("weight"); if (mod.hasattr("weight")) { diff --git a/tools/pnnx/src/pass_ncnn/F_rms_norm.cpp b/tools/pnnx/src/pass_ncnn/F_rms_norm.cpp index 14f592ff59b7..8230168312c2 100644 --- a/tools/pnnx/src/pass_ncnn/F_rms_norm.cpp +++ b/tools/pnnx/src/pass_ncnn/F_rms_norm.cpp @@ -50,8 +50,10 @@ pnnx.Output output 1 0 out affine_size *= normalized_shape[i]; } + const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f; + op->params["0"] = affine_size; - op->params["1"] = captured_params.at("eps"); + op->params["1"] = eps; op->params["2"] = 0; } }; diff --git a/tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp b/tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp index ef7552b8b6b9..7fda637c5cac 100644 --- a/tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp @@ -50,8 +50,10 @@ pnnx.Output output 1 0 out affine_size *= normalized_shape[i]; } + const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f; + op->params["0"] = affine_size; - op->params["1"] = captured_params.at("eps"); + op->params["1"] = eps; op->params["2"] = captured_params.at("elementwise_affine").b ? 1 : 0; if (captured_params.at("elementwise_affine").b) diff --git a/tools/pnnx/tests/ncnn/test_F_layer_norm.py b/tools/pnnx/tests/ncnn/test_F_layer_norm.py index 92244f179104..9d590aa76dda 100644 --- a/tools/pnnx/tests/ncnn/test_F_layer_norm.py +++ b/tools/pnnx/tests/ncnn/test_F_layer_norm.py @@ -37,8 +37,8 @@ def test(): net.eval() torch.manual_seed(0) - x = torch.rand(12, 24) - y = torch.rand(3, 12, 16) + x = torch.rand(1, 12, 24) + y = torch.rand(1, 3, 12, 16) a = net(x, y) @@ -48,7 +48,7 @@ def test(): # torchscript to pnnx import os - os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[12,24],[3,12,16]") + os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[1,12,24],[1,3,12,16]") # ncnn inference import test_F_layer_norm_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_rms_norm.py b/tools/pnnx/tests/ncnn/test_F_rms_norm.py index 4e737dfa34aa..637390e32836 100644 --- a/tools/pnnx/tests/ncnn/test_F_rms_norm.py +++ b/tools/pnnx/tests/ncnn/test_F_rms_norm.py @@ -39,8 +39,8 @@ def test(): net.eval() torch.manual_seed(0) - x = torch.rand(12, 24) - y = torch.rand(3, 12, 16) + x = torch.rand(1, 12, 24) + y = torch.rand(1, 3, 12, 16) a = net(x, y) @@ -50,7 +50,7 @@ def test(): # torchscript to pnnx import os - os.system("../../src/pnnx test_F_rms_norm.pt inputshape=[12,24],[3,12,16]") + os.system("../../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[1,3,12,16]") # ncnn inference import test_F_rms_norm_ncnn @@ -58,6 +58,8 @@ def test(): for a0, b0 in zip(a, b): if not torch.allclose(a0, b0, 1e-4, 1e-4): + print(a0) + print(b0) return False return True diff --git a/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py index a45444060d04..d409bdfba3a1 100644 --- a/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py +++ b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py @@ -36,8 +36,8 @@ def test(): net.eval() torch.manual_seed(0) - x = torch.rand(24, 64) - y = torch.rand(12, 24, 64) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) a = net(x, y) @@ -47,7 +47,7 @@ def test(): # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[24,64],[12,24,64]") + os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[1,24,64],[1,12,24,64]") # ncnn inference import test_nn_LayerNorm_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py index 235c06011991..0d5efa211e4d 100644 --- a/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py +++ b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py @@ -39,8 +39,8 @@ def test(): net.eval() torch.manual_seed(0) - x = torch.rand(24, 64) - y = torch.rand(12, 24, 64) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) a = net(x, y) @@ -50,7 +50,7 @@ def test(): # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_RMSNorm.pt inputshape=[24,64],[12,24,64]") + os.system("../../src/pnnx test_nn_RMSNorm.pt inputshape=[1,24,64],[1,12,24,64]") # ncnn inference import test_nn_RMSNorm_ncnn @@ -58,8 +58,6 @@ def test(): for a0, b0 in zip(a, b): if not torch.allclose(a0, b0, 1e-4, 1e-4): - print(a0) - print(b0) return False return True