Skip to content

Commit

Permalink
pnnx fuse t5-layernorm as rmsnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 3, 2024
1 parent 204583b commit 94eb6b9
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 0 deletions.
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_pixel_unshuffle.cpp
pass_level5/fuse_layernorm.cpp
pass_level5/fuse_multiheadattention.cpp
pass_level5/fuse_rmsnorm.cpp
pass_level5/fuse_scaled_dot_product_attention.cpp
pass_level5/fuse_select_to_unbind.cpp
pass_level5/fuse_silu.cpp
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/src/pass_level5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "pass_level5/fuse_multiheadattention.h"
#include "pass_level5/fuse_pad_conv1d.h"
#include "pass_level5/fuse_pad_conv2d.h"
#include "pass_level5/fuse_rmsnorm.h"
#include "pass_level5/fuse_scaled_dot_product_attention.h"
#include "pass_level5/fuse_select_to_unbind.h"
#include "pass_level5/fuse_silu.h"
Expand Down Expand Up @@ -145,6 +146,7 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_channel_shuffle(g);
fuse_layernorm(g);
fuse_rmsnorm(g);
fuse_multiheadattention(g);
fuse_scaled_dot_product_attention(g);

Expand Down
79 changes: 79 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_rmsnorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_rmsnorm.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_rmsnorm_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32
pnnx.Expression op_1 1 1 input sq expr=pow(@0,2)
torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True
pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,rsqrt(add(@2,%eps))))
pnnx.Output output 1 0 out
)PNNXIR";
}


const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.RMSNorm rmsnorm 1 1 input out elementwise_affine=True eps=%eps normalized_shape=(%c) @weight=%op_0.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

class fuse_rmsnorm_pass_onnx : public fuse_rmsnorm_pass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
pnnx.Attribute op_0 0 1 weight @data #weight=(%c)f32
pnnx.Expression op_1 1 1 input sq expr=pow(@0,2.000000e+00)
torch.mean op_2 1 1 sq sqmean dim=(-1) keepdim=True
pnnx.Expression op_3 3 1 weight input sqmean out expr=mul(@0,mul(@1,div(1.000000e+00,sqrt(add(@2,%eps)))))
pnnx.Output output 1 0 out
)PNNXIR";
}
};

void fuse_rmsnorm(Graph& graph)
{
fuse_rmsnorm_pass a;
fuse_rmsnorm_pass_onnx b;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
}

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_rmsnorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_rmsnorm(Graph& graph);

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ pnnx_add_test(pnnx_fuse_input_unpack)
pnnx_add_test(pnnx_fuse_layernorm)
pnnx_add_test(pnnx_fuse_linear_batchnorm1d)
pnnx_add_test(pnnx_fuse_multiheadattention)
pnnx_add_test(pnnx_fuse_rmsnorm)
pnnx_add_test(pnnx_fuse_scaled_dot_product_attention)
pnnx_add_test(pnnx_fuse_select_to_unbind)
pnnx_add_test(pnnx_fuse_slice_to_tensor_split)
Expand Down
73 changes: 73 additions & 0 deletions tools/pnnx/tests/test_pnnx_fuse_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.rand(hidden_size))
self.variance_epsilon = eps

def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * x

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.rmsn_0 = T5LayerNorm(26)
self.rmsn_1 = T5LayerNorm(21)

def forward(self, x, y):
x = self.rmsn_0(x)
y = self.rmsn_1(y)
return x, y

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 64, 26)
y = torch.rand(3, 15, 15, 21)

a0, a1 = net(x, y)

# export onnx
torch.onnx.export(net, (x,y), "test.onnx")

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_pnnx_fuse_rmsnorm.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_pnnx_fuse_rmsnorm.pt inputshape=[1,64,26],[3,15,15,21]")

# pnnx inference
import test_pnnx_fuse_rmsnorm_pnnx
b0, b1 = test_pnnx_fuse_rmsnorm_pnnx.test_inference()

return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

0 comments on commit 94eb6b9

Please sign in to comment.