From e13ab3e75d140e16cce972ada7c0dc1fa5df931d Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 29 Nov 2024 11:48:42 +0000 Subject: [PATCH] opt 8x8 --- src/layer/x86/gemm_int8.h | 585 +++++++++++++------------------------- 1 file changed, 204 insertions(+), 381 deletions(-) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 0ff42303ffb..ce3189b0ba7 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -14697,12 +14697,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& // from // 00 11 22 33 44 55 66 77 // 01 12 23 30 45 56 67 74 - // 60 71 42 53 24 35 06 17 - // 61 72 43 50 25 36 07 14 - // 02 13 20 31 46 57 64 75 - // 03 10 21 32 47 54 65 76 - // 62 73 40 51 26 37 04 15 - // 63 70 41 52 27 34 05 16 + // 20 31 02 13 64 75 46 57 + // 21 32 03 10 65 76 47 54 + // 04 15 26 37 40 51 62 73 + // 05 16 27 34 41 52 63 70 + // 24 35 06 17 60 71 42 53 + // 25 36 07 14 61 72 43 50 // to // 00 10 20 30 40 50 60 70 @@ -14718,79 +14718,42 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& __m256 _tmp1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); __m256 _tmp2 = _f2; __m256 _tmp3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); - __m256 _tmp4 = _mm256_shuffle_ps(_f4, _f4, _MM_SHUFFLE(1, 0, 3, 2)); - __m256 _tmp5 = _mm256_shuffle_ps(_f5, _f5, _MM_SHUFFLE(0, 3, 2, 1)); - __m256 _tmp6 = _mm256_shuffle_ps(_f6, _f6, _MM_SHUFFLE(1, 0, 3, 2)); - __m256 _tmp7 = _mm256_shuffle_ps(_f7, _f7, _MM_SHUFFLE(0, 3, 2, 1)); - - // 00 11 22 33 44 55 66 77 - // 30 01 12 23 74 45 56 67 - // 60 71 42 53 24 35 06 17 - // 50 61 72 43 14 25 36 07 - // 20 31 02 13 64 75 46 57 - // 10 21 32 03 54 65 76 47 - // 40 51 62 73 04 15 26 37 - // 70 41 52 63 34 05 16 27 - - _f0 = _mm256_permute2f128_ps(_tmp0, _tmp6, _MM_SHUFFLE(0, 2, 0, 0)); - _f1 = _mm256_permute2f128_ps(_tmp5, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); - _f2 = _mm256_permute2f128_ps(_tmp4, _tmp2, _MM_SHUFFLE(0, 2, 0, 0)); - _f3 = _mm256_permute2f128_ps(_tmp1, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); - _f4 = _mm256_permute2f128_ps(_tmp6, _tmp0, _MM_SHUFFLE(0, 3, 0, 1)); - _f5 = _mm256_permute2f128_ps(_tmp3, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); - _f6 = _mm256_permute2f128_ps(_tmp2, _tmp4, _MM_SHUFFLE(0, 3, 0, 1)); - _f7 = _mm256_permute2f128_ps(_tmp7, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); - - // 00 11 22 33 40 51 62 73 - // 10 21 32 03 50 61 72 43 - // 20 31 02 13 60 71 42 53 - // 30 01 12 23 70 41 52 63 - // 04 15 26 37 44 55 66 77 - // 14 25 36 07 54 65 76 47 - // 24 35 06 17 64 75 46 57 - // 34 05 16 27 74 45 56 67 - - _tmp0 = _mm256_unpacklo_ps(_f0, _f1); - _tmp1 = _mm256_unpacklo_ps(_f2, _f3); - _tmp2 = _mm256_unpackhi_ps(_f2, _f3); - _tmp3 = _mm256_unpackhi_ps(_f0, _f1); - _tmp4 = _mm256_unpacklo_ps(_f4, _f5); - _tmp5 = _mm256_unpacklo_ps(_f6, _f7); - _tmp6 = _mm256_unpackhi_ps(_f6, _f7); - _tmp7 = _mm256_unpackhi_ps(_f4, _f5); - - // 00 10 11 21 40 50 51 61 - // 20 30 31 01 60 70 71 41 - // 02 12 13 23 42 52 53 63 - // 22 32 33 03 62 72 73 43 - - // 04 14 15 25 44 54 55 65 - // 24 34 35 05 64 74 75 45 - // 06 16 17 27 46 56 57 67 - // 26 36 37 07 66 76 77 47 - - _f0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp1))); - _f1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp0), _mm256_castps_pd(_tmp1))); - _f2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp2), _mm256_castps_pd(_tmp3))); - _f3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp2), _mm256_castps_pd(_tmp3))); - _f4 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp4), _mm256_castps_pd(_tmp5))); - _f5 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp4), _mm256_castps_pd(_tmp5))); - _f6 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_tmp6), _mm256_castps_pd(_tmp7))); - _f7 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_tmp6), _mm256_castps_pd(_tmp7))); - - // 00 10 20 30 40 50 60 70 - // 11 21 31 01 51 61 71 41 - // 02 12 22 32 42 52 62 72 - // 13 23 33 03 53 63 73 43 - // 04 14 24 34 44 54 64 74 - // 15 25 35 05 55 65 75 45 - // 06 16 26 36 46 56 66 76 - // 17 27 37 07 57 67 77 47 + __m256 _tmp4 = _f4; + __m256 _tmp5 = _mm256_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); + __m256 _tmp6 = _f6; + __m256 _tmp7 = _mm256_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); - _f1 = _mm256_shuffle_ps(_f1, _f1, _MM_SHUFFLE(2, 1, 0, 3)); - _f3 = _mm256_shuffle_ps(_f3, _f3, _MM_SHUFFLE(2, 1, 0, 3)); - _f5 = _mm256_shuffle_ps(_f5, _f5, _MM_SHUFFLE(2, 1, 0, 3)); - _f7 = _mm256_shuffle_ps(_f7, _f7, _MM_SHUFFLE(2, 1, 0, 3)); + _f0 = _mm256_unpacklo_ps(_tmp0, _tmp3); + _f1 = _mm256_unpackhi_ps(_tmp0, _tmp3); + _f2 = _mm256_unpacklo_ps(_tmp2, _tmp1); + _f3 = _mm256_unpackhi_ps(_tmp2, _tmp1); + _f4 = _mm256_unpacklo_ps(_tmp4, _tmp7); + _f5 = _mm256_unpackhi_ps(_tmp4, _tmp7); + _f6 = _mm256_unpacklo_ps(_tmp6, _tmp5); + _f7 = _mm256_unpackhi_ps(_tmp6, _tmp5); + + _tmp0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f0), _mm256_castps_pd(_f2))); + _tmp1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f0), _mm256_castps_pd(_f2))); + _tmp2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f3), _mm256_castps_pd(_f1))); + _tmp3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f3), _mm256_castps_pd(_f1))); + _tmp4 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f4), _mm256_castps_pd(_f6))); + _tmp5 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f4), _mm256_castps_pd(_f6))); + _tmp6 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_f7), _mm256_castps_pd(_f5))); + _tmp7 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_f7), _mm256_castps_pd(_f5))); + + _tmp1 = _mm256_shuffle_ps(_tmp1, _tmp1, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp3 = _mm256_shuffle_ps(_tmp3, _tmp3, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp5 = _mm256_shuffle_ps(_tmp5, _tmp5, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp7 = _mm256_shuffle_ps(_tmp7, _tmp7, _MM_SHUFFLE(2, 1, 0, 3)); + + _f0 = _mm256_permute2f128_ps(_tmp0, _tmp4, _MM_SHUFFLE(0, 3, 0, 0)); + _f1 = _mm256_permute2f128_ps(_tmp1, _tmp5, _MM_SHUFFLE(0, 3, 0, 0)); + _f2 = _mm256_permute2f128_ps(_tmp2, _tmp6, _MM_SHUFFLE(0, 3, 0, 0)); + _f3 = _mm256_permute2f128_ps(_tmp3, _tmp7, _MM_SHUFFLE(0, 3, 0, 0)); + _f4 = _mm256_permute2f128_ps(_tmp4, _tmp0, _MM_SHUFFLE(0, 3, 0, 0)); + _f5 = _mm256_permute2f128_ps(_tmp5, _tmp1, _MM_SHUFFLE(0, 3, 0, 0)); + _f6 = _mm256_permute2f128_ps(_tmp6, _tmp2, _MM_SHUFFLE(0, 3, 0, 0)); + _f7 = _mm256_permute2f128_ps(_tmp7, _tmp3, _MM_SHUFFLE(0, 3, 0, 0)); } #else // __AVX2__ __m256 _f0 = _mm256_cvtepi32_ps(_mm256_loadu_si256((const __m256i*)pp)); @@ -17727,22 +17690,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA0, _pB2)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA0, _pB3)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); - _sum8 = _mm512_add_epi32(_sum8, _mm512_madd_epi16(_pA2, _pB0)); - _sum9 = _mm512_add_epi32(_sum9, _mm512_madd_epi16(_pA2, _pB1)); - _suma = _mm512_add_epi32(_suma, _mm512_madd_epi16(_pA3, _pB0)); - _sumb = _mm512_add_epi32(_sumb, _mm512_madd_epi16(_pA3, _pB1)); - _sumc = _mm512_add_epi32(_sumc, _mm512_madd_epi16(_pA2, _pB2)); - _sumd = _mm512_add_epi32(_sumd, _mm512_madd_epi16(_pA2, _pB3)); - _sume = _mm512_add_epi32(_sume, _mm512_madd_epi16(_pA3, _pB2)); - _sumf = _mm512_add_epi32(_sumf, _mm512_madd_epi16(_pA3, _pB3)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); + _sum8 = _mm512_comp_dpwssd_epi32(_sum8, _pA2, _pB0); + _sum9 = _mm512_comp_dpwssd_epi32(_sum9, _pA2, _pB1); + _suma = _mm512_comp_dpwssd_epi32(_suma, _pA3, _pB0); + _sumb = _mm512_comp_dpwssd_epi32(_sumb, _pA3, _pB1); + _sumc = _mm512_comp_dpwssd_epi32(_sumc, _pA2, _pB2); + _sumd = _mm512_comp_dpwssd_epi32(_sumd, _pA2, _pB3); + _sume = _mm512_comp_dpwssd_epi32(_sume, _pA3, _pB2); + _sumf = _mm512_comp_dpwssd_epi32(_sumf, _pA3, _pB3); pA += 32; pB += 32; @@ -17771,39 +17734,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); - __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); - __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); - __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); - __m512i _s4 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2)); - __m512i _s5 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3)); - __m512i _s6 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2)); - __m512i _s7 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3)); - __m512i _s8 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB0)); - __m512i _s9 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB1)); - __m512i _sa = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB0)); - __m512i _sb = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB1)); - __m512i _sc = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB2)); - __m512i _sd = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB3)); - __m512i _se = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB2)); - __m512i _sf = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB3)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); - _sum1 = _mm512_add_epi32(_sum1, _s1); - _sum2 = _mm512_add_epi32(_sum2, _s2); - _sum3 = _mm512_add_epi32(_sum3, _s3); - _sum4 = _mm512_add_epi32(_sum4, _s4); - _sum5 = _mm512_add_epi32(_sum5, _s5); - _sum6 = _mm512_add_epi32(_sum6, _s6); - _sum7 = _mm512_add_epi32(_sum7, _s7); - _sum8 = _mm512_add_epi32(_sum8, _s8); - _sum9 = _mm512_add_epi32(_sum9, _s9); - _suma = _mm512_add_epi32(_suma, _sa); - _sumb = _mm512_add_epi32(_sumb, _sb); - _sumc = _mm512_add_epi32(_sumc, _sc); - _sumd = _mm512_add_epi32(_sumd, _sd); - _sume = _mm512_add_epi32(_sume, _se); - _sumf = _mm512_add_epi32(_sumf, _sf); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); + _sum4 = _mm512_add_epi32(_sum4, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2))); + _sum5 = _mm512_add_epi32(_sum5, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3))); + _sum6 = _mm512_add_epi32(_sum6, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2))); + _sum7 = _mm512_add_epi32(_sum7, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3))); + _sum8 = _mm512_add_epi32(_sum8, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB0))); + _sum9 = _mm512_add_epi32(_sum9, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB1))); + _suma = _mm512_add_epi32(_suma, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB0))); + _sumb = _mm512_add_epi32(_sumb, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB1))); + _sumc = _mm512_add_epi32(_sumc, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB2))); + _sumd = _mm512_add_epi32(_sumd, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA2, _pB3))); + _sume = _mm512_add_epi32(_sume, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB2))); + _sumf = _mm512_add_epi32(_sumf, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA3, _pB3))); pA += 16; pB += 16; @@ -17921,14 +17867,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA0, _pB2)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA0, _pB3)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA1, _pB3); pA += 32; pB += 16; @@ -17954,23 +17900,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); - __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); - __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); - __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); - __m512i _s4 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2)); - __m512i _s5 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3)); - __m512i _s6 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2)); - __m512i _s7 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); - _sum1 = _mm512_add_epi32(_sum1, _s1); - _sum2 = _mm512_add_epi32(_sum2, _s2); - _sum3 = _mm512_add_epi32(_sum3, _s3); - _sum4 = _mm512_add_epi32(_sum4, _s4); - _sum5 = _mm512_add_epi32(_sum5, _s5); - _sum6 = _mm512_add_epi32(_sum6, _s6); - _sum7 = _mm512_add_epi32(_sum7, _s7); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); + _sum4 = _mm512_add_epi32(_sum4, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB2))); + _sum5 = _mm512_add_epi32(_sum5, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB3))); + _sum6 = _mm512_add_epi32(_sum6, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB2))); + _sum7 = _mm512_add_epi32(_sum7, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB3))); pA += 16; pB += 8; @@ -18053,10 +17990,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 1230 1230 1230 1230 __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 32; pB += 8; @@ -18077,15 +18014,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 12301230 12301230 __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); - __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); - __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); - __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); - _sum1 = _mm512_add_epi32(_sum1, _s1); - _sum2 = _mm512_add_epi32(_sum2, _s2); - _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); pA += 16; pB += 4; @@ -18149,8 +18081,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 1010 1010 1010 1010 __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); pA += 32; pB += 4; @@ -18169,11 +18101,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 10101010 10101010 __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 1, 0, 1)), _MM_SHUFFLE(0, 1, 0, 1)); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); - __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); - _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); pA += 16; pB += 2; @@ -18226,7 +18155,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0xxx0xxx0xxx0xxx -> 00000000... __m512i _pB0 = _mm512_shuffle_epi32(_pBBBB, _MM_PERM_AAAA); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); pA += 32; pB += 2; @@ -18238,9 +18167,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB))); pA += 16; pB += 1; @@ -18369,14 +18296,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); // __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB1)); - _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA00, _pB2)); - _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA00, _pB3)); - _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA11, _pB2)); - _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA11, _pB3)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA00, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA00, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA11, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA11, _pB1); + _sum4 = _mm512_comp_dpwssd_epi32(_sum4, _pA00, _pB2); + _sum5 = _mm512_comp_dpwssd_epi32(_sum5, _pA00, _pB3); + _sum6 = _mm512_comp_dpwssd_epi32(_sum6, _pA11, _pB2); + _sum7 = _mm512_comp_dpwssd_epi32(_sum7, _pA11, _pB3); pA += 16; pB += 32; @@ -18410,23 +18337,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); // __m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(2, 1, 0, 3)), _MM_SHUFFLE(2, 1, 0, 3)); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB0)); - __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB1)); - __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB0)); - __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB1)); - __m512i _s4 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB2)); - __m512i _s5 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB3)); - __m512i _s6 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB2)); - __m512i _s7 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB3)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); - _sum1 = _mm512_add_epi32(_sum1, _s1); - _sum2 = _mm512_add_epi32(_sum2, _s2); - _sum3 = _mm512_add_epi32(_sum3, _s3); - _sum4 = _mm512_add_epi32(_sum4, _s4); - _sum5 = _mm512_add_epi32(_sum5, _s5); - _sum6 = _mm512_add_epi32(_sum6, _s6); - _sum7 = _mm512_add_epi32(_sum7, _s7); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB1))); + _sum4 = _mm512_add_epi32(_sum4, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB2))); + _sum5 = _mm512_add_epi32(_sum5, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA00, _pB3))); + _sum6 = _mm512_add_epi32(_sum6, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB2))); + _sum7 = _mm512_add_epi32(_sum7, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA11, _pB3))); pA += 8; pB += 16; @@ -18486,10 +18404,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, { __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); - __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(0, 1, 2, 3)); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); - __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); _sum0 = _mm256_comp_dpbusd_epi32(_sum0, _pB0, _pA0); _sum1 = _mm256_comp_dpbusd_epi32(_sum1, _pB1, _pA0); _sum2 = _mm256_comp_dpbusd_epi32(_sum2, _pB0, _pA1); @@ -18524,78 +18442,51 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pA0 = _mm256_cvtepi8_epi16(_pA); __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); - // // 0123 4567 - // // 4567 0123 - // __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); - // 0123 4567 - // 6745 2301 - __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(0, 1, 2, 3)); + // 2301 6745 + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); // 0123 4567 // 1230 5674 - // 2301 6745 - // 3012 7456 + // 4567 0123 + // 5674 1230 __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); - __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); - __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB2, _MM_SHUFFLE(0, 3, 2, 1)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); - _sum4 = _mm256_add_epi32(_sum4, _mm256_madd_epi16(_pA0, _pB2)); - _sum5 = _mm256_add_epi32(_sum5, _mm256_madd_epi16(_pA0, _pB3)); - _sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2)); - _sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3)); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm256_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm256_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm256_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm256_comp_dpwssd_epi32(_sum7, _pA1, _pB3); pA += 16; pB += 16; } for (; kk < max_kk; kk += 1) { - __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); - __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); - - _pA = _mm_cvtepi8_epi16(_pA); - _pB = _mm_cvtepi8_epi16(_pB); + __m128i _pA0 = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB0 = _mm_loadl_epi64((const __m128i*)pB); - // // 0123 4567 - // // 4567 0123 - // __m128i _pA0 = _pA; - // __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(1, 0, 3, 2)); + _pA0 = _mm_cvtepi8_epi16(_pA0); + _pB0 = _mm_cvtepi8_epi16(_pB0); - // 0123 4567 - // 6745 2301 - __m128i _pA0 = _pA; - __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(0, 1, 2, 3)); + __m128i _pA1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pA0, _MM_SHUFFLE(1, 0, 3, 2)), _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB3 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - // 0123 4567 - // 1230 5674 - // 2301 6745 - // 3012 7456 - __m128i _pB0 = _pB; - __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m128i _pB2 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(2, 3, 0, 1)); - __m128i _pB3 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(2, 1, 0, 3)), _MM_SHUFFLE(2, 1, 0, 3)); - - __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0)); - __m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1)); - __m256i _s2 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0)); - __m256i _s3 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1)); - __m256i _s4 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB2)); - __m256i _s5 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB3)); - __m256i _s6 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB2)); - __m256i _s7 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB3)); - - _sum0 = _mm256_add_epi32(_sum0, _s0); - _sum1 = _mm256_add_epi32(_sum1, _s1); - _sum2 = _mm256_add_epi32(_sum2, _s2); - _sum3 = _mm256_add_epi32(_sum3, _s3); - _sum4 = _mm256_add_epi32(_sum4, _s4); - _sum5 = _mm256_add_epi32(_sum5, _s5); - _sum6 = _mm256_add_epi32(_sum6, _s6); - _sum7 = _mm256_add_epi32(_sum7, _s7); + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm256_add_epi32(_sum1, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm256_add_epi32(_sum2, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm256_add_epi32(_sum3, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1))); + _sum4 = _mm256_add_epi32(_sum4, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB2))); + _sum5 = _mm256_add_epi32(_sum5, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB3))); + _sum6 = _mm256_add_epi32(_sum6, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB2))); + _sum7 = _mm256_add_epi32(_sum7, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB3))); pA += 8; pB += 8; @@ -18680,10 +18571,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 1230 1230 __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); - _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); - _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 16; pB += 8; @@ -18706,15 +18597,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _pB; __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0)); - __m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1)); - __m256i _s2 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0)); - __m256i _s3 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1)); - - _sum0 = _mm256_add_epi32(_sum0, _s0); - _sum1 = _mm256_add_epi32(_sum1, _s1); - _sum2 = _mm256_add_epi32(_sum2, _s2); - _sum3 = _mm256_add_epi32(_sum3, _s3); + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm256_add_epi32(_sum1, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm256_add_epi32(_sum2, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm256_add_epi32(_sum3, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA1, _pB1))); pA += 8; pB += 4; @@ -18779,8 +18665,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 1010 1010 __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); - _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_comp_dpwssd_epi32(_sum1, _pA0, _pB1); pA += 16; pB += 4; @@ -18800,11 +18686,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _pB; __m128i _pB1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 1, 0, 1)), _MM_SHUFFLE(0, 1, 0, 1)); - __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB0)); - __m256i _s1 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB1)); - - _sum0 = _mm256_add_epi32(_sum0, _s0); - _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB0))); + _sum1 = _mm256_add_epi32(_sum1, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB1))); pA += 8; pB += 2; @@ -18867,7 +18750,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0xxx0xxx -> 00000000 11111111 __m256i _pB0 = _mm256_shuffle_epi32(_pBB, _MM_SHUFFLE(0, 0, 0, 0)); - _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum0 = _mm256_comp_dpwssd_epi32(_sum0, _pA0, _pB0); pA += 16; pB += 2; @@ -18879,9 +18762,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _pA = _mm_cvtepi8_epi16(_pA); - __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB)); - - _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum0 = _mm256_add_epi32(_sum0, _mm256_cvtepi16_epi32(_mm_mullo_epi16(_pA, _pB))); pA += 8; pB += 1; @@ -18974,10 +18855,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 1230 5674 9ab8 defc __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); - _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); - _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 8; pB += 32; @@ -18998,15 +18879,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 12305674 9ab8defc __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); - __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); - __m512i _s2 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0)); - __m512i _s3 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); - _sum1 = _mm512_add_epi32(_sum1, _s1); - _sum2 = _mm512_add_epi32(_sum2, _s2); - _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); + _sum2 = _mm512_add_epi32(_sum2, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB0))); + _sum3 = _mm512_add_epi32(_sum3, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA1, _pB1))); pA += 4; pB += 16; @@ -19120,25 +18996,14 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB2 = _mm_shuffle_epi32(_pBl, _MM_SHUFFLE(0, 3, 2, 1)); __m128i _pB3 = _mm_shuffle_epi32(_pBh, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); - _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); - _sum4 = _mm_maddd_epi16(_pA0, _pB2, _sum4); - _sum5 = _mm_maddd_epi16(_pA0, _pB3, _sum5); - _sum6 = _mm_maddd_epi16(_pA1, _pB2, _sum6); - _sum7 = _mm_maddd_epi16(_pA1, _pB3, _sum7); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); - _sum4 = _mm_add_epi32(_sum4, _mm_madd_epi16(_pA0, _pB2)); - _sum5 = _mm_add_epi32(_sum5, _mm_madd_epi16(_pA0, _pB3)); - _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); - _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); + _sum4 = _mm_comp_dpwssd_epi32(_sum4, _pA0, _pB2); + _sum5 = _mm_comp_dpwssd_epi32(_sum5, _pA0, _pB3); + _sum6 = _mm_comp_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm_comp_dpwssd_epi32(_sum7, _pA1, _pB3); pA += 8; pB += 16; @@ -19305,17 +19170,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _pB; __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); - _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 8; pB += 8; @@ -19441,13 +19299,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _pB; __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(2, 3, 0, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 8; pB += 4; @@ -19553,11 +19406,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, _pB = _mm_unpacklo_epi8(_pB, _mm_cmpgt_epi8(_mm_setzero_si128(), _pB)); #endif -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB, _sum0); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB); pA += 8; pB += 2; @@ -19662,8 +19511,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 1230 5674 9ab8 defc __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); - _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_comp_dpwssd_epi32(_sum1, _pA0, _pB1); pA += 4; pB += 32; @@ -19682,11 +19531,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 12305674 9ab8defc __m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0)); - __m512i _s1 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); - _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB0))); + _sum1 = _mm512_add_epi32(_sum1, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA0, _pB1))); pA += 2; pB += 16; @@ -19770,17 +19616,10 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0123 // 4567 -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); - _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); - _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); - _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); - _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm_comp_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm_comp_dpwssd_epi32(_sum3, _pA1, _pB1); pA += 4; pB += 16; @@ -19888,13 +19727,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _pB; __m128i _pB1 = _mm_shuffle_epi32(_pB, _MM_SHUFFLE(0, 3, 2, 1)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 4; pB += 8; @@ -20146,7 +19980,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m512i _pA0 = _mm512_cvtepi8_epi16(_pA); __m512i _pB0 = _mm512_cvtepi8_epi16(_pB); - _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum0 = _mm512_comp_dpwssd_epi32(_sum0, _pA0, _pB0); pA += 2; pB += 32; @@ -20158,9 +19992,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m256i _pB0 = _mm256_cvtepi8_epi16(_pB); - __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA, _pB0)); - - _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum0 = _mm512_add_epi32(_sum0, _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_pA, _pB0))); pA += 1; pB += 16; @@ -20223,13 +20055,8 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, __m128i _pB0 = _mm_unpacklo_epi8(_pB, _extpB); __m128i _pB1 = _mm_unpackhi_epi8(_pB, _extpB); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); - _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); - _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm_comp_dpwssd_epi32(_sum1, _pA, _pB1); pA += 2; pB += 16; @@ -20310,11 +20137,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, // 0xxx -> 0000 __m128i _pA0 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(0, 0, 0, 0)); -#if __XOP__ - _sum0 = _mm_maddd_epi16(_pA0, _pB, _sum0); -#else - _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB)); -#endif + _sum0 = _mm_comp_dpwssd_epi32(_sum0, _pA0, _pB); pA += 2; pB += 8;