Skip to content

Commit

Permalink
mha allow qdim differs from embed_dim (#5519)
Browse files Browse the repository at this point in the history
* test mha oom
  • Loading branch information
nihui authored Jun 19, 2024
1 parent 2828e7a commit 8235cad
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 53 deletions.
6 changes: 4 additions & 2 deletions src/layer/arm/multiheadattention_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
qk_softmax->create_pipeline(opt);
}

const int qdim = weight_data_size / embed_dim;

{
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);
Expand All @@ -72,7 +74,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, embed_dim); // K
pd.set(9, qdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
Expand Down Expand Up @@ -158,7 +160,7 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, embed_dim); // N = size
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C = null
pd.set(11, 0); // output_N1M
Expand Down
29 changes: 21 additions & 8 deletions src/layer/multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ int MultiHeadAttention::load_param(const ParamDict& pd)

int MultiHeadAttention::load_model(const ModelBin& mb)
{
q_weight_data = mb.load(weight_data_size, 0);
const int qdim = weight_data_size / embed_dim;

q_weight_data = mb.load(embed_dim * qdim, 0);
if (q_weight_data.empty())
return -100;

Expand All @@ -60,11 +62,11 @@ int MultiHeadAttention::load_model(const ModelBin& mb)
if (v_bias_data.empty())
return -100;

out_weight_data = mb.load(weight_data_size, 0);
out_weight_data = mb.load(qdim * embed_dim, 0);
if (out_weight_data.empty())
return -100;

out_bias_data = mb.load(embed_dim, 1);
out_bias_data = mb.load(qdim, 1);
if (out_bias_data.empty())
return -100;

Expand All @@ -82,21 +84,32 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
const int src_seqlen = q_blob.h;
const int dst_seqlen = k_blob.h;
const int embed_dim_per_head = embed_dim / num_heads;
const int qdim = weight_data_size / embed_dim;

// assert k_blob.h == v_blob.h

Mat& top_blob = top_blobs[0];
top_blob.create(embed_dim, src_seqlen, 4u, opt.blob_allocator);
top_blob.create(qdim, src_seqlen, 4u, opt.blob_allocator);
if (top_blob.empty())
return -1;
return -100;

Mat xq(embed_dim_per_head, src_seqlen, num_heads, 4u, opt.workspace_allocator);
if (xq.empty())
return -100;
Mat xk(embed_dim_per_head, dst_seqlen, num_heads, 4u, opt.workspace_allocator);
if (xk.empty())
return -100;
Mat xv(dst_seqlen, embed_dim_per_head, num_heads, 4u, opt.workspace_allocator);
if (xv.empty())
return -100;

Mat xqk(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator);
if (xqk.empty())
return -100;

Mat xqkv(embed_dim_per_head, num_heads, src_seqlen, 4u, opt.workspace_allocator);
if (xqkv.empty())
return -100;

const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

Expand All @@ -114,10 +127,10 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
for (int j = 0; j < embed_dim_per_head; j++)
{
const float* ptr = q_blob.row(i);
const float* kptr = (const float*)q_weight_data + embed_dim * (q * embed_dim_per_head + j);
const float* kptr = (const float*)q_weight_data + qdim * (q * embed_dim_per_head + j);

float sum = q_bias_data[q * embed_dim_per_head + j];
for (int k = 0; k < embed_dim; k++)
for (int k = 0; k < qdim; k++)
{
sum += *ptr++ * *kptr++;
}
Expand Down Expand Up @@ -286,7 +299,7 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
{
float* outptr = top_blob.row(i);

for (int j = 0; j < embed_dim; j++)
for (int j = 0; j < qdim; j++)
{
const float* ptr = xqkv.channel(i);
const float* kptr = (const float*)out_weight_data + embed_dim * j;
Expand Down
5 changes: 3 additions & 2 deletions src/layer/vulkan/multiheadattention_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ MultiHeadAttention_vulkan::MultiHeadAttention_vulkan()
int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
{
const int embed_dim_per_head = embed_dim / num_heads;
const int qdim = weight_data_size / embed_dim;
{
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

Expand All @@ -61,7 +62,7 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, embed_dim); // K
pd.set(9, qdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
// pd.set(12, 1); // output_elempack
Expand Down Expand Up @@ -220,7 +221,7 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, embed_dim); // N = size
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
Expand Down
6 changes: 4 additions & 2 deletions src/layer/x86/multiheadattention_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ MultiHeadAttention_x86::MultiHeadAttention_x86()

int MultiHeadAttention_x86::create_pipeline(const Option& opt)
{
const int qdim = weight_data_size / embed_dim;

{
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);
Expand All @@ -53,7 +55,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(6, 1); // constantC
pd.set(7, embed_dim); // M
pd.set(8, 0); // N
pd.set(9, embed_dim); // K
pd.set(9, qdim); // K
pd.set(10, 1); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
pd.set(12, 1); // output_elempack
Expand Down Expand Up @@ -191,7 +193,7 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
pd.set(5, 1); // constantB
pd.set(6, 1); // constantC
pd.set(7, 0); // M = outch
pd.set(8, embed_dim); // N = size
pd.set(8, qdim); // N = size
pd.set(9, embed_dim); // K = maxk*inch
pd.set(10, 4); // constant_broadcast_type_C
pd.set(11, 0); // output_N1M
Expand Down
83 changes: 44 additions & 39 deletions tests/test_multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,29 @@

#include "testutil.h"

static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int num_heads, int kdim, int vdim, int attn_mask)
static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int embed_dim, int num_heads, int attn_mask)
{
int embed_dim = q.w;
const int qdim = q.w;
const int kdim = k.w;
const int vdim = v.w;

ncnn::ParamDict pd;
pd.set(0, embed_dim);
pd.set(1, num_heads);
pd.set(2, embed_dim * embed_dim);
pd.set(2, embed_dim * qdim);
pd.set(3, kdim);
pd.set(4, vdim);
pd.set(5, attn_mask);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * embed_dim);
weights[0] = RandomMat(embed_dim * qdim);
weights[1] = RandomMat(embed_dim);
weights[2] = RandomMat(embed_dim * kdim);
weights[3] = RandomMat(embed_dim);
weights[4] = RandomMat(embed_dim * vdim);
weights[5] = RandomMat(embed_dim);
weights[6] = RandomMat(embed_dim * embed_dim);
weights[7] = RandomMat(embed_dim);
weights[6] = RandomMat(qdim * embed_dim);
weights[7] = RandomMat(qdim);

std::vector<ncnn::Mat> as(3);
as[0] = q;
Expand All @@ -51,32 +53,33 @@ static int test_multiheadattention(const ncnn::Mat& q, const ncnn::Mat& k, const
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, num_heads, kdim, vdim, attn_mask);
fprintf(stderr, "test_multiheadattention failed q=(%d %d) k=(%d %d) v=(%d %d) embed_dim=%d num_heads=%d kdim=%d vdim=%d attn_mask=%d\n", q.w, q.h, k.w, k.h, v.w, v.h, embed_dim, num_heads, kdim, vdim, attn_mask);
}

return ret;
}

static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int num_heads, int kvdim)
static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& kv, int embed_dim, int num_heads)
{
int embed_dim = q.w;
const int qdim = q.w;
const int kvdim = kv.w;

ncnn::ParamDict pd;
pd.set(0, embed_dim);
pd.set(1, num_heads);
pd.set(2, embed_dim * embed_dim);
pd.set(2, embed_dim * qdim);
pd.set(3, kvdim);
pd.set(4, kvdim);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * embed_dim);
weights[0] = RandomMat(embed_dim * qdim);
weights[1] = RandomMat(embed_dim);
weights[2] = RandomMat(embed_dim * kvdim);
weights[3] = RandomMat(embed_dim);
weights[4] = RandomMat(embed_dim * kvdim);
weights[5] = RandomMat(embed_dim);
weights[6] = RandomMat(embed_dim * embed_dim);
weights[7] = RandomMat(embed_dim);
weights[6] = RandomMat(qdim * embed_dim);
weights[7] = RandomMat(qdim);

