Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Aug 15, 2024
1 parent c887cd0 commit 22b3c25
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level1/nn_RMSNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class RMSNorm : public FuseModulePass

op->params["normalized_shape"] = rmsn->namedInput("normalized_shape");
op->params["eps"] = rmsn->namedInput("eps");
op->params["elementwise_affine"] = mod.hasattr("weight") && mod.hasattr("bias");
op->params["elementwise_affine"] = mod.hasattr("weight");

if (mod.hasattr("weight"))
{
Expand Down
4 changes: 3 additions & 1 deletion tools/pnnx/src/pass_ncnn/F_rms_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ pnnx.Output output 1 0 out
affine_size *= normalized_shape[i];
}

const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f;

op->params["0"] = affine_size;
op->params["1"] = captured_params.at("eps");
op->params["1"] = eps;
op->params["2"] = 0;
}
};
Expand Down
4 changes: 3 additions & 1 deletion tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ pnnx.Output output 1 0 out
affine_size *= normalized_shape[i];
}

const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f;

op->params["0"] = affine_size;
op->params["1"] = captured_params.at("eps");
op->params["1"] = eps;
op->params["2"] = captured_params.at("elementwise_affine").b ? 1 : 0;

if (captured_params.at("elementwise_affine").b)
Expand Down
6 changes: 3 additions & 3 deletions tools/pnnx/tests/ncnn/test_F_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test():
net.eval()

torch.manual_seed(0)
x = torch.rand(12, 24)
y = torch.rand(3, 12, 16)
x = torch.rand(1, 12, 24)
y = torch.rand(1, 3, 12, 16)

a = net(x, y)

Expand All @@ -48,7 +48,7 @@ def test():

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[12,24],[3,12,16]")
os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[1,12,24],[1,3,12,16]")

# ncnn inference
import test_F_layer_norm_ncnn
Expand Down
8 changes: 5 additions & 3 deletions tools/pnnx/tests/ncnn/test_F_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test():
net.eval()

torch.manual_seed(0)
x = torch.rand(12, 24)
y = torch.rand(3, 12, 16)
x = torch.rand(1, 12, 24)
y = torch.rand(1, 3, 12, 16)

a = net(x, y)

Expand All @@ -50,14 +50,16 @@ def test():

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_rms_norm.pt inputshape=[12,24],[3,12,16]")
os.system("../../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[1,3,12,16]")

# ncnn inference
import test_F_rms_norm_ncnn
b = test_F_rms_norm_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
print(a0)
print(b0)
return False
return True

Expand Down
6 changes: 3 additions & 3 deletions tools/pnnx/tests/ncnn/test_nn_LayerNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def test():
net.eval()

torch.manual_seed(0)
x = torch.rand(24, 64)
y = torch.rand(12, 24, 64)
x = torch.rand(1, 24, 64)
y = torch.rand(1, 12, 24, 64)

a = net(x, y)

Expand All @@ -47,7 +47,7 @@ def test():

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[24,64],[12,24,64]")
os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[1,24,64],[1,12,24,64]")

# ncnn inference
import test_nn_LayerNorm_ncnn
Expand Down
8 changes: 3 additions & 5 deletions tools/pnnx/tests/ncnn/test_nn_RMSNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test():
net.eval()

torch.manual_seed(0)
x = torch.rand(24, 64)
y = torch.rand(12, 24, 64)
x = torch.rand(1, 24, 64)
y = torch.rand(1, 12, 24, 64)

a = net(x, y)

Expand All @@ -50,16 +50,14 @@ def test():

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_RMSNorm.pt inputshape=[24,64],[12,24,64]")
os.system("../../src/pnnx test_nn_RMSNorm.pt inputshape=[1,24,64],[1,12,24,64]")

# ncnn inference
import test_nn_RMSNorm_ncnn
b = test_nn_RMSNorm_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
print(a0)
print(b0)
return False
return True

Expand Down

0 comments on commit 22b3c25

Please sign in to comment.