diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 6056c277b1f..05996f8d735 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -1266,21 +1266,22 @@ y = affine(out) | --------- | ------------- | ----- | --------- | ----------------- | | 0 | embed_dim | int | 0 | | | 1 | num_heads | int | 1 | | -| 2 | weight_data_size| int | 0 | | +| 2 | weight_data_size| int | 0 | qdim = weight_data_size / embed_dim | | 3 | kdim | int | embed_dim | | | 4 | vdim | int | embed_dim | | | 5 | attn_mask | int | 0 | | +| 6 | scale | float | 1.f / sqrt(embed_dim / num_heads) | | | weight | type | shape | | ------------- | ----- | --------------------- | -| q_weight_data | float/fp16/int8 | [weight_data_size] | +| q_weight_data | float/fp16/int8 | [embed_dim * qdim] | | q_bias_data | float | [embed_dim] | | k_weight_data | float/fp16/int8 | [embed_dim * kdim] | | k_bias_data | float | [embed_dim] | | v_weight_data | float/fp16/int8 | [embed_dim * vdim] | | v_bias_data | float | [embed_dim] | -| out_weight_data| float/fp16/int8 | [weight_data_size] | -| out_bias_data | float | [embed_dim] | +| out_weight_data| float/fp16/int8 | [qdim * embed_dim] | +| out_bias_data | float | [qdim] | # MVN ``` diff --git a/src/layer/arm/multiheadattention_arm.cpp b/src/layer/arm/multiheadattention_arm.cpp index f5826ddae9d..9fedf8b16d7 100644 --- a/src/layer/arm/multiheadattention_arm.cpp +++ b/src/layer/arm/multiheadattention_arm.cpp @@ -60,12 +60,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt) const int qdim = weight_data_size / embed_dim; { - const int embed_dim_per_head = embed_dim / num_heads; - const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); - q_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, inv_sqrt_embed_dim_per_head); + pd.set(0, scale); pd.set(1, 1.f); pd.set(2, 0); // transA pd.set(3, 1); // transB diff --git a/src/layer/multiheadattention.cpp b/src/layer/multiheadattention.cpp index 284801a2c7d..a9592af4093 100644 --- a/src/layer/multiheadattention.cpp +++ b/src/layer/multiheadattention.cpp @@ -15,6 +15,7 @@ #include "multiheadattention.h" #include +#include namespace ncnn { @@ -30,6 +31,7 @@ int MultiHeadAttention::load_param(const ParamDict& pd) kdim = pd.get(3, embed_dim); vdim = pd.get(4, embed_dim); attn_mask = pd.get(5, 0); + scale = pd.get(6, 1.f / sqrtf(embed_dim / num_heads)); return 0; } @@ -111,12 +113,10 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto if (xqkv.empty()) return -100; - const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); - #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_heads; q++) { - // xq = affine(q) * inv_sqrt_embed_dim_per_head + // xq = affine(q) * scale { Mat outm = xq.channel(q); @@ -135,7 +135,7 @@ int MultiHeadAttention::forward(const std::vector& bottom_blobs, std::vecto sum += *ptr++ * *kptr++; } - outptr[j] = sum * inv_sqrt_embed_dim_per_head; + outptr[j] = sum * scale; } } } diff --git a/src/layer/multiheadattention.h b/src/layer/multiheadattention.h index 50c8549ac9a..55764bd9c64 100644 --- a/src/layer/multiheadattention.h +++ b/src/layer/multiheadattention.h @@ -37,6 +37,7 @@ class MultiHeadAttention : public Layer int kdim; int vdim; int attn_mask; + float scale; Mat q_weight_data; Mat q_bias_data; diff --git a/src/layer/vulkan/multiheadattention_vulkan.cpp b/src/layer/vulkan/multiheadattention_vulkan.cpp index f1d7ce3636b..1abc09c30e6 100644 --- a/src/layer/vulkan/multiheadattention_vulkan.cpp +++ b/src/layer/vulkan/multiheadattention_vulkan.cpp @@ -48,12 +48,10 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt) const int embed_dim_per_head = embed_dim / num_heads; const int qdim = weight_data_size / embed_dim; { - const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); - q_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm); q_gemm->vkdev = vkdev; ncnn::ParamDict pd; - pd.set(0, inv_sqrt_embed_dim_per_head); + pd.set(0, scale); pd.set(1, 1.f); pd.set(2, 0); // transA pd.set(3, 1); // transB diff --git a/src/layer/x86/multiheadattention_x86.cpp b/src/layer/x86/multiheadattention_x86.cpp index db5f730aec2..9bddb3a78ef 100644 --- a/src/layer/x86/multiheadattention_x86.cpp +++ b/src/layer/x86/multiheadattention_x86.cpp @@ -41,12 +41,9 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt) const int qdim = weight_data_size / embed_dim; { - const int embed_dim_per_head = embed_dim / num_heads; - const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head); - q_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); ncnn::ParamDict pd; - pd.set(0, inv_sqrt_embed_dim_per_head); + pd.set(0, scale); pd.set(1, 1.f); pd.set(2, 0); // transA pd.set(3, 1); // transB diff --git a/tests/test_multiheadattention.cpp b/tests/test_multiheadattention.cpp index c509f8156e8..9cf211b5ed5 100644 --- a/tests/test_multiheadattention.cpp +++ b/tests/test_multiheadattention.cpp @@ -106,6 +106,7 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, in pd.set(2, embed_dim * qdim); pd.set(3, qdim); pd.set(4, qdim); + pd.set(6, 0.6f); std::vector weights(8); weights[0] = RandomMat(embed_dim * qdim);