Skip to content

Commit

Permalink
pnnx fuse onnx sdpa pattern and ncnn qdim mha fusion (#5589)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jul 18, 2024
1 parent 997c892 commit f825d3a
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 2 deletions.
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_prelu.cpp
pass_ncnn/F_relu.cpp
pass_ncnn/F_relu6.cpp
pass_ncnn/F_scaled_dot_product_attention.cpp
pass_ncnn/F_selu.cpp
pass_ncnn/F_sigmoid.cpp
pass_ncnn/F_silu.cpp
Expand Down
84 changes: 82 additions & 2 deletions tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pnnx.Output output 1 0 out
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
F.scaled_dot_product_attention op_0 3 1 query key value out attn_mask=None dropout_p=0.0 is_causal=False
F.scaled_dot_product_attention sdpa 3 1 query key value out attn_mask=None dropout_p=0.0 is_causal=False
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand Down Expand Up @@ -114,7 +114,7 @@ pnnx.Input input_Rh 0 1 Rh
pnnx.Input input_Rw 0 1 Rw
pnnx.Expression RhRw 2 1 Rh Rw RhRw expr=add(@0,@1) #RhRw=(%batch,%h,%w,%h,%w)f32
Tensor.reshape attn_mask 1 1 RhRw attn_mask shape=(%batch,%qsize,%qsize) #attn_mask=(%batch,%qsize,%qsize)f32
F.scaled_dot_product_attention op_0 4 1 query key value attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
F.scaled_dot_product_attention sdpa 4 1 query key value attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -137,15 +137,95 @@ pnnx.Output output 1 0 out
}
};

class fuse_scaled_dot_product_attention_pass_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
12 11
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_3 0 1 attn_mask
Tensor.permute op_0 1 1 query 13 dims=(0,2,1,3)
Tensor.permute op_1 1 1 key 20 dims=(0,2,3,1)
Tensor.permute op_2 1 1 value 19 dims=(0,2,1,3)
torch.matmul op_3 2 1 13 20 21
pnnx.Expression op_4 2 1 21 attn_mask 23 expr=add(@0,@1)
F.softmax softmax 1 1 23 24 dim=%softmax_dim
torch.matmul op_6 2 1 24 19 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
9 8
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_3 0 1 attn_mask
Tensor.permute op_0 1 1 query q dims=(0,2,1,3)
Tensor.permute op_1 1 1 key k dims=(0,2,1,3)
Tensor.permute op_2 1 1 value v dims=(0,2,1,3)
F.scaled_dot_product_attention sdpa 4 1 q k v attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
const int softmax_dim = captured_params.at("softmax_dim").i;

int softmax_input_rank = (int)matched_operators.at("softmax")->inputs[0]->shape.size();
if (softmax_dim != -1 && softmax_dim != softmax_input_rank - 1)
return false;

return true;
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
Operator* op = ops.at("sdpa");

op->params["scale"] = 1.f;

// rewrite qkv shape
{
std::vector<int> q_shape = ops.at("op_0")->inputs[0]->shape;
std::vector<int> k_shape = ops.at("op_1")->inputs[0]->shape;
std::vector<int> v_shape = ops.at("op_2")->inputs[0]->shape;

if (!q_shape.empty())
std::swap(q_shape[1], q_shape[2]);
if (!k_shape.empty())
std::swap(k_shape[1], k_shape[2]);
if (!v_shape.empty())
std::swap(v_shape[1], v_shape[2]);

ops.at("op_0")->outputs[0]->shape = q_shape;
ops.at("op_0")->outputs[0]->type = ops.at("op_0")->inputs[0]->type;
ops.at("op_1")->outputs[0]->shape = k_shape;
ops.at("op_1")->outputs[0]->type = ops.at("op_1")->inputs[0]->type;
ops.at("op_2")->outputs[0]->shape = v_shape;
ops.at("op_2")->outputs[0]->type = ops.at("op_2")->inputs[0]->type;
}
}
};

void fuse_scaled_dot_product_attention(Graph& graph)
{
#if TORCH_VERSION_MAJOR >= 2
fuse_scaled_dot_product_attention_pass a;
fuse_scaled_dot_product_attention_pass_1 b;
fuse_scaled_dot_product_attention_pass_onnx onnx0;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &onnx0, opindex);
#endif
}

Expand Down
223 changes: 223 additions & 0 deletions tools/pnnx/src/pass_ncnn/F_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class F_scaled_dot_product_attention : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
16 15
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 attn_mask
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "MultiHeadAttention";
}

const char* name_str() const
{
return "sdpa_attention";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["0"] = captured_params.at("embed_dim");
op->params["1"] = captured_params.at("num_heads");

const int embed_dim = captured_params.at("embed_dim").i;
const int qdim = captured_params.at("qdim").i;
const int kdim = captured_params.at("kdim").i;
const int vdim = captured_params.at("vdim").i;

op->params["2"] = embed_dim * qdim;
op->params["3"] = kdim;
op->params["4"] = vdim;
op->params["5"] = 1;
op->params["6"] = captured_params.at("scale");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = captured_attrs.at("op_0.weight");
if (captured_params.at("qbias").b)
{
op->attrs["2"] = captured_attrs.at("op_0.bias");
}
else
{
op->attrs["2"] = Attribute({embed_dim}, std::vector<float>(embed_dim, 0.f));
}
op->attrs["3"] = Attribute();
op->attrs["3"].data = {0, 0, 0, 0};
op->attrs["4"] = captured_attrs.at("op_1.weight");
if (captured_params.at("kbias").b)
{
op->attrs["5"] = captured_attrs.at("op_1.bias");
}
else
{
op->attrs["5"] = Attribute({embed_dim}, std::vector<float>(embed_dim, 0.f));
}
op->attrs["6"] = Attribute();
op->attrs["6"].data = {0, 0, 0, 0};
op->attrs["7"] = captured_attrs.at("op_2.weight");
if (captured_params.at("vbias").b)
{
op->attrs["8"] = captured_attrs.at("op_2.bias");
}
else
{
op->attrs["8"] = Attribute({embed_dim}, std::vector<float>(embed_dim, 0.f));
}
op->attrs["9"] = Attribute();
op->attrs["9"].data = {0, 0, 0, 0};
op->attrs["a"] = captured_attrs.at("out_proj.weight");
if (captured_params.at("outbias").b)
{
op->attrs["b"] = captured_attrs.at("out_proj.bias");
}
else
{
op->attrs["b"] = Attribute({qdim}, std::vector<float>(qdim, 0.f));
}
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10)

class F_scaled_dot_product_attention_1 : public F_scaled_dot_product_attention
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
17 16
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 kv
pnnx.Input input_2 0 1 attn_mask
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%qsize,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10)

class F_scaled_dot_product_attention_2 : public F_scaled_dot_product_attention
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 14
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
F_scaled_dot_product_attention::write(op, captured_params, captured_attrs);
op->params["5"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_2, 10)

class F_scaled_dot_product_attention_3 : public F_scaled_dot_product_attention
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
16 15
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 kv
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%qsize,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
F_scaled_dot_product_attention::write(op, captured_params, captured_attrs);
op->params["5"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_3, 10)

} // namespace ncnn

} // namespace pnnx

0 comments on commit f825d3a

Please sign in to comment.