Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jul 19, 2023
1 parent 205f746 commit 0dab232
Showing 1 changed file with 121 additions and 130 deletions.
251 changes: 121 additions & 130 deletions src/layer/x86/convolution_packed_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -3981,14 +3981,8 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
const int j2 = (ij + 2) % outw;
const int j3 = (ij + 3) % outw;

int sum00 = 0;
int sum01 = 0;
int sum02 = 0;
int sum03 = 0;
int sum10 = 0;
int sum11 = 0;
int sum12 = 0;
int sum13 = 0;
__m128i _sum0 = _mm_setzero_si128();
__m128i _sum1 = _mm_setzero_si128();

#if __AVX512F__
const signed char* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2);
Expand All @@ -4001,10 +3995,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
int q = 0;
#if __AVX512F__
{
__m512i _sum0 = _mm512_setzero_si512();
__m512i _sum1 = _mm512_setzero_si512();
__m512i _sum2 = _mm512_setzero_si512();
__m512i _sum3 = _mm512_setzero_si512();
__m512i _sum00 = _mm512_setzero_si512();
__m512i _sum11 = _mm512_setzero_si512();
__m512i _sum22 = _mm512_setzero_si512();
__m512i _sum33 = _mm512_setzero_si512();
for (; q + 15 < inch; q += 16)
{
const signed char* r0 = bottom_blob.channel(q / elempack).row<const signed char>(i0 * stride_h) + j0 * stride_w * elempack;
Expand Down Expand Up @@ -4067,45 +4061,38 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
__m512i _w = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*)kptr));

#if __AVX512VNNI__
_sum0 = _mm512_dpwssd_epi32(_sum0, _valval0, _w);
_sum1 = _mm512_dpwssd_epi32(_sum1, _valval1, _w);
_sum2 = _mm512_dpwssd_epi32(_sum2, _valval2, _w);
_sum3 = _mm512_dpwssd_epi32(_sum3, _valval3, _w);
_sum00 = _mm512_dpwssd_epi32(_sum00, _valval0, _w);
_sum11 = _mm512_dpwssd_epi32(_sum11, _valval1, _w);
_sum22 = _mm512_dpwssd_epi32(_sum22, _valval2, _w);
_sum33 = _mm512_dpwssd_epi32(_sum33, _valval3, _w);
#else
_sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_valval0, _w));
_sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_valval1, _w));
_sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_valval2, _w));
_sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_valval3, _w));
_sum00 = _mm512_add_epi32(_sum00, _mm512_madd_epi16(_valval0, _w));
_sum11 = _mm512_add_epi32(_sum11, _mm512_madd_epi16(_valval1, _w));
_sum22 = _mm512_add_epi32(_sum22, _mm512_madd_epi16(_valval2, _w));
_sum33 = _mm512_add_epi32(_sum33, _mm512_madd_epi16(_valval3, _w));
#endif // __AVX512VNNI__

kptr += 32;
}
}
__m256i _sum010 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum0, 0), _mm512_extracti64x4_epi64(_sum0, 1));
__m256i _sum011 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum1, 0), _mm512_extracti64x4_epi64(_sum1, 1));
__m256i _sum012 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum2, 0), _mm512_extracti64x4_epi64(_sum2, 1));
__m256i _sum013 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum3, 0), _mm512_extracti64x4_epi64(_sum3, 1));
__m128i _ss0 = _mm_add_epi32(_mm256_extracti128_si256(_sum010, 0), _mm256_extracti128_si256(_sum010, 1));
__m128i _ss1 = _mm_add_epi32(_mm256_extracti128_si256(_sum011, 0), _mm256_extracti128_si256(_sum011, 1));
__m128i _ss2 = _mm_add_epi32(_mm256_extracti128_si256(_sum012, 0), _mm256_extracti128_si256(_sum012, 1));
__m128i _ss3 = _mm_add_epi32(_mm256_extracti128_si256(_sum013, 0), _mm256_extracti128_si256(_sum013, 1));

