Skip to content

Commit

Permalink
multiheadattention scale param
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 21, 2024
1 parent f2a34ee commit ed0e494
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 19 deletions.
9 changes: 5 additions & 4 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1266,21 +1266,22 @@ y = affine(out)
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | embed_dim | int | 0 | |
| 1 | num_heads | int | 1 | |
| 2 | weight_data_size| int | 0 | |
| 2 | weight_data_size| int | 0 | qdim = weight_data_size / embed_dim |
| 3 | kdim | int | embed_dim | |
| 4 | vdim | int | embed_dim | |
| 5 | attn_mask | int | 0 | |
| 6 | scale | float | 1.f / sqrt(embed_dim / num_heads) | |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| q_weight_data | float/fp16/int8 | [weight_data_size] |
| q_weight_data | float/fp16/int8 | [embed_dim * qdim] |
| q_bias_data | float | [embed_dim] |
| k_weight_data | float/fp16/int8 | [embed_dim * kdim] |
| k_bias_data | float | [embed_dim] |
| v_weight_data | float/fp16/int8 | [embed_dim * vdim] |
| v_bias_data | float | [embed_dim] |
| out_weight_data| float/fp16/int8 | [weight_data_size] |
| out_bias_data | float | [embed_dim] |
| out_weight_data| float/fp16/int8 | [qdim * embed_dim] |
| out_bias_data | float | [qdim] |

# MVN
```
Expand Down
5 changes: 1 addition & 4 deletions src/layer/arm/multiheadattention_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,9 @@ int MultiHeadAttention_arm::create_pipeline(const Option& _opt)
const int qdim = weight_data_size / embed_dim;

{
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

q_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
ncnn::ParamDict pd;
pd.set(0, inv_sqrt_embed_dim_per_head);
pd.set(0, scale);
pd.set(1, 1.f);
pd.set(2, 0); // transA
pd.set(3, 1); // transB
Expand Down
8 changes: 4 additions & 4 deletions src/layer/multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "multiheadattention.h"

#include <float.h>
#include <math.h>

namespace ncnn {

Expand All @@ -30,6 +31,7 @@ int MultiHeadAttention::load_param(const ParamDict& pd)
kdim = pd.get(3, embed_dim);
vdim = pd.get(4, embed_dim);
attn_mask = pd.get(5, 0);
scale = pd.get(6, 1.f / sqrtf(embed_dim / num_heads));

return 0;
}
Expand Down Expand Up @@ -111,12 +113,10 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
if (xqkv.empty())
return -100;

const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_heads; q++)
{
// xq = affine(q) * inv_sqrt_embed_dim_per_head
// xq = affine(q) * scale
{
Mat outm = xq.channel(q);

Expand All @@ -135,7 +135,7 @@ int MultiHeadAttention::forward(const std::vector<Mat>& bottom_blobs, std::vecto
sum += *ptr++ * *kptr++;
}

outptr[j] = sum * inv_sqrt_embed_dim_per_head;
outptr[j] = sum * scale;
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/layer/multiheadattention.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MultiHeadAttention : public Layer
int kdim;
int vdim;
int attn_mask;
float scale;

Mat q_weight_data;
Mat q_bias_data;
Expand Down
4 changes: 1 addition & 3 deletions src/layer/vulkan/multiheadattention_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@ int MultiHeadAttention_vulkan::create_pipeline(const Option& opt)
const int embed_dim_per_head = embed_dim / num_heads;
const int qdim = weight_data_size / embed_dim;
{
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

q_gemm = ncnn::create_layer_vulkan(ncnn::LayerType::Gemm);
q_gemm->vkdev = vkdev;
ncnn::ParamDict pd;
pd.set(0, inv_sqrt_embed_dim_per_head);
pd.set(0, scale);
pd.set(1, 1.f);
pd.set(2, 0); // transA
pd.set(3, 1); // transB
Expand Down
5 changes: 1 addition & 4 deletions src/layer/x86/multiheadattention_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ int MultiHeadAttention_x86::create_pipeline(const Option& opt)
const int qdim = weight_data_size / embed_dim;

{
const int embed_dim_per_head = embed_dim / num_heads;
const float inv_sqrt_embed_dim_per_head = 1.f / sqrtf(embed_dim_per_head);

q_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm);
ncnn::ParamDict pd;
pd.set(0, inv_sqrt_embed_dim_per_head);
pd.set(0, scale);
pd.set(1, 1.f);
pd.set(2, 0); // transA
pd.set(3, 1); // transB
Expand Down
1 change: 1 addition & 0 deletions tests/test_multiheadattention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ static int test_multiheadattention_sameqkv(const ncnn::Mat& a, int embed_dim, in
pd.set(2, embed_dim * qdim);
pd.set(3, qdim);
pd.set(4, qdim);
pd.set(6, 0.6f);

std::vector<ncnn::Mat> weights(8);
weights[0] = RandomMat(embed_dim * qdim);
Expand Down

0 comments on commit ed0e494

Please sign in to comment.