From 88e9e9163dcf9e48d1432bfd127355824c7a41f6 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 11 Dec 2024 03:41:54 +0000 Subject: [PATCH] w --- src/layer/x86/gemm_int8.h | 836 ++++++++-------------------------- src/layer/x86/x86_usability.h | 33 +- 2 files changed, 219 insertions(+), 650 deletions(-) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index c933824c773..7e668a267c9 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -200,67 +200,44 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in for (; ii + 3 < max_ii; ii += 4) { const signed char* p0 = A.row(i + ii) + k; + +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A.w)); +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ +#else const signed char* p1 = A.row(i + ii + 1) + k; const signed char* p2 = A.row(i + ii + 2) + k; const signed char* p3 = A.row(i + ii + 3) + k; +#endif // __AVX2__ int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; + __m128i _w_shift = _mm_setzero_si128(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p0[2]; - pp[3] = p0[3]; - pp[4] = p1[0]; - pp[5] = p1[1]; - pp[6] = p1[2]; - pp[7] = p1[3]; - pp[8] = p2[0]; - pp[9] = p2[1]; - pp[10] = p2[2]; - pp[11] = p2[3]; - pp[12] = p3[0]; - pp[13] = p3[1]; - pp[14] = p3[2]; - pp[15] = p3[3]; - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; + __m128i _p = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _p); + _mm_storeu_si128((__m128i*)pp, _p); pp += 16; p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; + _mm_storeu_si128((__m128i*)pp, _w_shift); pp += 16; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi16(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0 += 2; +#else pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -274,9 +251,16 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in p1 += 2; p2 += 2; p3 += 2; +#endif // __AVX2__ } for (; kk < max_kk; kk++) { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi8(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storeu_si32(pp, _p); + pp += 4; + p0++; +#else pp[0] = p0[0]; pp[1] = p1[0]; pp[2] = p2[0]; @@ -286,6 +270,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in p1++; p2++; p3++; +#endif // __AVX2__ } } #endif // __SSE2__ @@ -425,238 +410,47 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, { const signed char* p0 = A.row(k) + (i + ii); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ + int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0]; - pp[1] = p0[A_hstep]; - pp[2] = p0[A_hstep * 2]; - pp[3] = p0[A_hstep * 3]; - pp[4] = p0[1]; - pp[5] = p0[A_hstep + 1]; - pp[6] = p0[A_hstep * 2 + 1]; - pp[7] = p0[A_hstep * 3 + 1]; - pp[8] = p0[2]; - pp[9] = p0[A_hstep + 2]; - pp[10] = p0[A_hstep * 2 + 2]; - pp[11] = p0[A_hstep * 3 + 2]; - pp[12] = p0[3]; - pp[13] = p0[A_hstep + 3]; - pp[14] = p0[A_hstep * 2 + 3]; - pp[15] = p0[A_hstep * 3 + 3]; - pp[16] = p0[4]; - pp[17] = p0[A_hstep + 4]; - pp[18] = p0[A_hstep * 2 + 4]; - pp[19] = p0[A_hstep * 3 + 4]; - pp[20] = p0[5]; - pp[21] = p0[A_hstep + 5]; - pp[22] = p0[A_hstep * 2 + 5]; - pp[23] = p0[A_hstep * 3 + 5]; - pp[24] = p0[6]; - pp[25] = p0[A_hstep + 6]; - pp[26] = p0[A_hstep * 2 + 6]; - pp[27] = p0[A_hstep * 3 + 6]; - pp[28] = p0[7]; - pp[29] = p0[A_hstep + 7]; - pp[30] = p0[A_hstep * 2 + 7]; - pp[31] = p0[A_hstep * 3 + 7]; - - pp[32 + 0] = p0[8]; - pp[32 + 1] = p0[A_hstep + 8]; - pp[32 + 2] = p0[A_hstep * 2 + 8]; - pp[32 + 3] = p0[A_hstep * 3 + 8]; - pp[32 + 4] = p0[9]; - pp[32 + 5] = p0[A_hstep + 9]; - pp[32 + 6] = p0[A_hstep * 2 + 9]; - pp[32 + 7] = p0[A_hstep * 3 + 9]; - pp[32 + 8] = p0[10]; - pp[32 + 9] = p0[A_hstep + 10]; - pp[32 + 10] = p0[A_hstep * 2 + 10]; - pp[32 + 11] = p0[A_hstep * 3 + 10]; - pp[32 + 12] = p0[11]; - pp[32 + 13] = p0[A_hstep + 11]; - pp[32 + 14] = p0[A_hstep * 2 + 11]; - pp[32 + 15] = p0[A_hstep * 3 + 11]; - pp[32 + 16] = p0[12]; - pp[32 + 17] = p0[A_hstep + 12]; - pp[32 + 18] = p0[A_hstep * 2 + 12]; - pp[32 + 19] = p0[A_hstep * 3 + 12]; - pp[32 + 20] = p0[13]; - pp[32 + 21] = p0[A_hstep + 13]; - pp[32 + 22] = p0[A_hstep * 2 + 13]; - pp[32 + 23] = p0[A_hstep * 3 + 13]; - pp[32 + 24] = p0[14]; - pp[32 + 25] = p0[A_hstep + 14]; - pp[32 + 26] = p0[A_hstep * 2 + 14]; - pp[32 + 27] = p0[A_hstep * 3 + 14]; - pp[32 + 28] = p0[15]; - pp[32 + 29] = p0[A_hstep + 15]; - pp[32 + 30] = p0[A_hstep * 2 + 15]; - pp[32 + 31] = p0[A_hstep * 3 + 15]; - - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; - - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; - + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep)); + __m128i _p2 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep * 2)); + __m128i _p3 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep * 3)); + transpose16x4_epi8(_p0, _p1, _p2, _p3); + __m512i _pp = combine4x4_epi32(_p0, _p1, _p2, _p3); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += A_hstep * 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; + _mm512_storeu_si512((__m512i*)pp, _w_shift); pp += 64; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = p0[0]; - pp[1] = p0[A_hstep]; - pp[2] = p0[1]; - pp[3] = p0[A_hstep + 1]; - pp[4] = p0[2]; - pp[5] = p0[A_hstep + 2]; - pp[6] = p0[3]; - pp[7] = p0[A_hstep + 3]; - pp[8] = p0[4]; - pp[9] = p0[A_hstep + 4]; - pp[10] = p0[5]; - pp[11] = p0[A_hstep + 5]; - pp[12] = p0[6]; - pp[13] = p0[A_hstep + 6]; - pp[14] = p0[7]; - pp[15] = p0[A_hstep + 7]; - - pp[16 + 0] = p0[8]; - pp[16 + 1] = p0[A_hstep + 8]; - pp[16 + 2] = p0[9]; - pp[16 + 3] = p0[A_hstep + 9]; - pp[16 + 4] = p0[10]; - pp[16 + 5] = p0[A_hstep + 10]; - pp[16 + 6] = p0[11]; - pp[16 + 7] = p0[A_hstep + 11]; - pp[16 + 8] = p0[12]; - pp[16 + 9] = p0[A_hstep + 12]; - pp[16 + 10] = p0[13]; - pp[16 + 11] = p0[A_hstep + 13]; - pp[16 + 12] = p0[14]; - pp[16 + 13] = p0[A_hstep + 14]; - pp[16 + 14] = p0[15]; - pp[16 + 15] = p0[A_hstep + 15]; + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + A_hstep)); + __m128i _t0 = _mm_unpacklo_epi8(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi8(_p0, _p1); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); pp += 32; p0 += A_hstep * 2; } for (; kk < max_kk; kk++) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p0[2]; - pp[3] = p0[3]; - pp[4] = p0[4]; - pp[5] = p0[5]; - pp[6] = p0[6]; - pp[7] = p0[7]; - pp[8] = p0[8]; - pp[9] = p0[9]; - pp[10] = p0[10]; - pp[11] = p0[11]; - pp[12] = p0[12]; - pp[13] = p0[13]; - pp[14] = p0[14]; - pp[15] = p0[15]; + __m128i _p = _mm_loadu_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _p); pp += 16; p0 += A_hstep; } @@ -666,129 +460,45 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, { const signed char* p0 = A.row(k) + (i + ii); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ + int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0]; - pp[1] = p0[A_hstep]; - pp[2] = p0[A_hstep * 2]; - pp[3] = p0[A_hstep * 3]; - pp[4] = p0[1]; - pp[5] = p0[A_hstep + 1]; - pp[6] = p0[A_hstep * 2 + 1]; - pp[7] = p0[A_hstep * 3 + 1]; - pp[8] = p0[2]; - pp[9] = p0[A_hstep + 2]; - pp[10] = p0[A_hstep * 2 + 2]; - pp[11] = p0[A_hstep * 3 + 2]; - pp[12] = p0[3]; - pp[13] = p0[A_hstep + 3]; - pp[14] = p0[A_hstep * 2 + 3]; - pp[15] = p0[A_hstep * 3 + 3]; - pp[16] = p0[4]; - pp[17] = p0[A_hstep + 4]; - pp[18] = p0[A_hstep * 2 + 4]; - pp[19] = p0[A_hstep * 3 + 4]; - pp[20] = p0[5]; - pp[21] = p0[A_hstep + 5]; - pp[22] = p0[A_hstep * 2 + 5]; - pp[23] = p0[A_hstep * 3 + 5]; - pp[24] = p0[6]; - pp[25] = p0[A_hstep + 6]; - pp[26] = p0[A_hstep * 2 + 6]; - pp[27] = p0[A_hstep * 3 + 6]; - pp[28] = p0[7]; - pp[29] = p0[A_hstep + 7]; - pp[30] = p0[A_hstep * 2 + 7]; - pp[31] = p0[A_hstep * 3 + 7]; - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 2)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep * 3)); + transpose8x4_epi8(_p0, _p1, _p2, _p3); + __m256i _pp = combine4x2_epi32(_p0, _p1); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp); + _mm256_storeu_si256((__m256i*)pp, _pp); pp += 32; p0 += A_hstep * 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = p0[0]; - pp[1] = p0[A_hstep]; - pp[2] = p0[1]; - pp[3] = p0[A_hstep + 1]; - pp[4] = p0[2]; - pp[5] = p0[A_hstep + 2]; - pp[6] = p0[3]; - pp[7] = p0[A_hstep + 3]; - pp[8] = p0[4]; - pp[9] = p0[A_hstep + 4]; - pp[10] = p0[5]; - pp[11] = p0[A_hstep + 5]; - pp[12] = p0[6]; - pp[13] = p0[A_hstep + 6]; - pp[14] = p0[7]; - pp[15] = p0[A_hstep + 7]; + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + A_hstep)); + __m128i _pp = _mm_unpacklo_epi8(_p0, _p1); + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += A_hstep * 2; } for (; kk < max_kk; kk++) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p0[2]; - pp[3] = p0[3]; - pp[4] = p0[4]; - pp[5] = p0[5]; - pp[6] = p0[6]; - pp[7] = p0[7]; + __m128i _p = _mm_loadl_epi64((const __m128i*)p0); + _mm_storel_epi64((__m128i*)pp, _p); pp += 8; p0 += A_hstep; } @@ -798,55 +508,28 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, { const signed char* p0 = A.row(k) + (i + ii); +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _v127 = _mm_set1_epi8(127); + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(A_hstep)); +#endif // __AVX512VNNI__ || __AVXVNNI__ + int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; + __m128i _w_shift = _mm_setzero_si128(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0]; - pp[1] = p0[A_hstep]; - pp[2] = p0[A_hstep * 2]; - pp[3] = p0[A_hstep * 3]; - pp[4] = p0[1]; - pp[5] = p0[A_hstep + 1]; - pp[6] = p0[A_hstep * 2 + 1]; - pp[7] = p0[A_hstep * 3 + 1]; - pp[8] = p0[2]; - pp[9] = p0[A_hstep + 2]; - pp[10] = p0[A_hstep * 2 + 2]; - pp[11] = p0[A_hstep * 3 + 2]; - pp[12] = p0[3]; - pp[13] = p0[A_hstep + 3]; - pp[14] = p0[A_hstep * 2 + 3]; - pp[15] = p0[A_hstep * 3 + 3]; - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; + __m128i _pp = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + _w_shift = _mm_comp_dpbusd_epi32(_w_shift, _v127, _pp); + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += A_hstep * 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; + _mm_storeu_si128((__m128i*)pp, _w_shift); pp += 16; } #endif // __AVX512VNNI__ || __AVXVNNI__ @@ -1032,6 +715,14 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in for (; jj + 7 < max_jj; jj += 8) { const signed char* p0 = B.row(j + jj) + k; + +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(B.w)); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ +#else const signed char* p1 = B.row(j + jj + 1) + k; const signed char* p2 = B.row(j + jj + 2) + k; const signed char* p3 = B.row(j + jj + 3) + k; @@ -1039,56 +730,27 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in const signed char* p5 = B.row(j + jj + 5) + k; const signed char* p6 = B.row(j + jj + 6) + k; const signed char* p7 = B.row(j + jj + 7) + k; +#endif // __AVX2__ int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0] + 127; - pp[1] = p0[1] + 127; - pp[2] = p0[2] + 127; - pp[3] = p0[3] + 127; - pp[4] = p1[0] + 127; - pp[5] = p1[1] + 127; - pp[6] = p1[2] + 127; - pp[7] = p1[3] + 127; - pp[8] = p2[0] + 127; - pp[9] = p2[1] + 127; - pp[10] = p2[2] + 127; - pp[11] = p2[3] + 127; - pp[12] = p3[0] + 127; - pp[13] = p3[1] + 127; - pp[14] = p3[2] + 127; - pp[15] = p3[3] + 127; - pp[16 + 0] = p4[0] + 127; - pp[16 + 1] = p4[1] + 127; - pp[16 + 2] = p4[2] + 127; - pp[16 + 3] = p4[3] + 127; - pp[16 + 4] = p5[0] + 127; - pp[16 + 5] = p5[1] + 127; - pp[16 + 6] = p5[2] + 127; - pp[16 + 7] = p5[3] + 127; - pp[16 + 8] = p6[0] + 127; - pp[16 + 9] = p6[1] + 127; - pp[16 + 10] = p6[2] + 127; - pp[16 + 11] = p6[3] + 127; - pp[16 + 12] = p7[0] + 127; - pp[16 + 13] = p7[1] + 127; - pp[16 + 14] = p7[2] + 127; - pp[16 + 15] = p7[3] + 127; + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _p = _mm256_add_epi8(_p, _v127); + _mm256_storeu_si256((__m256i*)pp, _p); pp += 32; p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { +#if __AVX2__ + __m128i _p = _mm256_comp_cvtepi32_epi16(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storeu_si128((__m128i*)pp, _p); + pp += 16; + p0 += 2; +#else pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -1114,9 +776,16 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in p5 += 2; p6 += 2; p7 += 2; +#endif // __AVX2__ } for (; kk < max_kk; kk++) { +#if __AVX2__ + __m128i _p = _mm256_comp_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0++; +#else pp[0] = p0[0]; pp[1] = p1[0]; pp[2] = p2[0]; @@ -1134,45 +803,45 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in p5++; p6++; p7++; +#endif // __AVX2__ } } #endif // defined(__x86_64__) || defined(_M_X64) for (; jj + 3 < max_jj; jj += 4) { const signed char* p0 = B.row(j + jj) + k; + +#if __AVX2__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B.w)); +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ +#else const signed char* p1 = B.row(j + jj + 1) + k; const signed char* p2 = B.row(j + jj + 2) + k; const signed char* p3 = B.row(j + jj + 3) + k; +#endif // __AVX2__ int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0] + 127; - pp[1] = p0[1] + 127; - pp[2] = p0[2] + 127; - pp[3] = p0[3] + 127; - pp[4] = p1[0] + 127; - pp[5] = p1[1] + 127; - pp[6] = p1[2] + 127; - pp[7] = p1[3] + 127; - pp[8] = p2[0] + 127; - pp[9] = p2[1] + 127; - pp[10] = p2[2] + 127; - pp[11] = p2[3] + 127; - pp[12] = p3[0] + 127; - pp[13] = p3[1] + 127; - pp[14] = p3[2] + 127; - pp[15] = p3[3] + 127; + __m128i _p = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _p = _mm_add_epi8(_p, _v127); + _mm_storeu_si128((__m128i*)pp, _p); pp += 16; p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi16(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); + pp += 8; + p0 += 2; +#else pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -1186,9 +855,16 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in p1 += 2; p2 += 2; p3 += 2; +#endif // __AVX2__ } for (; kk < max_kk; kk++) { +#if __AVX2__ + __m128i _p = _mm_comp_cvtepi32_epi8(_mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storeu_si32(pp, _p); + pp += 4; + p0++; +#else pp[0] = p0[0]; pp[1] = p1[0]; pp[2] = p2[0]; @@ -1198,6 +874,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in p1++; p2++; p3++; +#endif // __AVX2__ } } #endif // __SSE2__ @@ -1303,135 +980,41 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, { const signed char* p0 = B.row(k) + (j + jj); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ + int kk = 0; #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0] + 127; - pp[1] = p0[B_hstep] + 127; - pp[2] = p0[B_hstep * 2] + 127; - pp[3] = p0[B_hstep * 3] + 127; - pp[4] = p0[1] + 127; - pp[5] = p0[B_hstep + 1] + 127; - pp[6] = p0[B_hstep * 2 + 1] + 127; - pp[7] = p0[B_hstep * 3 + 1] + 127; - pp[8] = p0[2] + 127; - pp[9] = p0[B_hstep + 2] + 127; - pp[10] = p0[B_hstep * 2 + 2] + 127; - pp[11] = p0[B_hstep * 3 + 2] + 127; - pp[12] = p0[3] + 127; - pp[13] = p0[B_hstep + 3] + 127; - pp[14] = p0[B_hstep * 2 + 3] + 127; - pp[15] = p0[B_hstep * 3 + 3] + 127; - pp[16] = p0[4] + 127; - pp[17] = p0[B_hstep + 4] + 127; - pp[18] = p0[B_hstep * 2 + 4] + 127; - pp[19] = p0[B_hstep * 3 + 4] + 127; - pp[20] = p0[5] + 127; - pp[21] = p0[B_hstep + 5] + 127; - pp[22] = p0[B_hstep * 2 + 5] + 127; - pp[23] = p0[B_hstep * 3 + 5] + 127; - pp[24] = p0[6] + 127; - pp[25] = p0[B_hstep + 6] + 127; - pp[26] = p0[B_hstep * 2 + 6] + 127; - pp[27] = p0[B_hstep * 3 + 6] + 127; - pp[28] = p0[7] + 127; - pp[29] = p0[B_hstep + 7] + 127; - pp[30] = p0[B_hstep * 2 + 7] + 127; - pp[31] = p0[B_hstep * 3 + 7] + 127; - - pp[32 + 0] = p0[8] + 127; - pp[32 + 1] = p0[B_hstep + 8] + 127; - pp[32 + 2] = p0[B_hstep * 2 + 8] + 127; - pp[32 + 3] = p0[B_hstep * 3 + 8] + 127; - pp[32 + 4] = p0[9] + 127; - pp[32 + 5] = p0[B_hstep + 9] + 127; - pp[32 + 6] = p0[B_hstep * 2 + 9] + 127; - pp[32 + 7] = p0[B_hstep * 3 + 9] + 127; - pp[32 + 8] = p0[10] + 127; - pp[32 + 9] = p0[B_hstep + 10] + 127; - pp[32 + 10] = p0[B_hstep * 2 + 10] + 127; - pp[32 + 11] = p0[B_hstep * 3 + 10] + 127; - pp[32 + 12] = p0[11] + 127; - pp[32 + 13] = p0[B_hstep + 11] + 127; - pp[32 + 14] = p0[B_hstep * 2 + 11] + 127; - pp[32 + 15] = p0[B_hstep * 3 + 11] + 127; - pp[32 + 16] = p0[12] + 127; - pp[32 + 17] = p0[B_hstep + 12] + 127; - pp[32 + 18] = p0[B_hstep * 2 + 12] + 127; - pp[32 + 19] = p0[B_hstep * 3 + 12] + 127; - pp[32 + 20] = p0[13] + 127; - pp[32 + 21] = p0[B_hstep + 13] + 127; - pp[32 + 22] = p0[B_hstep * 2 + 13] + 127; - pp[32 + 23] = p0[B_hstep * 3 + 13] + 127; - pp[32 + 24] = p0[14] + 127; - pp[32 + 25] = p0[B_hstep + 14] + 127; - pp[32 + 26] = p0[B_hstep * 2 + 14] + 127; - pp[32 + 27] = p0[B_hstep * 3 + 14] + 127; - pp[32 + 28] = p0[15] + 127; - pp[32 + 29] = p0[B_hstep + 15] + 127; - pp[32 + 30] = p0[B_hstep * 2 + 15] + 127; - pp[32 + 31] = p0[B_hstep * 3 + 15] + 127; + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep)); + __m128i _p2 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep * 2)); + __m128i _p3 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep * 3)); + transpose16x4_epi8(_p0, _p1, _p2, _p3); + __m512i _pp = combine4x4_epi32(_p0, _p1, _p2, _p3); + _pp = _mm512_add_epi8(_pp, _v127); + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += B_hstep * 4; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = p0[0]; - pp[1] = p0[B_hstep]; - pp[2] = p0[1]; - pp[3] = p0[B_hstep + 1]; - pp[4] = p0[2]; - pp[5] = p0[B_hstep + 2]; - pp[6] = p0[3]; - pp[7] = p0[B_hstep + 3]; - pp[8] = p0[4]; - pp[9] = p0[B_hstep + 4]; - pp[10] = p0[5]; - pp[11] = p0[B_hstep + 5]; - pp[12] = p0[6]; - pp[13] = p0[B_hstep + 6]; - pp[14] = p0[7]; - pp[15] = p0[B_hstep + 7]; - - pp[16 + 0] = p0[8]; - pp[16 + 1] = p0[B_hstep + 8]; - pp[16 + 2] = p0[9]; - pp[16 + 3] = p0[B_hstep + 9]; - pp[16 + 4] = p0[10]; - pp[16 + 5] = p0[B_hstep + 10]; - pp[16 + 6] = p0[11]; - pp[16 + 7] = p0[B_hstep + 11]; - pp[16 + 8] = p0[12]; - pp[16 + 9] = p0[B_hstep + 12]; - pp[16 + 10] = p0[13]; - pp[16 + 11] = p0[B_hstep + 13]; - pp[16 + 12] = p0[14]; - pp[16 + 13] = p0[B_hstep + 14]; - pp[16 + 14] = p0[15]; - pp[16 + 15] = p0[B_hstep + 15]; + __m128i _p0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _p1 = _mm_loadu_si128((const __m128i*)(p0 + B_hstep)); + __m128i _t0 = _mm_unpacklo_epi8(_p0, _p1); + __m128i _t1 = _mm_unpackhi_epi8(_p0, _p1); + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); pp += 32; p0 += B_hstep * 2; } for (; kk < max_kk; kk++) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p0[2]; - pp[3] = p0[3]; - pp[4] = p0[4]; - pp[5] = p0[5]; - pp[6] = p0[6]; - pp[7] = p0[7]; - pp[8] = p0[8]; - pp[9] = p0[9]; - pp[10] = p0[10]; - pp[11] = p0[11]; - pp[12] = p0[12]; - pp[13] = p0[13]; - pp[14] = p0[14]; - pp[15] = p0[15]; + __m128i _p = _mm_loadu_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _p); pp += 16; p0 += B_hstep; } @@ -1441,77 +1024,39 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, { const signed char* p0 = B.row(k) + (j + jj); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ + int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0] + 127; - pp[1] = p0[B_hstep] + 127; - pp[2] = p0[B_hstep * 2] + 127; - pp[3] = p0[B_hstep * 3] + 127; - pp[4] = p0[1] + 127; - pp[5] = p0[B_hstep + 1] + 127; - pp[6] = p0[B_hstep * 2 + 1] + 127; - pp[7] = p0[B_hstep * 3 + 1] + 127; - pp[8] = p0[2] + 127; - pp[9] = p0[B_hstep + 2] + 127; - pp[10] = p0[B_hstep * 2 + 2] + 127; - pp[11] = p0[B_hstep * 3 + 2] + 127; - pp[12] = p0[3] + 127; - pp[13] = p0[B_hstep + 3] + 127; - pp[14] = p0[B_hstep * 2 + 3] + 127; - pp[15] = p0[B_hstep * 3 + 3] + 127; - pp[16] = p0[4] + 127; - pp[17] = p0[B_hstep + 4] + 127; - pp[18] = p0[B_hstep * 2 + 4] + 127; - pp[19] = p0[B_hstep * 3 + 4] + 127; - pp[20] = p0[5] + 127; - pp[21] = p0[B_hstep + 5] + 127; - pp[22] = p0[B_hstep * 2 + 5] + 127; - pp[23] = p0[B_hstep * 3 + 5] + 127; - pp[24] = p0[6] + 127; - pp[25] = p0[B_hstep + 6] + 127; - pp[26] = p0[B_hstep * 2 + 6] + 127; - pp[27] = p0[B_hstep * 3 + 6] + 127; - pp[28] = p0[7] + 127; - pp[29] = p0[B_hstep + 7] + 127; - pp[30] = p0[B_hstep * 2 + 7] + 127; - pp[31] = p0[B_hstep * 3 + 7] + 127; + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep)); + __m128i _p2 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 2)); + __m128i _p3 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep * 3)); + transpose8x4_epi8(_p0, _p1, _p2, _p3); + __m256i _pp = combine4x2_epi32(_p0, _p1); + _pp = _mm256_add_epi8(_pp, _v127); + _mm256_storeu_si256((__m256i*)pp, _pp); pp += 32; p0 += B_hstep * 4; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = p0[0]; - pp[1] = p0[B_hstep]; - pp[2] = p0[1]; - pp[3] = p0[B_hstep + 1]; - pp[4] = p0[2]; - pp[5] = p0[B_hstep + 2]; - pp[6] = p0[3]; - pp[7] = p0[B_hstep + 3]; - pp[8] = p0[4]; - pp[9] = p0[B_hstep + 4]; - pp[10] = p0[5]; - pp[11] = p0[B_hstep + 5]; - pp[12] = p0[6]; - pp[13] = p0[B_hstep + 6]; - pp[14] = p0[7]; - pp[15] = p0[B_hstep + 7]; + __m128i _p0 = _mm_loadl_epi64((const __m128i*)p0); + __m128i _p1 = _mm_loadl_epi64((const __m128i*)(p0 + B_hstep)); + __m128i _pp = _mm_unpacklo_epi8(_p0, _p1); + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += B_hstep * 2; } for (; kk < max_kk; kk++) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p0[2]; - pp[3] = p0[3]; - pp[4] = p0[4]; - pp[5] = p0[5]; - pp[6] = p0[6]; - pp[7] = p0[7]; + __m128i _p = _mm_loadl_epi64((const __m128i*)p0); + _mm_storel_epi64((__m128i*)pp, _p); pp += 8; p0 += B_hstep; } @@ -1521,26 +1066,21 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, { const signed char* p0 = B.row(k) + (j + jj); +#if __AVX512VNNI__ || __AVXVNNI__ + __m128i _v127 = _mm_set1_epi8(127); + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(B_hstep)); +#endif // __AVX512VNNI__ || __AVXVNNI__ + int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0] + 127; - pp[1] = p0[B_hstep] + 127; - pp[2] = p0[B_hstep * 2] + 127; - pp[3] = p0[B_hstep * 3] + 127; - pp[4] = p0[1] + 127; - pp[5] = p0[B_hstep + 1] + 127; - pp[6] = p0[B_hstep * 2 + 1] + 127; - pp[7] = p0[B_hstep * 3 + 1] + 127; - pp[8] = p0[2] + 127; - pp[9] = p0[B_hstep + 2] + 127; - pp[10] = p0[B_hstep * 2 + 2] + 127; - pp[11] = p0[B_hstep * 3 + 2] + 127; - pp[12] = p0[3] + 127; - pp[13] = p0[B_hstep + 3] + 127; - pp[14] = p0[B_hstep * 2 + 3] + 127; - pp[15] = p0[B_hstep * 3 + 3] + 127; + __m128i _pp = _mm_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + _pp = _mm_shuffle_epi8(_pp, _si); + _pp = _mm_add_epi8(_pp, _v127); + _mm_storeu_si128((__m128i*)pp, _pp); pp += 16; p0 += B_hstep * 4; } @@ -4825,8 +4365,8 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int w_shift1 += pp[6]; w_shift1 += pp[7]; #else // __AVX512VNNI__ || __AVXVNNI__ - __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); - __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); int64_t v = float2int8_sse(_t0, _t1); *(int64_t*)pp = v; #endif // __AVX512VNNI__ || __AVXVNNI__ @@ -7324,8 +6864,8 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[6] += 127; pp[7] += 127; #else // __AVX512VNNI__ || __AVXVNNI__ - __m128 _t0 = _mm_unpacklo_ps(_p0, _p1); - __m128 _t1 = _mm_unpackhi_ps(_p0, _p1); + __m128 _t0 = _mm_castpd_ps(_mm_unpacklo_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); + __m128 _t1 = _mm_castpd_ps(_mm_unpackhi_pd(_mm_castps_pd(_p0), _mm_castps_pd(_p1))); int64_t v = float2int8_sse(_t0, _t1); *(int64_t*)pp = v; #endif // __AVX512VNNI__ || __AVXVNNI__ diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index 8c73bd39d65..f25b06745e8 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -122,6 +122,15 @@ static NCNN_FORCEINLINE void transpose8x4_epi16(__m128i& _r0, __m128i& _r1, __m1 _r3 = _mm_unpackhi_epi32(_tmp1, _tmp3); } +static NCNN_FORCEINLINE void transpose8x4_epi8(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpacklo_epi8(_r2, _r3); + + _r0 = _mm_unpacklo_epi16(_tmp0, _tmp1); + _r1 = _mm_unpackhi_epi16(_tmp0, _tmp1); +} + static NCNN_FORCEINLINE void transpose16x4_epi8(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) { __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); @@ -908,6 +917,16 @@ static NCNN_FORCEINLINE __m256i _mm256_comp_dpbusd_epi32(__m256i src, __m256i a, } #endif // __AVX512VNNI__ || __AVXVNNI__ +static NCNN_FORCEINLINE __m128i _mm_comp_cvtepi32_epi16(__m128i a) +{ +#if __AVX512F__ + return _mm_cvtepi32_epi16(a); +#else + __m128i _si = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0); + return _mm_shuffle_epi8(a, _si); +#endif +} + static NCNN_FORCEINLINE __m128i _mm256_comp_cvtepi32_epi16(__m256i a) { #if __AVX512F__ @@ -920,15 +939,25 @@ static NCNN_FORCEINLINE __m128i _mm256_comp_cvtepi32_epi16(__m256i a) #endif } +static NCNN_FORCEINLINE __m128i _mm_comp_cvtepi32_epi8(__m128i a) +{ +#if __AVX512F__ + return _mm_cvtepi32_epi8(a); +#else + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + return _mm_shuffle_epi8(a, _si); +#endif +} + static NCNN_FORCEINLINE __m128i _mm256_comp_cvtepi32_epi8(__m256i a) { #if __AVX512F__ return _mm256_cvtepi32_epi8(a); #else - __m128i _si = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 0, 0, 0, 0, 0, 0, 0, 0); + __m128i _si = _mm_setr_epi8(0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); __m256i _t = _mm256_shuffle_epi8(a, combine4x2_epi32(_si, _si)); _t = _mm256_permute4x64_epi64(_t, _MM_SHUFFLE(3, 1, 2, 0)); - return _mm256_castsi256_si128(_t); + return _mm_shuffle_epi32(_mm256_castsi256_si128(_t), _MM_SHUFFLE(3, 1, 2, 0)); #endif }