sum00 += _mm_extract_epi32(_ss0, 0) + _mm_extract_epi32(_ss0, 1);
sum10 += _mm_extract_epi32(_ss0, 2) + _mm_extract_epi32(_ss0, 3);
sum01 += _mm_extract_epi32(_ss1, 0) + _mm_extract_epi32(_ss1, 1);
sum11 += _mm_extract_epi32(_ss1, 2) + _mm_extract_epi32(_ss1, 3);
sum02 += _mm_extract_epi32(_ss2, 0) + _mm_extract_epi32(_ss2, 1);
sum12 += _mm_extract_epi32(_ss2, 2) + _mm_extract_epi32(_ss2, 3);
sum03 += _mm_extract_epi32(_ss3, 0) + _mm_extract_epi32(_ss3, 1);
sum13 += _mm_extract_epi32(_ss3, 2) + _mm_extract_epi32(_ss3, 3);
__m256i _sum010 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum00, 0), _mm512_extracti64x4_epi64(_sum11, 0));
__m256i _sum230 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum22, 0), _mm512_extracti64x4_epi64(_sum33, 0));
__m256i _sum011 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum00, 1), _mm512_extracti64x4_epi64(_sum11, 1));
__m256i _sum231 = _mm256_hadd_epi32(_mm512_extracti64x4_epi64(_sum22, 1), _mm512_extracti64x4_epi64(_sum33, 1));
__m256i _ss0 = _mm256_hadd_epi32(_sum010, _sum230);
__m256i _ss1 = _mm256_hadd_epi32(_sum011, _sum231);
_sum0 = _mm_add_epi32(_sum0, _mm256_extracti128_si256(_ss0, 0));
_sum0 = _mm_add_epi32(_sum0, _mm256_extracti128_si256(_ss0, 1));
_sum1 = _mm_add_epi32(_sum1, _mm256_extracti128_si256(_ss1, 0));
_sum1 = _mm_add_epi32(_sum1, _mm256_extracti128_si256(_ss1, 1));
}
#endif // __AVX512F__
{
#if __AVX2__
__m256i _sum0 = _mm256_setzero_si256();
__m256i _sum1 = _mm256_setzero_si256();
__m256i _sum2 = _mm256_setzero_si256();
__m256i _sum3 = _mm256_setzero_si256();
__m256i _sum00 = _mm256_setzero_si256();
__m256i _sum11 = _mm256_setzero_si256();
__m256i _sum22 = _mm256_setzero_si256();
__m256i _sum33 = _mm256_setzero_si256();
#else
__m128i _sum00 = _mm_setzero_si128();
__m128i _sum10 = _mm_setzero_si128();
Expand Down Expand Up @@ -4193,15 +4180,15 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
__m256i _w = _mm256_cvtepi8_epi16(_w01);

#if __AVXVNNI__ || __AVX512VNNI__
_sum0 = _mm256_dpwssd_epi32(_sum0, _valval0, _w);
_sum1 = _mm256_dpwssd_epi32(_sum1, _valval1, _w);
_sum2 = _mm256_dpwssd_epi32(_sum2, _valval2, _w);
_sum3 = _mm256_dpwssd_epi32(_sum3, _valval3, _w);
_sum00 = _mm256_dpwssd_epi32(_sum00, _valval0, _w);
_sum11 = _mm256_dpwssd_epi32(_sum11, _valval1, _w);
_sum22 = _mm256_dpwssd_epi32(_sum22, _valval2, _w);
_sum33 = _mm256_dpwssd_epi32(_sum33, _valval3, _w);
#else
_sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_valval0, _w));
_sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_valval1, _w));
_sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_valval2, _w));
_sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_valval3, _w));
_sum00 = _mm256_add_epi32(_sum00, _mm256_madd_epi16(_valval0, _w));
_sum11 = _mm256_add_epi32(_sum11, _mm256_madd_epi16(_valval1, _w));
_sum22 = _mm256_add_epi32(_sum22, _mm256_madd_epi16(_valval2, _w));
_sum33 = _mm256_add_epi32(_sum33, _mm256_madd_epi16(_valval3, _w));
#endif
#else // __AVX2__
__m128i _extw01 = _mm_cmpgt_epi8(_mm_setzero_si128(), _w01);
Expand Down Expand Up @@ -4233,23 +4220,37 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
}
}
#if __AVX2__
sum00 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum0, 0));
sum10 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum0, 1));
sum01 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum1, 0));
sum11 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum1, 1));
sum02 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum2, 0));
sum12 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum2, 1));
sum03 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum3, 0));
sum13 += _mm_reduce_add_epi32(_mm256_extracti128_si256(_sum3, 1));
__m256i _sum01 = _mm256_hadd_epi32(_sum00, _sum11);
__m256i _sum23 = _mm256_hadd_epi32(_sum22, _sum33);
__m256i _ss = _mm256_hadd_epi32(_sum01, _sum23);
_sum0 = _mm_add_epi32(_sum0, _mm256_extracti128_si256(_ss, 0));
_sum1 = _mm_add_epi32(_sum1, _mm256_extracti128_si256(_ss, 1));
#else
sum00 += _mm_reduce_add_epi32(_sum00);
sum10 += _mm_reduce_add_epi32(_sum10);
sum01 += _mm_reduce_add_epi32(_sum01);
sum11 += _mm_reduce_add_epi32(_sum11);
sum02 += _mm_reduce_add_epi32(_sum02);
sum12 += _mm_reduce_add_epi32(_sum12);
sum03 += _mm_reduce_add_epi32(_sum03);
sum13 += _mm_reduce_add_epi32(_sum13);
// transpose 4x4
__m128i _tmp00 = _mm_unpacklo_epi32(_sum00, _sum01);
__m128i _tmp01 = _mm_unpacklo_epi32(_sum02, _sum03);
__m128i _tmp02 = _mm_unpackhi_epi32(_sum00, _sum01);
__m128i _tmp03 = _mm_unpackhi_epi32(_sum02, _sum03);
__m128i _tmp10 = _mm_unpacklo_epi32(_sum10, _sum11);
__m128i _tmp11 = _mm_unpacklo_epi32(_sum12, _sum13);
__m128i _tmp12 = _mm_unpackhi_epi32(_sum10, _sum11);
__m128i _tmp13 = _mm_unpackhi_epi32(_sum12, _sum13);
_sum00 = _mm_unpacklo_epi64(_tmp00, _tmp01);
_sum01 = _mm_unpackhi_epi64(_tmp00, _tmp01);
_sum02 = _mm_unpacklo_epi64(_tmp02, _tmp03);
_sum03 = _mm_unpackhi_epi64(_tmp02, _tmp03);
_sum10 = _mm_unpacklo_epi64(_tmp10, _tmp11);
_sum11 = _mm_unpackhi_epi64(_tmp10, _tmp11);
_sum12 = _mm_unpacklo_epi64(_tmp12, _tmp13);
_sum13 = _mm_unpackhi_epi64(_tmp12, _tmp13);
_sum00 = _mm_add_epi32(_sum00, _sum01);
_sum02 = _mm_add_epi32(_sum02, _sum03);
_sum10 = _mm_add_epi32(_sum10, _sum11);
_sum12 = _mm_add_epi32(_sum12, _sum13);
_sum0 = _mm_add_epi32(_sum0, _sum00);
_sum0 = _mm_add_epi32(_sum0, _sum02);
_sum1 = _mm_add_epi32(_sum1, _sum10);
_sum1 = _mm_add_epi32(_sum1, _sum12);
#endif // __AVX2__
}
for (; q + 1 < inch; q += 2)
Expand All @@ -4268,22 +4269,12 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const