std::vector<ncnn::Mat> as(2);
as[0] = q;
Expand All @@ -87,30 +90,32 @@ static int test_multiheadattention_samekv(const ncnn::Mat& q, const ncnn::Mat& k
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, num_heads, kvdim);
fprintf(stderr, "test_multiheadattention_samekv failed q=(%d %d) kv=(%d %d) embed_dim=%d num_heads=%d kvdim=%d\n", q.w, q.h, kv.w, kv.h, embed_dim, num_heads, kvdim);
}

return ret;
}

static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, int num_heads)
{
int embed_dim = a.w;
const int qdim = a.w;

ncnn::ParamDict pd;
pd.set(0, embed_dim);
pd.set(1, num_heads);
pd.set(2, embed_dim * embed_dim);
pd.set(2, embed_dim * qdim);
pd.set(3, qdim);
pd.set(4, qdim);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * embed_dim);
weights[0] = RandomMat(embed_dim * qdim);
weights[1] = RandomMat(embed_dim);
weights[2] = RandomMat(embed_dim * embed_dim);
weights[2] = RandomMat(embed_dim * qdim);
weights[3] = RandomMat(embed_dim);
weights[4] = RandomMat(embed_dim * embed_dim);
weights[4] = RandomMat(embed_dim * qdim);
weights[5] = RandomMat(embed_dim);
weights[6] = RandomMat(embed_dim * embed_dim);
weights[7] = RandomMat(embed_dim);
weights[6] = RandomMat(qdim * embed_dim);
weights[7] = RandomMat(qdim);

