Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 28, 2024
1 parent d0d9a90 commit 4adbfeb
Show file tree
Hide file tree
Showing 7 changed files with 1,114 additions and 1,534 deletions.
2 changes: 1 addition & 1 deletion src/layer/arm/gru_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ int GRU_arm::forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>
{
if (elemtype == 1)
{
hidden = bottom_blobs[1].clone();
hidden = bottom_blobs[1].clone(hidden_allocator);
}
if (elemtype == 2)
{
Expand Down
46 changes: 18 additions & 28 deletions src/layer/arm/gru_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,29 +416,25 @@ static void gru_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc
{
kptr[0] = weight_xc_R[i];
kptr[1] = weight_xc_U[i];

kptr += 2;
}

for (int i = 0; i < num_output; i++)
{
kptr[0] = weight_hc_R[i];
kptr[1] = weight_hc_U[i];

kptr += 2;
}

for (int i = 0; i < num_output; i++)
{
kptr[0] = weight_hc_N[i];

kptr += 1;
}

for (int i = 0; i < size; i++)
{
kptr[0] = weight_xc_N[i];

kptr += 1;
}

Expand Down Expand Up @@ -537,8 +533,8 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int32x4_t _gru_Ux0 = vdupq_n_s32(0);
int i = 0;
#if __ARM_FEATURE_DOTPROD
int32x4_t _sum1 = vdupq_n_s32(0);
int32x4_t _sum2 = vdupq_n_s32(0);
int32x4_t _sum3 = vdupq_n_s32(0);
for (; i + 7 < size; i += 8)
{
int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i));
Expand All @@ -550,13 +546,13 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int8x16_t _w3 = vld1q_s8(kptr + 48);
_gru_Rx0 = vdotq_s32(_gru_Rx0, _w0, _xi0);
_gru_Ux0 = vdotq_s32(_gru_Ux0, _w1, _xi0);
_sum2 = vdotq_s32(_sum2, _w2, _xi1);
_sum3 = vdotq_s32(_sum3, _w3, _xi1);
_sum1 = vdotq_s32(_sum1, _w2, _xi1);
_sum2 = vdotq_s32(_sum2, _w3, _xi1);

kptr += 64;
}
_gru_Rx0 = vaddq_s32(_gru_Rx0, _sum2);
_gru_Ux0 = vaddq_s32(_gru_Ux0, _sum3);
_gru_Rx0 = vaddq_s32(_gru_Rx0, _sum1);
_gru_Ux0 = vaddq_s32(_gru_Ux0, _sum2);
#endif // __ARM_FEATURE_DOTPROD
for (; i + 3 < size; i += 4)
{
Expand Down Expand Up @@ -613,8 +609,8 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int32x4_t _gru_Uh0 = vdupq_n_s32(0);
i = 0;
#if __ARM_FEATURE_DOTPROD
_sum1 = vdupq_n_s32(0);
_sum2 = vdupq_n_s32(0);
_sum3 = vdupq_n_s32(0);
for (; i + 7 < num_output; i += 8)
{
int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i));
Expand All @@ -626,13 +622,13 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int8x16_t _w3 = vld1q_s8(kptr + 48);
_gru_Rh0 = vdotq_s32(_gru_Rh0, _w0, _h_cont0);
_gru_Uh0 = vdotq_s32(_gru_Uh0, _w1, _h_cont0);
_sum2 = vdotq_s32(_sum2, _w2, _h_cont1);
_sum3 = vdotq_s32(_sum3, _w3, _h_cont1);
_sum1 = vdotq_s32(_sum1, _w2, _h_cont1);
_sum2 = vdotq_s32(_sum2, _w3, _h_cont1);

kptr += 64;
}
_gru_Rh0 = vaddq_s32(_gru_Rh0, _sum2);
_gru_Uh0 = vaddq_s32(_gru_Uh0, _sum3);
_gru_Rh0 = vaddq_s32(_gru_Rh0, _sum1);
_gru_Uh0 = vaddq_s32(_gru_Uh0, _sum2);
#endif // __ARM_FEATURE_DOTPROD
for (; i + 3 < num_output; i += 4)
{
Expand Down Expand Up @@ -713,7 +709,7 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int32x4_t _gru_Nh0 = vdupq_n_s32(0);
i = 0;
#if __ARM_FEATURE_DOTPROD
_sum2 = vdupq_n_s32(0);
_sum1 = vdupq_n_s32(0);
for (; i + 7 < num_output; i += 8)
{
int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i));
Expand All @@ -722,11 +718,11 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_gru_Nh0 = vdotq_s32(_gru_Nh0, _w0, _h_cont0);
_sum2 = vdotq_s32(_sum2, _w1, _h_cont1);
_sum1 = vdotq_s32(_sum1, _w1, _h_cont1);

kptr += 32;
}
_gru_Nh0 = vaddq_s32(_gru_Nh0, _sum2);
_gru_Nh0 = vaddq_s32(_gru_Nh0, _sum1);
#endif // __ARM_FEATURE_DOTPROD
for (; i + 3 < num_output; i += 4)
{
Expand Down Expand Up @@ -771,7 +767,7 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int32x4_t _gru_Nx0 = vdupq_n_s32(0);
i = 0;
#if __ARM_FEATURE_DOTPROD
_sum2 = vdupq_n_s32(0);
_sum1 = vdupq_n_s32(0);
for (; i + 7 < size; i += 8)
{
int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i));
Expand All @@ -780,11 +776,11 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int8x16_t _w0 = vld1q_s8(kptr);
int8x16_t _w1 = vld1q_s8(kptr + 16);
_gru_Nx0 = vdotq_s32(_gru_Nx0, _w0, _xi0);
_sum2 = vdotq_s32(_sum2, _w1, _xi1);
_sum1 = vdotq_s32(_sum1, _w1, _xi1);

kptr += 32;
}
_gru_Nx0 = vaddq_s32(_gru_Nx0, _sum2);
_gru_Nx0 = vaddq_s32(_gru_Nx0, _sum1);
#endif // __ARM_FEATURE_DOTPROD
for (; i + 3 < size; i += 4)
{
Expand Down Expand Up @@ -910,20 +906,14 @@ static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_de
int Nh = 0;
for (int i = 0; i < num_output; i++)
{
signed char h_cont = hs[i];

Nh += kptr[0] * h_cont;

Nh += kptr[0] * hs[i];
kptr += 1;
}

int Nx = 0;
for (int i = 0; i < size; i++)
{
signed char xi = x[i];

Nx += kptr[0] * xi;

Nx += kptr[0] * x[i];
kptr += 1;
}

Expand Down
Loading

0 comments on commit 4adbfeb

Please sign in to comment.