// if (elempack == 1)
{
sum00 += r0s[0] * kptr[0];
sum10 += r0s[0] * kptr[1];
sum00 += r0s[N] * kptr[2];
sum10 += r0s[N] * kptr[3];
sum01 += r1s[0] * kptr[0];
sum11 += r1s[0] * kptr[1];
sum01 += r1s[N] * kptr[2];
sum11 += r1s[N] * kptr[3];
sum02 += r2s[0] * kptr[0];
sum12 += r2s[0] * kptr[1];
sum02 += r2s[N] * kptr[2];
sum12 += r2s[N] * kptr[3];
sum03 += r3s[0] * kptr[0];
sum13 += r3s[0] * kptr[1];
sum03 += r3s[N] * kptr[2];
sum13 += r3s[N] * kptr[3];
__m128i _r = _mm_setr_epi16(r0s[0], r0s[N], r1s[0], r1s[N], r2s[0], r2s[N], r3s[0], r3s[N]);
__m128i _w0 = _mm_setr_epi16(kptr[0], kptr[2], kptr[0], kptr[2], kptr[0], kptr[2], kptr[0], kptr[2]);
__m128i _w1 = _mm_setr_epi16(kptr[1], kptr[3], kptr[1], kptr[3], kptr[1], kptr[3], kptr[1], kptr[3]);

_sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_r, _w0));
_sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_r, _w1));