std::vector<ncnn::Mat> as(1);
as[0] = a;
Expand All @@ -120,7 +125,7 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
int ret = test_layer("MultiHeadAttention", pd, weights, as, 1, epsilon);
if (ret != 0)
{
fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) num_heads=%d\n", a.w, a.h, num_heads);
fprintf(stderr, "test_multiheadattention_sameqkv failed a=(%d %d) embed_dim=%d num_heads=%d\n", a.w, a.h, embed_dim, num_heads);
}

return ret;
Expand All @@ -129,32 +134,32 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int num_heads)
static int test_multiheadattention_0()
{
return 0
|| test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 2, 32, 20, 0)
|| test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 2, 32, 18, 1)
|| test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 4, 64, 64, 0)
|| test_multiheadattention(RandomMat(64, 127), RandomMat(64, 127), RandomMat(64, 127), 16, 64, 64, 1)
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 2, 44, 55, 0)
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 127), RandomMat(55, 127), 4, 44, 55, 1)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 3, 28, 32, 0)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 3, 28, 11, 1);
|| test_multiheadattention(RandomMat(62, 66), RandomMat(32, 66), RandomMat(20, 66), 62, 2, 0)
|| test_multiheadattention(RandomMat(26, 64), RandomMat(32, 64), RandomMat(18, 64), 26, 2, 1)
|| test_multiheadattention(RandomMat(64, 128), RandomMat(64, 128), RandomMat(64, 128), 64, 4, 0)
|| test_multiheadattention(RandomMat(48, 127), RandomMat(64, 127), RandomMat(64, 127), 64, 16, 1)
|| test_multiheadattention(RandomMat(16, 128), RandomMat(44, 128), RandomMat(55, 128), 16, 2, 0)
|| test_multiheadattention(RandomMat(12, 128), RandomMat(44, 127), RandomMat(55, 127), 16, 4, 1)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 127), RandomMat(32, 127), 12, 3, 0)
|| test_multiheadattention(RandomMat(12, 17), RandomMat(28, 32), RandomMat(11, 32), 12, 3, 1);
}

static int test_multiheadattention_1()
{
return 0
|| test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 4, 64)
|| test_multiheadattention_samekv(RandomMat(64, 127), RandomMat(64, 127), 16, 64)
|| test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 2, 44)
|| test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(22, 127), 4, 22)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 3, 28)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 3, 11);
|| test_multiheadattention_samekv(RandomMat(64, 128), RandomMat(64, 128), 64, 4)
|| test_multiheadattention_samekv(RandomMat(48, 127), RandomMat(64, 127), 64, 16)
|| test_multiheadattention_samekv(RandomMat(16, 128), RandomMat(44, 128), 16, 2)
|| test_multiheadattention_samekv(RandomMat(12, 128), RandomMat(22, 127), 16, 4)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(28, 127), 12, 3)
|| test_multiheadattention_samekv(RandomMat(12, 17), RandomMat(11, 32), 12, 3);
}

static int test_multiheadattention_2()
{
return 0
|| test_multiheadattention_sameqkv(RandomMat(64, 128), 4)
|| test_multiheadattention_sameqkv(RandomMat(64, 127), 8);
|| test_multiheadattention_sameqkv(RandomMat(64, 128), 64, 4)
|| test_multiheadattention_sameqkv(RandomMat(48, 127), 64, 8);
}

int main()
Expand Down
Loading

0 comments on commit 8235cad

Please sign in to comment.