diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index a9493f414cfa..f5826ddae9d7 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -57,6 +57,8 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) qk_softmax->create_pipeline(opt); } + const int qdim = weight_data_size / embed_dim; + { const int embed_dim_per_head = embed_dim / num_heads; const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); @@ -72,7 +74,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(6, 1); // constantC pd.set(7, embed_dim); // M pd.set(8, 0); // N - pd.set(9, embed_dim); // K + pd.set(9, qdim); // K pd.set(10, 1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack @@ -158,7 +160,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) pd.set(5, 1); // constantB pd.set(6, 1); // constantC pd.set(7, 0); // M = outch - pd.set(8, embed_dim); // N = size + pd.set(8, qdim); // N = size pd.set(9, embed_dim); // K = maxk*inch pd.set(10, 4); // constant_broadcast_type_C = null pd.set(11, 0); // output_N1M diff --git a/src/layer/multiheadattention.cpp b/src/layer/multiheadattention.cpp index fa06a105ad2e..284801a2c7d3 100644 --- a/src/layer/multiheadattention.cpp +++ b/src/layer/multiheadattention.cpp @@ -36,7 +36,9 @@ int MultiHeadAttention::load_param(const ParamDict& pd) int MultiHeadAttention::load_model(const ModelBin& mb) { - q_weight_data = mb.load(weight_data_size, 0); + const int qdim = weight_data_size / embed_dim; + + q_weight_data = mb.load(embed_dim * qdim, 0); if (q_weight_data.empty()) return -100; @@ -60,11 +62,11 @@ int MultiHeadAttention::load_model(const ModelBin& mb) if (v_bias_data.empty()) return -100; - out_weight_data = mb.load(weight_data_size, 0); + out_weight_data = mb.load(qdim * embed_dim, 0); if (out_weight_data.empty()) return -100; - out_bias_data = mb.load(embed_dim, 1); + out_bias_data = mb.load(qdim, 1); if (out_bias_data.empty()) return -100; @@ -82,21 +84,32 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto const int src_seqlen = q_blob.h; const int dst_seqlen = k_blob.h; const int embed_dim_per_head = embed_dim / num_heads; + const int qdim = weight_data_size / embed_dim; // assert k_blob.h == v_blob.h Mat& top_blob = top_blobs[0]; - top_blob.create(embed_dim, src_seqlen, 4u, opt.blob_allocator); + top_blob.create(qdim, src_seqlen, 4u, opt.blob_allocator); if (top_blob.empty()) - return -1; + return -100; Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator); + if (xq.empty()) + return -100; Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator); + if (xk.empty()) + return -100; Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator); + if (xv.empty()) + return -100; Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); + if (xqk.empty()) + return -100; Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator); + if (xqkv.empty()) + return -100; const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); @@ -114,10 +127,10 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto for (int j = 0; j < embed_dim_per_head; j++) { const float* ptr = q_blob.row(i); - const float* kptr = (const float*)q_weight_data + embed_dim * (q * embed_dim_per_head + j); + const float* kptr = (const float*)q_weight_data + qdim * (q * embed_dim_per_head + j); float sum = q_bias_data[q * embed_dim_per_head + j]; - for (int k = 0; k < embed_dim; k++) + for (int k = 0; k < qdim; k++) { sum += *ptr++ * *kptr++; } @@ -286,7 +299,7 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto { float* outptr = top_blob.row(i); - for (int j = 0; j < embed_dim; j++) + for (int j = 0; j < qdim; j++) { const float* ptr = xqkv.channel(i); const float* kptr = (const float*)out_weight_data + embed_dim * j; diff --git a/src/layer/vulkan/multiheadattention_vulkan.cpp b/src/layer/vulkan/multiheadattention_vulkan.cpp index 48967de36977..f1d7ce3636b2 100644 --- a/src/layer/vulkan/multiheadattention_vulkan.cpp +++ b/src/layer/vulkan/multiheadattention_vulkan.cpp @@ -46,6 +46,7 @@ MultiHeadAttention_vulkan::MultiHeadAttention_vulkan() int MultiHeadAttention_vulkan::create_pipeline(const Option& opt) { const int embed_dim_per_head = embed_dim / num_heads; + const int qdim = weight_data_size / embed_dim; { const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); @@ -61,7 +62,7 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt) pd.set(6, 1); // constantC pd.set(7, embed_dim); // M pd.set(8, 0); // N - pd.set(9, embed_dim); // K + pd.set(9, qdim); // K pd.set(10, 1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M // pd.set(12, 1); // output_elempack @@ -220,7 +221,7 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt) pd.set(5, 1); // constantB pd.set(6, 1); // constantC pd.set(7, 0); // M = outch - pd.set(8, embed_dim); // N = size + pd.set(8, qdim); // N = size pd.set(9, embed_dim); // K = maxk*inch pd.set(10, 4); // constant_broadcast_type_C pd.set(11, 0); // output_N1M diff --git a/src/layer/x86/multiheadattention_x86.cpp b/src/layer/x86/multiheadattention_x86.cpp index 7d6dbd80c95b..db5f730aec2a 100644 --- a/src/layer/x86/multiheadattention_x86.cpp +++ b/src/layer/x86/multiheadattention_x86.cpp @@ -38,6 +38,8 @@ MultiHeadAttention_x86::MultiHeadAttention_x86() int MultiHeadAttention_x86::create_pipeline(const Option& opt) { + const int qdim = weight_data_size / embed_dim; + { const int embed_dim_per_head = embed_dim / num_heads; const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); @@ -53,7 +55,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(6, 1); // constantC pd.set(7, embed_dim); // M pd.set(8, 0); // N - pd.set(9, embed_dim); // K + pd.set(9, qdim); // K pd.set(10, 1); // constant_broadcast_type_C pd.set(11, 0); // output_N1M pd.set(12, 1); // output_elempack @@ -191,7 +193,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) pd.set(5, 1); // constantB pd.set(6, 1); // constantC pd.set(7, 0); // M = outch - pd.set(8, embed_dim); // N = size + pd.set(8, qdim); // N = size pd.set(9, embed_dim); // K = maxk*inch pd.set(10, 4); // constant_broadcast_type_C pd.set(11, 0); // output_N1M diff --git a/tests/test_multiheadattention.cpp b/tests/test_multiheadattention.cpp index ad29c6b98b0a..c509f8156e8f 100644 --- a/tests/test_multiheadattention.cpp +++ b/tests/test_multiheadattention.cpp @@ -14,27 +14,29 @@ #include "testutil.h" -static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int num_heads, int kdim, int vdim, int attn_mask) +static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask) { - int embed_dim = q.w; + const int qdim = q.w; + const int kdim = k.w; + const int vdim = v.w; ncnn::ParamDict pd; pd.set(0, embed_dim); pd.set(1, num_heads); - pd.set(2, embed_dim * embed_dim); + pd.set(2, embed_dim * qdim); pd.set(3, kdim); pd.set(4, vdim); pd.set(5, attn_mask); std::vector weights(8); - weights[0] = RandomMat(embed_dim * embed_dim); + weights[0] = RandomMat(embed_dim * qdim); weights[1] = RandomMat(embed_dim); weights[2] = RandomMat(embed_dim * kdim); weights[3] = RandomMat(embed_dim); weights[4] = RandomMat(embed_dim * vdim); weights[5] = RandomMat(embed_dim); - weights[6] = RandomMat(embed_dim * embed_dim); - weights[7] = RandomMat(embed_dim); + weights[6] = RandomMat(qdim * embed_dim); + weights[7] = RandomMat(qdim); std::vector as(3); as[0] = q; @@ -51,32 +53,33 @@ static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, num_heads, kdim, vdim, attn_mask); + fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask); } return ret; } -static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int num_heads, int kvdim) +static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads) { - int embed_dim = q.w; + const int qdim = q.w; + const int kvdim = kv.w; ncnn::ParamDict pd; pd.set(0, embed_dim); pd.set(1, num_heads); - pd.set(2, embed_dim * embed_dim); + pd.set(2, embed_dim * qdim); pd.set(3, kvdim); pd.set(4, kvdim); std::vector weights(8); - weights[0] = RandomMat(embed_dim * embed_dim); + weights[0] = RandomMat(embed_dim * qdim); weights[1] = RandomMat(embed_dim); weights[2] = RandomMat(embed_dim * kvdim); weights[3] = RandomMat(embed_dim); weights[4] = RandomMat(embed_dim * kvdim); weights[5] = RandomMat(embed_dim); - weights[6] = RandomMat(embed_dim * embed_dim); - weights[7] = RandomMat(embed_dim); + weights[6] = RandomMat(qdim * embed_dim); + weights[7] = RandomMat(qdim); std::vector as(2); as[0] = q; @@ -87,30 +90,32 @@ static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& k int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, num_heads, kvdim); + fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim); } return ret; } -static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads) +static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads) { - int embed_dim = a.w; + const int qdim = a.w; ncnn::ParamDict pd; pd.set(0, embed_dim); pd.set(1, num_heads); - pd.set(2, embed_dim * embed_dim); + pd.set(2, embed_dim * qdim); + pd.set(3, qdim); + pd.set(4, qdim); std::vector weights(8); - weights[0] = RandomMat(embed_dim * embed_dim); + weights[0] = RandomMat(embed_dim * qdim); weights[1] = RandomMat(embed_dim); - weights[2] = RandomMat(embed_dim * embed_dim); + weights[2] = RandomMat(embed_dim * qdim); weights[3] = RandomMat(embed_dim); - weights[4] = RandomMat(embed_dim * embed_dim); + weights[4] = RandomMat(embed_dim * qdim); weights[5] = RandomMat(embed_dim); - weights[6] = RandomMat(embed_dim * embed_dim); - weights[7] = RandomMat(embed_dim); + weights[6] = RandomMat(qdim * embed_dim); + weights[7] = RandomMat(qdim); std::vector as(1); as[0] = a; @@ -120,7 +125,7 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads) int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon); if (ret != 0) { - fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) num_heads=%d\n", a.w, a.h, num_heads); + fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads); } return ret; @@ -129,32 +134,32 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads) static int test_multiheadattention_0() { return 0 - || test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 2, 32, 20, 0) - || test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 2, 32, 18, 1) - || test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 4, 64, 64, 0) - || test_multiheadattention(RandomMat(64, 127), RandomMat(64, 127), RandomMat(64, 127), 16, 64, 64, 1) - || test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 2, 44, 55, 0) - || test_multiheadattention(RandomMat(16, 128), RandomMat(44, 127), RandomMat(55, 127), 4, 44, 55, 1) - || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 3, 28, 32, 0) - || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 3, 28, 11, 1); + || test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0) + || test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1) + || test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0) + || test_multiheadattention(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1) + || test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0) + || test_multiheadattention(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1) + || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0) + || test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1); } static int test_multiheadattention_1() { return 0 - || test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 4, 64) - || test_multiheadattention_samekv(RandomMat(64, 127), RandomMat(64, 127), 16, 64) - || test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 2, 44) - || test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(22, 127), 4, 22) - || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 3, 28) - || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 3, 11); + || test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4) + || test_multiheadattention_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16) + || test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2) + || test_multiheadattention_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4) + || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3) + || test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3); } static int test_multiheadattention_2() { return 0 - || test_multiheadattention_sameqkv(RandomMat(64, 128), 4) - || test_multiheadattention_sameqkv(RandomMat(64, 127), 8); + || test_multiheadattention_sameqkv(RandomMat(64, 128), 64, 4) + || test_multiheadattention_sameqkv(RandomMat(48, 127), 64, 8); } int main() diff --git a/tests/test_multiheadattention_oom.cpp b/tests/test_multiheadattention_oom.cpp new file mode 100644 index 000000000000..6b1d6ccbdc0f --- /dev/null +++ b/tests/test_multiheadattention_oom.cpp @@ -0,0 +1,74 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_multiheadattention_oom(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask) +{ + const int qdim = q.w; + const int kdim = k.w; + const int vdim = v.w; + + ncnn::ParamDict pd; + pd.set(0, embed_dim); + pd.set(1, num_heads); + pd.set(2, embed_dim * qdim); + pd.set(3, kdim); + pd.set(4, vdim); + pd.set(5, attn_mask); + + std::vector weights(8); + weights[0] = RandomMat(embed_dim * qdim); + weights[1] = RandomMat(embed_dim); + weights[2] = RandomMat(embed_dim * kdim); + weights[3] = RandomMat(embed_dim); + weights[4] = RandomMat(embed_dim * vdim); + weights[5] = RandomMat(embed_dim); + weights[6] = RandomMat(qdim * embed_dim); + weights[7] = RandomMat(qdim); + + std::vector as(3); + as[0] = q; + as[1] = k; + as[2] = v; + + if (attn_mask) + { + as.push_back(RandomMat(k.h, q.h)); + } + + int ret = test_layer_oom("MultiHeadAttention", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_multiheadattention_oom failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask); + } + + return ret; +} + +static int test_multiheadattention_0() +{ + return 0 + || test_multiheadattention_oom(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0) + || test_multiheadattention_oom(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1) + || test_multiheadattention_oom(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0) + || test_multiheadattention_oom(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1); +} + +int main() +{ + SRAND(7767517); + + return test_multiheadattention_0(); +}