kptr += 4;
}
Expand All @@ -4305,28 +4296,24 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const

// if (elempack == 1)
{
sum00 += r0s[0] * kptr[0];
sum10 += r0s[0] * kptr[1];
sum01 += r1s[0] * kptr[0];
sum11 += r1s[0] * kptr[1];
sum02 += r2s[0] * kptr[0];
sum12 += r2s[0] * kptr[1];
sum03 += r3s[0] * kptr[0];
sum13 += r3s[0] * kptr[1];
__m128i _r = _mm_setr_epi16(r0s[0], r1s[0], r2s[0], r3s[0], r0s[0], r1s[0], r2s[0], r3s[0]);
__m128i _w = _mm_setr_epi16(kptr[0], kptr[0], kptr[0], kptr[0], kptr[1], kptr[1], kptr[1], kptr[1]);

__m128i _sl = _mm_mullo_epi16(_r, _w);
__m128i _sh = _mm_mulhi_epi16(_r, _w);
__m128i _s0 = _mm_unpacklo_epi16(_sl, _sh);
__m128i _s1 = _mm_unpackhi_epi16(_sl, _sh);

_sum0 = _mm_add_epi32(_sum0, _s0);
_sum1 = _mm_add_epi32(_sum1, _s1);

kptr += 2;
}
}
}

outptr0[0] = sum00;
outptr0[1] = sum01;
outptr0[2] = sum02;
outptr0[3] = sum03;
outptr1[0] = sum10;
outptr1[1] = sum11;
outptr1[2] = sum12;
outptr1[3] = sum13;
_mm_store_si128((__m128i*)outptr0, _sum0);
_mm_store_si128((__m128i*)outptr1, _sum1);
outptr0 += 4;
outptr1 += 4;
}
Expand Down Expand Up @@ -4793,10 +4780,7 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
const int j2 = (ij + 2) % outw;
const int j3 = (ij + 3) % outw;

int sum0 = 0;
int sum1 = 0;
int sum2 = 0;
int sum3 = 0;
__m128i _sum = _mm_setzero_si128();

#if __AVX512F__
const signed char* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2);
Expand Down Expand Up @@ -4886,15 +4870,11 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
kptr += 16;
}
}

