Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pnnx fuse onnx sdpa pattern and ncnn qdim mha fusion #5589

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_prelu.cpp
pass_ncnn/F_relu.cpp
pass_ncnn/F_relu6.cpp
pass_ncnn/F_scaled_dot_product_attention.cpp
pass_ncnn/F_selu.cpp
pass_ncnn/F_sigmoid.cpp
pass_ncnn/F_silu.cpp
Expand Down
84 changes: 82 additions & 2 deletions tools/pnnx/src/pass_level5/fuse_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pnnx.Output output 1 0 out
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
F.scaled_dot_product_attention op_0 3 1 query key value out attn_mask=None dropout_p=0.0 is_causal=False
F.scaled_dot_product_attention sdpa 3 1 query key value out attn_mask=None dropout_p=0.0 is_causal=False
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand Down Expand Up @@ -114,7 +114,7 @@ pnnx.Input input_Rh 0 1 Rh
pnnx.Input input_Rw 0 1 Rw
pnnx.Expression RhRw 2 1 Rh Rw RhRw expr=add(@0,@1) #RhRw=(%batch,%h,%w,%h,%w)f32
Tensor.reshape attn_mask 1 1 RhRw attn_mask shape=(%batch,%qsize,%qsize) #attn_mask=(%batch,%qsize,%qsize)f32
F.scaled_dot_product_attention op_0 4 1 query key value attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
F.scaled_dot_product_attention sdpa 4 1 query key value attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -137,15 +137,95 @@ pnnx.Output output 1 0 out
}
};

class fuse_scaled_dot_product_attention_pass_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
12 11
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_3 0 1 attn_mask
Tensor.permute op_0 1 1 query 13 dims=(0,2,1,3)
Tensor.permute op_1 1 1 key 20 dims=(0,2,3,1)
Tensor.permute op_2 1 1 value 19 dims=(0,2,1,3)
torch.matmul op_3 2 1 13 20 21
pnnx.Expression op_4 2 1 21 attn_mask 23 expr=add(@0,@1)
F.softmax softmax 1 1 23 24 dim=%softmax_dim
torch.matmul op_6 2 1 24 19 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
9 8
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_3 0 1 attn_mask
Tensor.permute op_0 1 1 query q dims=(0,2,1,3)
Tensor.permute op_1 1 1 key k dims=(0,2,1,3)
Tensor.permute op_2 1 1 value v dims=(0,2,1,3)
F.scaled_dot_product_attention sdpa 4 1 q k v attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
const int softmax_dim = captured_params.at("softmax_dim").i;

int softmax_input_rank = (int)matched_operators.at("softmax")->inputs[0]->shape.size();
if (softmax_dim != -1 && softmax_dim != softmax_input_rank - 1)
return false;

return true;
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& /*captured_attrs*/) const
{
Operator* op = ops.at("sdpa");

op->params["scale"] = 1.f;

// rewrite qkv shape
{
std::vector<int> q_shape = ops.at("op_0")->inputs[0]->shape;
std::vector<int> k_shape = ops.at("op_1")->inputs[0]->shape;
std::vector<int> v_shape = ops.at("op_2")->inputs[0]->shape;

if (!q_shape.empty())
std::swap(q_shape[1], q_shape[2]);
if (!k_shape.empty())
std::swap(k_shape[1], k_shape[2]);
if (!v_shape.empty())
std::swap(v_shape[1], v_shape[2]);

ops.at("op_0")->outputs[0]->shape = q_shape;
ops.at("op_0")->outputs[0]->type = ops.at("op_0")->inputs[0]->type;
ops.at("op_1")->outputs[0]->shape = k_shape;
ops.at("op_1")->outputs[0]->type = ops.at("op_1")->inputs[0]->type;
ops.at("op_2")->outputs[0]->shape = v_shape;
ops.at("op_2")->outputs[0]->type = ops.at("op_2")->inputs[0]->type;
}
}
};

void fuse_scaled_dot_product_attention(Graph& graph)
{
#if TORCH_VERSION_MAJOR >= 2
fuse_scaled_dot_product_attention_pass a;
fuse_scaled_dot_product_attention_pass_1 b;
fuse_scaled_dot_product_attention_pass_onnx onnx0;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
pnnx_graph_rewrite(graph, &onnx0, opindex);
#endif
}

Expand Down
223 changes: 223 additions & 0 deletions tools/pnnx/src/pass_ncnn/F_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class F_scaled_dot_product_attention : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
16 15
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 attn_mask
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "MultiHeadAttention";
}

const char* name_str() const
{
return "sdpa_attention";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["0"] = captured_params.at("embed_dim");
op->params["1"] = captured_params.at("num_heads");

const int embed_dim = captured_params.at("embed_dim").i;
const int qdim = captured_params.at("qdim").i;
const int kdim = captured_params.at("kdim").i;
const int vdim = captured_params.at("vdim").i;

op->params["2"] = embed_dim * qdim;
op->params["3"] = kdim;
op->params["4"] = vdim;
op->params["5"] = 1;
op->params["6"] = captured_params.at("scale");

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = captured_attrs.at("op_0.weight");
if (captured_params.at("qbias").b)
{
op->attrs["2"] = captured_attrs.at("op_0.bias");
}
else
{
op->attrs["2"] = Attribute({embed_dim}, std::vector<float>(embed_dim, 0.f));
}
op->attrs["3"] = Attribute();
op->attrs["3"].data = {0, 0, 0, 0};
op->attrs["4"] = captured_attrs.at("op_1.weight");
if (captured_params.at("kbias").b)
{
op->attrs["5"] = captured_attrs.at("op_1.bias");
}
else
{
op->attrs["5"] = Attribute({embed_dim}, std::vector<float>(embed_dim, 0.f));
}
op->attrs["6"] = Attribute();
op->attrs["6"].data = {0, 0, 0, 0};
op->attrs["7"] = captured_attrs.at("op_2.weight");
if (captured_params.at("vbias").b)
{
op->attrs["8"] = captured_attrs.at("op_2.bias");
}
else
{
op->attrs["8"] = Attribute({embed_dim}, std::vector<float>(embed_dim, 0.f));
}
op->attrs["9"] = Attribute();
op->attrs["9"].data = {0, 0, 0, 0};
op->attrs["a"] = captured_attrs.at("out_proj.weight");
if (captured_params.at("outbias").b)
{
op->attrs["b"] = captured_attrs.at("out_proj.bias");
}
else
{
op->attrs["b"] = Attribute({qdim}, std::vector<float>(qdim, 0.f));
}
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10)

class F_scaled_dot_product_attention_1 : public F_scaled_dot_product_attention
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
17 16
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 kv
pnnx.Input input_2 0 1 attn_mask
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%qsize,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10)

class F_scaled_dot_product_attention_2 : public F_scaled_dot_product_attention
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 14
pnnx.Input input 0 1 input
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
F_scaled_dot_product_attention::write(op, captured_params, captured_attrs);
op->params["5"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_2, 10)

class F_scaled_dot_product_attention_3 : public F_scaled_dot_product_attention
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
16 15
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 kv
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
nn.Linear op_1 1 1 kv k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
nn.Linear op_2 1 1 kv v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
Tensor.reshape op_3 1 1 q 10 shape=(%batch,%qsize,%num_heads,%feat_per_head)
Tensor.reshape op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
F_scaled_dot_product_attention::write(op, captured_params, captured_attrs);
op->params["5"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_3, 10)

} // namespace ncnn

} // namespace pnnx
Loading