__m128i _ss0 = _mm_add_epi32(_mm256_extracti128_si256(_sum0, 0), _mm256_extracti128_si256(_sum0, 1));
__m128i _ss1 = _mm_add_epi32(_mm256_extracti128_si256(_sum1, 0), _mm256_extracti128_si256(_sum1, 1));
__m128i _ss2 = _mm_add_epi32(_mm256_extracti128_si256(_sum2, 0), _mm256_extracti128_si256(_sum2, 1));
__m128i _ss3 = _mm_add_epi32(_mm256_extracti128_si256(_sum3, 0), _mm256_extracti128_si256(_sum3, 1));
sum0 += _mm_reduce_add_epi32(_ss0);
sum1 += _mm_reduce_add_epi32(_ss1);
sum2 += _mm_reduce_add_epi32(_ss2);
sum3 += _mm_reduce_add_epi32(_ss3);
_sum0 = _mm256_hadd_epi32(_sum0, _sum1);
_sum2 = _mm256_hadd_epi32(_sum2, _sum3);
_sum0 = _mm256_hadd_epi32(_sum0, _sum2);
_sum = _mm_add_epi32(_sum, _mm256_extracti128_si256(_sum0, 0));
_sum = _mm_add_epi32(_sum, _mm256_extracti128_si256(_sum0, 1));
}
#endif // __AVX512F__
{
Expand Down Expand Up @@ -4997,10 +4977,24 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const
kptr += 8;
}
}
sum0 += _mm_reduce_add_epi32(_sum0);
sum1 += _mm_reduce_add_epi32(_sum1);
sum2 += _mm_reduce_add_epi32(_sum2);
sum3 += _mm_reduce_add_epi32(_sum3);
#if __SSSE3__
__m128i _ss = _mm_hadd_epi32(_mm_hadd_epi32(_sum0, _sum1), _mm_hadd_epi32(_sum2, _sum3));
_sum = _mm_add_epi32(_sum, _ss);
#else
// transpose 4x4
__m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1);
__m128i _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3);
__m128i _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1);
__m128i _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3);
_sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1);
_sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1);
_sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3);
_sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3);
_sum0 = _mm_add_epi32(_sum0, _sum1);
_sum2 = _mm_add_epi32(_sum2, _sum3);
_sum = _mm_add_epi32(_sum, _sum0);
_sum = _mm_add_epi32(_sum, _sum2);
#endif // __SSSE3__
}
for (; q + 1 < inch; q += 2)
{
Expand All @@ -5018,14 +5012,10 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const

// if (elempack == 1)
{
sum0 += r0s[0] * kptr[0];
sum0 += r0s[N] * kptr[1];
sum1 += r1s[0] * kptr[0];
sum1 += r1s[N] * kptr[1];
sum2 += r2s[0] * kptr[0];
sum2 += r2s[N] * kptr[1];
sum3 += r3s[0] * kptr[0];
sum3 += r3s[N] * kptr[1];
__m128i _r = _mm_setr_epi16(r0s[0], r0s[N], r1s[0], r1s[N], r2s[0], r2s[N], r3s[0], r3s[N]);
__m128i _w = _mm_setr_epi16(kptr[0], kptr[1], kptr[0], kptr[1], kptr[0], kptr[1], kptr[0], kptr[1]);

_sum = _mm_add_epi32(_sum, _mm_madd_epi16(_r, _w));

kptr += 2;
}
Expand All @@ -5047,20 +5037,21 @@ static void convolution_packed_int8(const Mat& bottom_blob, Mat& top_blob, const

// if (elempack == 1)
{
sum0 += r0s[0] * kptr[0];
sum1 += r1s[0] * kptr[0];
sum2 += r2s[0] * kptr[0];
sum3 += r3s[0] * kptr[0];
__m128i _r = _mm_setr_epi16(r0s[0], r1s[0], r2s[0], r3s[0], 0, 0, 0, 0);
__m128i _w = _mm_set1_epi16(kptr[0]);

__m128i _sl = _mm_mullo_epi16(_r, _w);
__m128i _sh = _mm_mulhi_epi16(_r, _w);
__m128i _s0 = _mm_unpacklo_epi16(_sl, _sh);

_sum = _mm_add_epi32(_sum, _s0);

kptr += 1;
}
}
}

outptr[0] = sum0;
outptr[1] = sum1;
outptr[2] = sum2;
outptr[3] = sum3;
_mm_store_si128((__m128i*)outptr, _sum);
outptr += 4;
}
#endif // __SSE2__
Expand Down

0 comments on commit 0dab232

Please sign in to comment.