diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 43951800901b..2b76fb06dded 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -12,6 +12,30 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ +void pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void gemm_transB_packed_tile_int8_avx512vnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +void pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void gemm_transB_packed_tile_int8_avxvnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ void pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); void transpose_pack_A_tile_int8_avx2(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); @@ -40,6 +64,22 @@ static void print(__m512 x) static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_A_tile_int8_avx512vnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_A_tile_int8_avxvnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { @@ -78,6 +118,198 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in const signed char* pf = A.row(i + ii + 15) + k; int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; + + pp[16 + 0] = p4[0]; + pp[16 + 1] = p4[1]; + pp[16 + 2] = p4[2]; + pp[16 + 3] = p4[3]; + pp[16 + 4] = p5[0]; + pp[16 + 5] = p5[1]; + pp[16 + 6] = p5[2]; + pp[16 + 7] = p5[3]; + pp[16 + 8] = p6[0]; + pp[16 + 9] = p6[1]; + pp[16 + 10] = p6[2]; + pp[16 + 11] = p6[3]; + pp[16 + 12] = p7[0]; + pp[16 + 13] = p7[1]; + pp[16 + 14] = p7[2]; + pp[16 + 15] = p7[3]; + + pp[32 + 0] = p8[0]; + pp[32 + 1] = p8[1]; + pp[32 + 2] = p8[2]; + pp[32 + 3] = p8[3]; + pp[32 + 4] = p9[0]; + pp[32 + 5] = p9[1]; + pp[32 + 6] = p9[2]; + pp[32 + 7] = p9[3]; + pp[32 + 8] = pa[0]; + pp[32 + 9] = pa[1]; + pp[32 + 10] = pa[2]; + pp[32 + 11] = pa[3]; + pp[32 + 12] = pb[0]; + pp[32 + 13] = pb[1]; + pp[32 + 14] = pb[2]; + pp[32 + 15] = pb[3]; + + pp[48 + 0] = pc[0]; + pp[48 + 1] = pc[1]; + pp[48 + 2] = pc[2]; + pp[48 + 3] = pc[3]; + pp[48 + 4] = pd[0]; + pp[48 + 5] = pd[1]; + pp[48 + 6] = pd[2]; + pp[48 + 7] = pd[3]; + pp[48 + 8] = pe[0]; + pp[48 + 9] = pe[1]; + pp[48 + 10] = pe[2]; + pp[48 + 11] = pe[3]; + pp[48 + 12] = pf[0]; + pp[48 + 13] = pf[1]; + pp[48 + 14] = pf[2]; + pp[48 + 15] = pf[3]; + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + p8 += 4; + p9 += 4; + pa += 4; + pb += 4; + pc += 4; + pd += 4; + pe += 4; + pf += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -182,6 +414,104 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in const signed char* p7 = A.row(i + ii + 7) + k; int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; + pp[16] = p4[0]; + pp[17] = p4[1]; + pp[18] = p4[2]; + pp[19] = p4[3]; + pp[20] = p5[0]; + pp[21] = p5[1]; + pp[22] = p5[2]; + pp[23] = p5[3]; + pp[24] = p6[0]; + pp[25] = p6[1]; + pp[26] = p6[2]; + pp[27] = p6[3]; + pp[28] = p7[0]; + pp[29] = p7[1]; + pp[30] = p7[2]; + pp[31] = p7[3]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -240,6 +570,60 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in const signed char* p3 = A.row(i + ii + 3) + k; int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -277,6 +661,38 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + pp += 8; + p0 += 4; + p1 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -302,6 +718,27 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in const signed char* p0 = A.row(i + ii) + k; int kk = 0; +#if __AVX512VNNI__ + int w_shift = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + pp += 4; + p0 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = p0[0]; @@ -313,6 +750,22 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_A_tile_int8_avx512vnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_A_tile_int8_avxvnni(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { @@ -338,6 +791,181 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, const signed char* p0 = A.row(k) + (i + ii); int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp[8] = p0[2]; + pp[9] = p0[A_hstep + 2]; + pp[10] = p0[A_hstep * 2 + 2]; + pp[11] = p0[A_hstep * 3 + 2]; + pp[12] = p0[3]; + pp[13] = p0[A_hstep + 3]; + pp[14] = p0[A_hstep * 2 + 3]; + pp[15] = p0[A_hstep * 3 + 3]; + pp[16] = p0[4]; + pp[17] = p0[A_hstep + 4]; + pp[18] = p0[A_hstep * 2 + 4]; + pp[19] = p0[A_hstep * 3 + 4]; + pp[20] = p0[5]; + pp[21] = p0[A_hstep + 5]; + pp[22] = p0[A_hstep * 2 + 5]; + pp[23] = p0[A_hstep * 3 + 5]; + pp[24] = p0[6]; + pp[25] = p0[A_hstep + 6]; + pp[26] = p0[A_hstep * 2 + 6]; + pp[27] = p0[A_hstep * 3 + 6]; + pp[28] = p0[7]; + pp[29] = p0[A_hstep + 7]; + pp[30] = p0[A_hstep * 2 + 7]; + pp[31] = p0[A_hstep * 3 + 7]; + + pp[32 + 0] = p0[8]; + pp[32 + 1] = p0[A_hstep + 8]; + pp[32 + 2] = p0[A_hstep * 2 + 8]; + pp[32 + 3] = p0[A_hstep * 3 + 8]; + pp[32 + 4] = p0[9]; + pp[32 + 5] = p0[A_hstep + 9]; + pp[32 + 6] = p0[A_hstep * 2 + 9]; + pp[32 + 7] = p0[A_hstep * 3 + 9]; + pp[32 + 8] = p0[10]; + pp[32 + 9] = p0[A_hstep + 10]; + pp[32 + 10] = p0[A_hstep * 2 + 10]; + pp[32 + 11] = p0[A_hstep * 3 + 10]; + pp[32 + 12] = p0[11]; + pp[32 + 13] = p0[A_hstep + 11]; + pp[32 + 14] = p0[A_hstep * 2 + 11]; + pp[32 + 15] = p0[A_hstep * 3 + 11]; + pp[32 + 16] = p0[12]; + pp[32 + 17] = p0[A_hstep + 12]; + pp[32 + 18] = p0[A_hstep * 2 + 12]; + pp[32 + 19] = p0[A_hstep * 3 + 12]; + pp[32 + 20] = p0[13]; + pp[32 + 21] = p0[A_hstep + 13]; + pp[32 + 22] = p0[A_hstep * 2 + 13]; + pp[32 + 23] = p0[A_hstep * 3 + 13]; + pp[32 + 24] = p0[14]; + pp[32 + 25] = p0[A_hstep + 14]; + pp[32 + 26] = p0[A_hstep * 2 + 14]; + pp[32 + 27] = p0[A_hstep * 3 + 14]; + pp[32 + 28] = p0[15]; + pp[32 + 29] = p0[A_hstep + 15]; + pp[32 + 30] = p0[A_hstep * 2 + 15]; + pp[32 + 31] = p0[A_hstep * 3 + 15]; + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -404,6 +1032,97 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, const signed char* p0 = A.row(k) + (i + ii); int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp[8] = p0[2]; + pp[9] = p0[A_hstep + 2]; + pp[10] = p0[A_hstep * 2 + 2]; + pp[11] = p0[A_hstep * 3 + 2]; + pp[12] = p0[3]; + pp[13] = p0[A_hstep + 3]; + pp[14] = p0[A_hstep * 2 + 3]; + pp[15] = p0[A_hstep * 3 + 3]; + pp[16] = p0[4]; + pp[17] = p0[A_hstep + 4]; + pp[18] = p0[A_hstep * 2 + 4]; + pp[19] = p0[A_hstep * 3 + 4]; + pp[20] = p0[5]; + pp[21] = p0[A_hstep + 5]; + pp[22] = p0[A_hstep * 2 + 5]; + pp[23] = p0[A_hstep * 3 + 5]; + pp[24] = p0[6]; + pp[25] = p0[A_hstep + 6]; + pp[26] = p0[A_hstep * 2 + 6]; + pp[27] = p0[A_hstep * 3 + 6]; + pp[28] = p0[7]; + pp[29] = p0[A_hstep + 7]; + pp[30] = p0[A_hstep * 2 + 7]; + pp[31] = p0[A_hstep * 3 + 7]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + pp += 32; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -445,6 +1164,57 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, const signed char* p0 = A.row(k) + (i + ii); int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp[8] = p0[2]; + pp[9] = p0[A_hstep + 2]; + pp[10] = p0[A_hstep * 2 + 2]; + pp[11] = p0[A_hstep * 3 + 2]; + pp[12] = p0[3]; + pp[13] = p0[A_hstep + 3]; + pp[14] = p0[A_hstep * 2 + 3]; + pp[15] = p0[A_hstep * 3 + 3]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + pp += 16; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -475,6 +1245,37 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + pp += 8; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -498,6 +1299,27 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, const signed char* p0 = A.row(k) + (i + ii); int kk = 0; +#if __AVX512VNNI__ + int w_shift = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + pp += 4; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = p0[0]; @@ -509,6 +1331,22 @@ static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_B_tile_int8_avx512vnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_B_tile_int8_avxvnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + // NCNN_LOGE("pack_B_tile_int8"); // assert B.elempack == 1 // assert B.dims == 2 @@ -539,6 +1377,96 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in const signed char* pf = B.row(j + jj + 15) + k; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[1] + 127; + pp[2] = p0[2] + 127; + pp[3] = p0[3] + 127; + pp[4] = p1[0] + 127; + pp[5] = p1[1] + 127; + pp[6] = p1[2] + 127; + pp[7] = p1[3] + 127; + pp[8] = p2[0] + 127; + pp[9] = p2[1] + 127; + pp[10] = p2[2] + 127; + pp[11] = p2[3] + 127; + pp[12] = p3[0] + 127; + pp[13] = p3[1] + 127; + pp[14] = p3[2] + 127; + pp[15] = p3[3] + 127; + + pp[16 + 0] = p4[0] + 127; + pp[16 + 1] = p4[1] + 127; + pp[16 + 2] = p4[2] + 127; + pp[16 + 3] = p4[3] + 127; + pp[16 + 4] = p5[0] + 127; + pp[16 + 5] = p5[1] + 127; + pp[16 + 6] = p5[2] + 127; + pp[16 + 7] = p5[3] + 127; + pp[16 + 8] = p6[0] + 127; + pp[16 + 9] = p6[1] + 127; + pp[16 + 10] = p6[2] + 127; + pp[16 + 11] = p6[3] + 127; + pp[16 + 12] = p7[0] + 127; + pp[16 + 13] = p7[1] + 127; + pp[16 + 14] = p7[2] + 127; + pp[16 + 15] = p7[3] + 127; + + pp[32 + 0] = p8[0] + 127; + pp[32 + 1] = p8[1] + 127; + pp[32 + 2] = p8[2] + 127; + pp[32 + 3] = p8[3] + 127; + pp[32 + 4] = p9[0] + 127; + pp[32 + 5] = p9[1] + 127; + pp[32 + 6] = p9[2] + 127; + pp[32 + 7] = p9[3] + 127; + pp[32 + 8] = pa[0] + 127; + pp[32 + 9] = pa[1] + 127; + pp[32 + 10] = pa[2] + 127; + pp[32 + 11] = pa[3] + 127; + pp[32 + 12] = pb[0] + 127; + pp[32 + 13] = pb[1] + 127; + pp[32 + 14] = pb[2] + 127; + pp[32 + 15] = pb[3] + 127; + + pp[48 + 0] = pc[0] + 127; + pp[48 + 1] = pc[1] + 127; + pp[48 + 2] = pc[2] + 127; + pp[48 + 3] = pc[3] + 127; + pp[48 + 4] = pd[0] + 127; + pp[48 + 5] = pd[1] + 127; + pp[48 + 6] = pd[2] + 127; + pp[48 + 7] = pd[3] + 127; + pp[48 + 8] = pe[0] + 127; + pp[48 + 9] = pe[1] + 127; + pp[48 + 10] = pe[2] + 127; + pp[48 + 11] = pe[3] + 127; + pp[48 + 12] = pf[0] + 127; + pp[48 + 13] = pf[1] + 127; + pp[48 + 14] = pf[2] + 127; + pp[48 + 15] = pf[3] + 127; + + pp += 64; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + p8 += 4; + p9 += 4; + pa += 4; + pb += 4; + pc += 4; + pd += 4; + pe += 4; + pf += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -643,6 +1571,52 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in const signed char* p7 = B.row(j + jj + 7) + k; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[1] + 127; + pp[2] = p0[2] + 127; + pp[3] = p0[3] + 127; + pp[4] = p1[0] + 127; + pp[5] = p1[1] + 127; + pp[6] = p1[2] + 127; + pp[7] = p1[3] + 127; + pp[8] = p2[0] + 127; + pp[9] = p2[1] + 127; + pp[10] = p2[2] + 127; + pp[11] = p2[3] + 127; + pp[12] = p3[0] + 127; + pp[13] = p3[1] + 127; + pp[14] = p3[2] + 127; + pp[15] = p3[3] + 127; + pp[16 + 0] = p4[0] + 127; + pp[16 + 1] = p4[1] + 127; + pp[16 + 2] = p4[2] + 127; + pp[16 + 3] = p4[3] + 127; + pp[16 + 4] = p5[0] + 127; + pp[16 + 5] = p5[1] + 127; + pp[16 + 6] = p5[2] + 127; + pp[16 + 7] = p5[3] + 127; + pp[16 + 8] = p6[0] + 127; + pp[16 + 9] = p6[1] + 127; + pp[16 + 10] = p6[2] + 127; + pp[16 + 11] = p6[3] + 127; + pp[16 + 12] = p7[0] + 127; + pp[16 + 13] = p7[1] + 127; + pp[16 + 14] = p7[2] + 127; + pp[16 + 15] = p7[3] + 127; + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -701,6 +1675,32 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in const signed char* p3 = B.row(j + jj + 3) + k; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[1] + 127; + pp[2] = p0[2] + 127; + pp[3] = p0[3] + 127; + pp[4] = p1[0] + 127; + pp[5] = p1[1] + 127; + pp[6] = p1[2] + 127; + pp[7] = p1[3] + 127; + pp[8] = p2[0] + 127; + pp[9] = p2[1] + 127; + pp[10] = p2[2] + 127; + pp[11] = p2[3] + 127; + pp[12] = p3[0] + 127; + pp[13] = p3[1] + 127; + pp[14] = p3[2] + 127; + pp[15] = p3[3] + 127; + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -738,6 +1738,22 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[1] + 127; + pp[2] = p0[2] + 127; + pp[3] = p0[3] + 127; + pp[4] = p1[0] + 127; + pp[5] = p1[1] + 127; + pp[6] = p1[2] + 127; + pp[7] = p1[3] + 127; + pp += 8; + p0 += 4; + p1 += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -763,6 +1779,17 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in const signed char* p0 = B.row(j + jj) + k; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[1] + 127; + pp[2] = p0[2] + 127; + pp[3] = p0[3] + 127; + pp += 4; + p0 += 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = p0[0]; @@ -774,6 +1801,22 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_B_tile_int8_avx512vnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_B_tile_int8_avxvnni(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + // NCNN_LOGE("transpose_pack_B_tile_int8"); // assert B.elempack == 1 // assert B.dims == 2 @@ -791,6 +1834,78 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, const signed char* p0 = B.row(k) + (j + jj); int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[B_hstep] + 127; + pp[2] = p0[B_hstep * 2] + 127; + pp[3] = p0[B_hstep * 3] + 127; + pp[4] = p0[1] + 127; + pp[5] = p0[B_hstep + 1] + 127; + pp[6] = p0[B_hstep * 2 + 1] + 127; + pp[7] = p0[B_hstep * 3 + 1] + 127; + pp[8] = p0[2] + 127; + pp[9] = p0[B_hstep + 2] + 127; + pp[10] = p0[B_hstep * 2 + 2] + 127; + pp[11] = p0[B_hstep * 3 + 2] + 127; + pp[12] = p0[3] + 127; + pp[13] = p0[B_hstep + 3] + 127; + pp[14] = p0[B_hstep * 2 + 3] + 127; + pp[15] = p0[B_hstep * 3 + 3] + 127; + pp[16] = p0[4] + 127; + pp[17] = p0[B_hstep + 4] + 127; + pp[18] = p0[B_hstep * 2 + 4] + 127; + pp[19] = p0[B_hstep * 3 + 4] + 127; + pp[20] = p0[5] + 127; + pp[21] = p0[B_hstep + 5] + 127; + pp[22] = p0[B_hstep * 2 + 5] + 127; + pp[23] = p0[B_hstep * 3 + 5] + 127; + pp[24] = p0[6] + 127; + pp[25] = p0[B_hstep + 6] + 127; + pp[26] = p0[B_hstep * 2 + 6] + 127; + pp[27] = p0[B_hstep * 3 + 6] + 127; + pp[28] = p0[7] + 127; + pp[29] = p0[B_hstep + 7] + 127; + pp[30] = p0[B_hstep * 2 + 7] + 127; + pp[31] = p0[B_hstep * 3 + 7] + 127; + + pp[32 + 0] = p0[8] + 127; + pp[32 + 1] = p0[B_hstep + 8] + 127; + pp[32 + 2] = p0[B_hstep * 2 + 8] + 127; + pp[32 + 3] = p0[B_hstep * 3 + 8] + 127; + pp[32 + 4] = p0[9] + 127; + pp[32 + 5] = p0[B_hstep + 9] + 127; + pp[32 + 6] = p0[B_hstep * 2 + 9] + 127; + pp[32 + 7] = p0[B_hstep * 3 + 9] + 127; + pp[32 + 8] = p0[10] + 127; + pp[32 + 9] = p0[B_hstep + 10] + 127; + pp[32 + 10] = p0[B_hstep * 2 + 10] + 127; + pp[32 + 11] = p0[B_hstep * 3 + 10] + 127; + pp[32 + 12] = p0[11] + 127; + pp[32 + 13] = p0[B_hstep + 11] + 127; + pp[32 + 14] = p0[B_hstep * 2 + 11] + 127; + pp[32 + 15] = p0[B_hstep * 3 + 11] + 127; + pp[32 + 16] = p0[12] + 127; + pp[32 + 17] = p0[B_hstep + 12] + 127; + pp[32 + 18] = p0[B_hstep * 2 + 12] + 127; + pp[32 + 19] = p0[B_hstep * 3 + 12] + 127; + pp[32 + 20] = p0[13] + 127; + pp[32 + 21] = p0[B_hstep + 13] + 127; + pp[32 + 22] = p0[B_hstep * 2 + 13] + 127; + pp[32 + 23] = p0[B_hstep * 3 + 13] + 127; + pp[32 + 24] = p0[14] + 127; + pp[32 + 25] = p0[B_hstep + 14] + 127; + pp[32 + 26] = p0[B_hstep * 2 + 14] + 127; + pp[32 + 27] = p0[B_hstep * 3 + 14] + 127; + pp[32 + 28] = p0[15] + 127; + pp[32 + 29] = p0[B_hstep + 15] + 127; + pp[32 + 30] = p0[B_hstep * 2 + 15] + 127; + pp[32 + 31] = p0[B_hstep * 3 + 15] + 127; + pp += 64; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -857,6 +1972,45 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, const signed char* p0 = B.row(k) + (j + jj); int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[B_hstep] + 127; + pp[2] = p0[B_hstep * 2] + 127; + pp[3] = p0[B_hstep * 3] + 127; + pp[4] = p0[1] + 127; + pp[5] = p0[B_hstep + 1] + 127; + pp[6] = p0[B_hstep * 2 + 1] + 127; + pp[7] = p0[B_hstep * 3 + 1] + 127; + pp[8] = p0[2] + 127; + pp[9] = p0[B_hstep + 2] + 127; + pp[10] = p0[B_hstep * 2 + 2] + 127; + pp[11] = p0[B_hstep * 3 + 2] + 127; + pp[12] = p0[3] + 127; + pp[13] = p0[B_hstep + 3] + 127; + pp[14] = p0[B_hstep * 2 + 3] + 127; + pp[15] = p0[B_hstep * 3 + 3] + 127; + pp[16] = p0[4] + 127; + pp[17] = p0[B_hstep + 4] + 127; + pp[18] = p0[B_hstep * 2 + 4] + 127; + pp[19] = p0[B_hstep * 3 + 4] + 127; + pp[20] = p0[5] + 127; + pp[21] = p0[B_hstep + 5] + 127; + pp[22] = p0[B_hstep * 2 + 5] + 127; + pp[23] = p0[B_hstep * 3 + 5] + 127; + pp[24] = p0[6] + 127; + pp[25] = p0[B_hstep + 6] + 127; + pp[26] = p0[B_hstep * 2 + 6] + 127; + pp[27] = p0[B_hstep * 3 + 6] + 127; + pp[28] = p0[7] + 127; + pp[29] = p0[B_hstep + 7] + 127; + pp[30] = p0[B_hstep * 2 + 7] + 127; + pp[31] = p0[B_hstep * 3 + 7] + 127; + pp += 32; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -898,6 +2052,29 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, const signed char* p0 = B.row(k) + (j + jj); int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[B_hstep] + 127; + pp[2] = p0[B_hstep * 2] + 127; + pp[3] = p0[B_hstep * 3] + 127; + pp[4] = p0[1] + 127; + pp[5] = p0[B_hstep + 1] + 127; + pp[6] = p0[B_hstep * 2 + 1] + 127; + pp[7] = p0[B_hstep * 3 + 1] + 127; + pp[8] = p0[2] + 127; + pp[9] = p0[B_hstep + 2] + 127; + pp[10] = p0[B_hstep * 2 + 2] + 127; + pp[11] = p0[B_hstep * 3 + 2] + 127; + pp[12] = p0[3] + 127; + pp[13] = p0[B_hstep + 3] + 127; + pp[14] = p0[B_hstep * 2 + 3] + 127; + pp[15] = p0[B_hstep * 3 + 3] + 127; + pp += 16; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -928,6 +2105,21 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[B_hstep] + 127; + pp[2] = p0[B_hstep * 2] + 127; + pp[3] = p0[B_hstep * 3] + 127; + pp[4] = p0[1] + 127; + pp[5] = p0[B_hstep + 1] + 127; + pp[6] = p0[B_hstep * 2 + 1] + 127; + pp[7] = p0[B_hstep * 3 + 1] + 127; + pp += 8; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = p0[0]; @@ -951,6 +2143,17 @@ static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, const signed char* p0 = B.row(k) + (j + jj); int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0] + 127; + pp[1] = p0[B_hstep] + 127; + pp[2] = p0[B_hstep * 2] + 127; + pp[3] = p0[B_hstep * 3] + 127; + pp += 4; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = p0[0]; @@ -966,7 +2169,7 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; const int K = A.w; - NCNN_LOGE("compute_A_tile_int8_scales %d %d", max_ii, elempack); + // NCNN_LOGE("compute_A_tile_int8_scales %d %d", max_ii, elempack); const float v127_B_scale = 127.f * B_scale; @@ -1152,6 +2355,22 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_A_tile_fp32_to_int8_avx512vnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_A_tile_fp32_to_int8_avxvnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { @@ -1163,7 +2382,9 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i const int elempack = A.elempack; const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; - NCNN_LOGE("pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + // NCNN_LOGE("pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + + signed char* pp = (signed char*)AT; int ii = 0; #if __SSE2__ @@ -1171,8 +2392,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i #if __AVX512F__ for (; ii + 15 < max_ii; ii += 16) { - signed char* pp = (signed char*)AT + ii * max_kk; - const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; const float scale0 = scales[i + ii]; @@ -1195,6 +2414,180 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[16] * scale0); + pp[2] = float2int8(p0[32] * scale0); + pp[3] = float2int8(p0[48] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[17] * scale1); + pp[6] = float2int8(p0[33] * scale1); + pp[7] = float2int8(p0[49] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[18] * scale2); + pp[10] = float2int8(p0[34] * scale2); + pp[11] = float2int8(p0[50] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[19] * scale3); + pp[14] = float2int8(p0[35] * scale3); + pp[15] = float2int8(p0[51] * scale3); + pp[16] = float2int8(p0[4] * scale4); + pp[17] = float2int8(p0[20] * scale4); + pp[18] = float2int8(p0[36] * scale4); + pp[19] = float2int8(p0[52] * scale4); + pp[20] = float2int8(p0[5] * scale5); + pp[21] = float2int8(p0[21] * scale5); + pp[22] = float2int8(p0[37] * scale5); + pp[23] = float2int8(p0[53] * scale5); + pp[24] = float2int8(p0[6] * scale6); + pp[25] = float2int8(p0[22] * scale6); + pp[26] = float2int8(p0[38] * scale6); + pp[27] = float2int8(p0[54] * scale6); + pp[28] = float2int8(p0[7] * scale7); + pp[29] = float2int8(p0[23] * scale7); + pp[30] = float2int8(p0[39] * scale7); + pp[31] = float2int8(p0[55] * scale7); + pp[32] = float2int8(p0[8] * scale8); + pp[33] = float2int8(p0[24] * scale8); + pp[34] = float2int8(p0[40] * scale8); + pp[35] = float2int8(p0[56] * scale8); + pp[36] = float2int8(p0[9] * scale9); + pp[37] = float2int8(p0[25] * scale9); + pp[38] = float2int8(p0[41] * scale9); + pp[39] = float2int8(p0[57] * scale9); + pp[40] = float2int8(p0[10] * scalea); + pp[41] = float2int8(p0[26] * scalea); + pp[42] = float2int8(p0[42] * scalea); + pp[43] = float2int8(p0[58] * scalea); + pp[44] = float2int8(p0[11] * scaleb); + pp[45] = float2int8(p0[27] * scaleb); + pp[46] = float2int8(p0[43] * scaleb); + pp[47] = float2int8(p0[59] * scaleb); + pp[48] = float2int8(p0[12] * scalec); + pp[49] = float2int8(p0[28] * scalec); + pp[50] = float2int8(p0[44] * scalec); + pp[51] = float2int8(p0[60] * scalec); + pp[52] = float2int8(p0[13] * scaled); + pp[53] = float2int8(p0[29] * scaled); + pp[54] = float2int8(p0[45] * scaled); + pp[55] = float2int8(p0[61] * scaled); + pp[56] = float2int8(p0[14] * scalee); + pp[57] = float2int8(p0[30] * scalee); + pp[58] = float2int8(p0[46] * scalee); + pp[59] = float2int8(p0[62] * scalee); + pp[60] = float2int8(p0[15] * scalef); + pp[61] = float2int8(p0[31] * scalef); + pp[62] = float2int8(p0[47] * scalef); + pp[63] = float2int8(p0[63] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += 64; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1257,6 +2650,181 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[8] * scale0); + pp[2] = float2int8(p0[16] * scale0); + pp[3] = float2int8(p0[24] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[9] * scale1); + pp[6] = float2int8(p0[17] * scale1); + pp[7] = float2int8(p0[25] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[10] * scale2); + pp[10] = float2int8(p0[18] * scale2); + pp[11] = float2int8(p0[26] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[11] * scale3); + pp[14] = float2int8(p0[19] * scale3); + pp[15] = float2int8(p0[27] * scale3); + pp[16] = float2int8(p0[4] * scale4); + pp[17] = float2int8(p0[12] * scale4); + pp[18] = float2int8(p0[20] * scale4); + pp[19] = float2int8(p0[28] * scale4); + pp[20] = float2int8(p0[5] * scale5); + pp[21] = float2int8(p0[13] * scale5); + pp[22] = float2int8(p0[21] * scale5); + pp[23] = float2int8(p0[29] * scale5); + pp[24] = float2int8(p0[6] * scale6); + pp[25] = float2int8(p0[14] * scale6); + pp[26] = float2int8(p0[22] * scale6); + pp[27] = float2int8(p0[30] * scale6); + pp[28] = float2int8(p0[7] * scale7); + pp[29] = float2int8(p0[15] * scale7); + pp[30] = float2int8(p0[23] * scale7); + pp[31] = float2int8(p0[31] * scale7); + + pp[32 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); + pp[32 + 1] = float2int8(p0[A_hstep * 8 + 8] * scale8); + pp[32 + 2] = float2int8(p0[A_hstep * 8 + 16] * scale8); + pp[32 + 3] = float2int8(p0[A_hstep * 8 + 24] * scale8); + pp[32 + 4] = float2int8(p0[A_hstep * 8 + 1] * scale9); + pp[32 + 5] = float2int8(p0[A_hstep * 8 + 9] * scale9); + pp[32 + 6] = float2int8(p0[A_hstep * 8 + 17] * scale9); + pp[32 + 7] = float2int8(p0[A_hstep * 8 + 25] * scale9); + pp[32 + 8] = float2int8(p0[A_hstep * 8 + 2] * scalea); + pp[32 + 9] = float2int8(p0[A_hstep * 8 + 10] * scalea); + pp[32 + 10] = float2int8(p0[A_hstep * 8 + 18] * scalea); + pp[32 + 11] = float2int8(p0[A_hstep * 8 + 26] * scalea); + pp[32 + 12] = float2int8(p0[A_hstep * 8 + 3] * scaleb); + pp[32 + 13] = float2int8(p0[A_hstep * 8 + 11] * scaleb); + pp[32 + 14] = float2int8(p0[A_hstep * 8 + 19] * scaleb); + pp[32 + 15] = float2int8(p0[A_hstep * 8 + 27] * scaleb); + pp[32 + 16] = float2int8(p0[A_hstep * 8 + 4] * scalec); + pp[32 + 17] = float2int8(p0[A_hstep * 8 + 12] * scalec); + pp[32 + 18] = float2int8(p0[A_hstep * 8 + 20] * scalec); + pp[32 + 19] = float2int8(p0[A_hstep * 8 + 28] * scalec); + pp[32 + 20] = float2int8(p0[A_hstep * 8 + 5] * scaled); + pp[32 + 21] = float2int8(p0[A_hstep * 8 + 13] * scaled); + pp[32 + 22] = float2int8(p0[A_hstep * 8 + 21] * scaled); + pp[32 + 23] = float2int8(p0[A_hstep * 8 + 29] * scaled); + pp[32 + 24] = float2int8(p0[A_hstep * 8 + 6] * scalee); + pp[32 + 25] = float2int8(p0[A_hstep * 8 + 14] * scalee); + pp[32 + 26] = float2int8(p0[A_hstep * 8 + 22] * scalee); + pp[32 + 27] = float2int8(p0[A_hstep * 8 + 30] * scalee); + pp[32 + 28] = float2int8(p0[A_hstep * 8 + 7] * scalef); + pp[32 + 29] = float2int8(p0[A_hstep * 8 + 15] * scalef); + pp[32 + 30] = float2int8(p0[A_hstep * 8 + 23] * scalef); + pp[32 + 31] = float2int8(p0[A_hstep * 8 + 31] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += 32; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1320,6 +2888,182 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[4] * scale0); + pp[2] = float2int8(p0[8] * scale0); + pp[3] = float2int8(p0[12] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[9] * scale1); + pp[7] = float2int8(p0[13] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[6] * scale2); + pp[10] = float2int8(p0[10] * scale2); + pp[11] = float2int8(p0[14] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[7] * scale3); + pp[14] = float2int8(p0[11] * scale3); + pp[15] = float2int8(p0[15] * scale3); + pp[16 + 0] = float2int8(p0[A_hstep * 4 + 0] * scale4); + pp[16 + 1] = float2int8(p0[A_hstep * 4 + 4] * scale4); + pp[16 + 2] = float2int8(p0[A_hstep * 4 + 8] * scale4); + pp[16 + 3] = float2int8(p0[A_hstep * 4 + 12] * scale4); + pp[16 + 4] = float2int8(p0[A_hstep * 4 + 1] * scale5); + pp[16 + 5] = float2int8(p0[A_hstep * 4 + 5] * scale5); + pp[16 + 6] = float2int8(p0[A_hstep * 4 + 9] * scale5); + pp[16 + 7] = float2int8(p0[A_hstep * 4 + 13] * scale5); + pp[16 + 8] = float2int8(p0[A_hstep * 4 + 2] * scale6); + pp[16 + 9] = float2int8(p0[A_hstep * 4 + 6] * scale6); + pp[16 + 10] = float2int8(p0[A_hstep * 4 + 10] * scale6); + pp[16 + 11] = float2int8(p0[A_hstep * 4 + 14] * scale6); + pp[16 + 12] = float2int8(p0[A_hstep * 4 + 3] * scale7); + pp[16 + 13] = float2int8(p0[A_hstep * 4 + 7] * scale7); + pp[16 + 14] = float2int8(p0[A_hstep * 4 + 11] * scale7); + pp[16 + 15] = float2int8(p0[A_hstep * 4 + 15] * scale7); + + pp[32 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); + pp[32 + 1] = float2int8(p0[A_hstep * 8 + 4] * scale8); + pp[32 + 2] = float2int8(p0[A_hstep * 8 + 8] * scale8); + pp[32 + 3] = float2int8(p0[A_hstep * 8 + 12] * scale8); + pp[32 + 4] = float2int8(p0[A_hstep * 8 + 1] * scale9); + pp[32 + 5] = float2int8(p0[A_hstep * 8 + 5] * scale9); + pp[32 + 6] = float2int8(p0[A_hstep * 8 + 9] * scale9); + pp[32 + 7] = float2int8(p0[A_hstep * 8 + 13] * scale9); + pp[32 + 8] = float2int8(p0[A_hstep * 8 + 2] * scalea); + pp[32 + 9] = float2int8(p0[A_hstep * 8 + 6] * scalea); + pp[32 + 10] = float2int8(p0[A_hstep * 8 + 10] * scalea); + pp[32 + 11] = float2int8(p0[A_hstep * 8 + 14] * scalea); + pp[32 + 12] = float2int8(p0[A_hstep * 8 + 3] * scaleb); + pp[32 + 13] = float2int8(p0[A_hstep * 8 + 7] * scaleb); + pp[32 + 14] = float2int8(p0[A_hstep * 8 + 11] * scaleb); + pp[32 + 15] = float2int8(p0[A_hstep * 8 + 15] * scaleb); + + pp[48 + 0] = float2int8(p0[A_hstep * 12 + 0] * scalec); + pp[48 + 1] = float2int8(p0[A_hstep * 12 + 4] * scalec); + pp[48 + 2] = float2int8(p0[A_hstep * 12 + 8] * scalec); + pp[48 + 3] = float2int8(p0[A_hstep * 12 + 12] * scalec); + pp[48 + 4] = float2int8(p0[A_hstep * 12 + 1] * scaled); + pp[48 + 5] = float2int8(p0[A_hstep * 12 + 5] * scaled); + pp[48 + 6] = float2int8(p0[A_hstep * 12 + 9] * scaled); + pp[48 + 7] = float2int8(p0[A_hstep * 12 + 13] * scaled); + pp[48 + 8] = float2int8(p0[A_hstep * 12 + 2] * scalee); + pp[48 + 9] = float2int8(p0[A_hstep * 12 + 6] * scalee); + pp[48 + 10] = float2int8(p0[A_hstep * 12 + 10] * scalee); + pp[48 + 11] = float2int8(p0[A_hstep * 12 + 14] * scalee); + pp[48 + 12] = float2int8(p0[A_hstep * 12 + 3] * scalef); + pp[48 + 13] = float2int8(p0[A_hstep * 12 + 7] * scalef); + pp[48 + 14] = float2int8(p0[A_hstep * 12 + 11] * scalef); + pp[48 + 15] = float2int8(p0[A_hstep * 12 + 15] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += 16; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1385,6 +3129,181 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[A_hstep] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep + 2] * scale1); + pp[7] = float2int8(p0[A_hstep + 3] * scale1); + pp[8] = float2int8(p0[A_hstep * 2] * scale2); + pp[9] = float2int8(p0[A_hstep * 2 + 1] * scale2); + pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); + pp[11] = float2int8(p0[A_hstep * 2 + 3] * scale2); + pp[12] = float2int8(p0[A_hstep * 3] * scale3); + pp[13] = float2int8(p0[A_hstep * 3 + 1] * scale3); + pp[14] = float2int8(p0[A_hstep * 3 + 2] * scale3); + pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); + pp[16] = float2int8(p0[A_hstep * 4] * scale4); + pp[17] = float2int8(p0[A_hstep * 4 + 1] * scale4); + pp[18] = float2int8(p0[A_hstep * 4 + 2] * scale4); + pp[19] = float2int8(p0[A_hstep * 4 + 3] * scale4); + pp[20] = float2int8(p0[A_hstep * 5] * scale5); + pp[21] = float2int8(p0[A_hstep * 5 + 1] * scale5); + pp[22] = float2int8(p0[A_hstep * 5 + 2] * scale5); + pp[23] = float2int8(p0[A_hstep * 5 + 3] * scale5); + pp[24] = float2int8(p0[A_hstep * 6] * scale6); + pp[25] = float2int8(p0[A_hstep * 6 + 1] * scale6); + pp[26] = float2int8(p0[A_hstep * 6 + 2] * scale6); + pp[27] = float2int8(p0[A_hstep * 6 + 3] * scale6); + pp[28] = float2int8(p0[A_hstep * 7] * scale7); + pp[29] = float2int8(p0[A_hstep * 7 + 1] * scale7); + pp[30] = float2int8(p0[A_hstep * 7 + 2] * scale7); + pp[31] = float2int8(p0[A_hstep * 7 + 3] * scale7); + + pp[32 + 0] = float2int8(p0[A_hstep * 8] * scale8); + pp[32 + 1] = float2int8(p0[A_hstep * 8 + 1] * scale8); + pp[32 + 2] = float2int8(p0[A_hstep * 8 + 2] * scale8); + pp[32 + 3] = float2int8(p0[A_hstep * 8 + 3] * scale8); + pp[32 + 4] = float2int8(p0[A_hstep * 9] * scale9); + pp[32 + 5] = float2int8(p0[A_hstep * 9 + 1] * scale9); + pp[32 + 6] = float2int8(p0[A_hstep * 9 + 2] * scale9); + pp[32 + 7] = float2int8(p0[A_hstep * 9 + 3] * scale9); + pp[32 + 8] = float2int8(p0[A_hstep * 10] * scalea); + pp[32 + 9] = float2int8(p0[A_hstep * 10 + 1] * scalea); + pp[32 + 10] = float2int8(p0[A_hstep * 10 + 2] * scalea); + pp[32 + 11] = float2int8(p0[A_hstep * 10 + 3] * scalea); + pp[32 + 12] = float2int8(p0[A_hstep * 11] * scaleb); + pp[32 + 13] = float2int8(p0[A_hstep * 11 + 1] * scaleb); + pp[32 + 14] = float2int8(p0[A_hstep * 11 + 2] * scaleb); + pp[32 + 15] = float2int8(p0[A_hstep * 11 + 3] * scaleb); + pp[32 + 16] = float2int8(p0[A_hstep * 12] * scalec); + pp[32 + 17] = float2int8(p0[A_hstep * 12 + 1] * scalec); + pp[32 + 18] = float2int8(p0[A_hstep * 12 + 2] * scalec); + pp[32 + 19] = float2int8(p0[A_hstep * 12 + 3] * scalec); + pp[32 + 20] = float2int8(p0[A_hstep * 13] * scaled); + pp[32 + 21] = float2int8(p0[A_hstep * 13 + 1] * scaled); + pp[32 + 22] = float2int8(p0[A_hstep * 13 + 2] * scaled); + pp[32 + 23] = float2int8(p0[A_hstep * 13 + 3] * scaled); + pp[32 + 24] = float2int8(p0[A_hstep * 14] * scalee); + pp[32 + 25] = float2int8(p0[A_hstep * 14 + 1] * scalee); + pp[32 + 26] = float2int8(p0[A_hstep * 14 + 2] * scalee); + pp[32 + 27] = float2int8(p0[A_hstep * 14 + 3] * scalee); + pp[32 + 28] = float2int8(p0[A_hstep * 15] * scalef); + pp[32 + 29] = float2int8(p0[A_hstep * 15 + 1] * scalef); + pp[32 + 30] = float2int8(p0[A_hstep * 15 + 2] * scalef); + pp[32 + 31] = float2int8(p0[A_hstep * 15 + 3] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1447,17 +3366,11 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i } } #endif // __AVX512F__ +#if !__AVX2__ + signed char* pp1 = pp + max_kk * 4; +#endif for (; ii + 7 < max_ii; ii += 8) { -#if __AVX2__ - signed char* pp = (signed char*)AT + ii * max_kk; -#else - signed char* pp = (signed char*)AT + ii * max_kk; - signed char* pp1 = (signed char*)AT + (ii + 4) * max_kk; - // NCNN_LOGE("pp0 %p", pp); - // NCNN_LOGE("pp1 %p", pp1); -#endif - const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; const float scale0 = scales[i + ii]; @@ -1472,6 +3385,97 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[8] * scale0); + pp[2] = float2int8(p0[16] * scale0); + pp[3] = float2int8(p0[24] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[9] * scale1); + pp[6] = float2int8(p0[17] * scale1); + pp[7] = float2int8(p0[25] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[10] * scale2); + pp[10] = float2int8(p0[18] * scale2); + pp[11] = float2int8(p0[26] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[11] * scale3); + pp[14] = float2int8(p0[19] * scale3); + pp[15] = float2int8(p0[27] * scale3); + pp[16] = float2int8(p0[4] * scale4); + pp[17] = float2int8(p0[12] * scale4); + pp[18] = float2int8(p0[20] * scale4); + pp[19] = float2int8(p0[28] * scale4); + pp[20] = float2int8(p0[5] * scale5); + pp[21] = float2int8(p0[13] * scale5); + pp[22] = float2int8(p0[21] * scale5); + pp[23] = float2int8(p0[29] * scale5); + pp[24] = float2int8(p0[6] * scale6); + pp[25] = float2int8(p0[14] * scale6); + pp[26] = float2int8(p0[22] * scale6); + pp[27] = float2int8(p0[30] * scale6); + pp[28] = float2int8(p0[7] * scale7); + pp[29] = float2int8(p0[15] * scale7); + pp[30] = float2int8(p0[23] * scale7); + pp[31] = float2int8(p0[31] * scale7); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + pp += 32; + p0 += 32; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1501,8 +3505,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp1[5] = float2int8(p0[14] * scale6); pp1[6] = float2int8(p0[7] * scale7); pp1[7] = float2int8(p0[15] * scale7); - // NCNN_LOGE("%d %d", pp[0], pp[4]); - // NCNN_LOGE("%d %d", pp1[0], pp1[4]); pp += 8; pp1 += 8; #endif @@ -1534,6 +3536,97 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[4] * scale0); + pp[2] = float2int8(p0[8] * scale0); + pp[3] = float2int8(p0[12] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[9] * scale1); + pp[7] = float2int8(p0[13] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[6] * scale2); + pp[10] = float2int8(p0[10] * scale2); + pp[11] = float2int8(p0[14] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[7] * scale3); + pp[14] = float2int8(p0[11] * scale3); + pp[15] = float2int8(p0[15] * scale3); + pp[16] = float2int8(p0[A_hstep * 4 + 0] * scale4); + pp[17] = float2int8(p0[A_hstep * 4 + 4] * scale4); + pp[18] = float2int8(p0[A_hstep * 4 + 8] * scale4); + pp[19] = float2int8(p0[A_hstep * 4 + 12] * scale4); + pp[20] = float2int8(p0[A_hstep * 4 + 1] * scale5); + pp[21] = float2int8(p0[A_hstep * 4 + 5] * scale5); + pp[22] = float2int8(p0[A_hstep * 4 + 9] * scale5); + pp[23] = float2int8(p0[A_hstep * 4 + 13] * scale5); + pp[24] = float2int8(p0[A_hstep * 4 + 2] * scale6); + pp[25] = float2int8(p0[A_hstep * 4 + 6] * scale6); + pp[26] = float2int8(p0[A_hstep * 4 + 10] * scale6); + pp[27] = float2int8(p0[A_hstep * 4 + 14] * scale6); + pp[28] = float2int8(p0[A_hstep * 4 + 3] * scale7); + pp[29] = float2int8(p0[A_hstep * 4 + 7] * scale7); + pp[30] = float2int8(p0[A_hstep * 4 + 11] * scale7); + pp[31] = float2int8(p0[A_hstep * 4 + 15] * scale7); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + pp += 32; + p0 += 16; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1594,6 +3687,97 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[A_hstep] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep + 2] * scale1); + pp[7] = float2int8(p0[A_hstep + 3] * scale1); + pp[8] = float2int8(p0[A_hstep * 2] * scale2); + pp[9] = float2int8(p0[A_hstep * 2 + 1] * scale2); + pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); + pp[11] = float2int8(p0[A_hstep * 2 + 3] * scale2); + pp[12] = float2int8(p0[A_hstep * 3] * scale3); + pp[13] = float2int8(p0[A_hstep * 3 + 1] * scale3); + pp[14] = float2int8(p0[A_hstep * 3 + 2] * scale3); + pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); + pp[16] = float2int8(p0[A_hstep * 4] * scale4); + pp[17] = float2int8(p0[A_hstep * 4 + 1] * scale4); + pp[18] = float2int8(p0[A_hstep * 4 + 2] * scale4); + pp[19] = float2int8(p0[A_hstep * 4 + 3] * scale4); + pp[20] = float2int8(p0[A_hstep * 5] * scale5); + pp[21] = float2int8(p0[A_hstep * 5 + 1] * scale5); + pp[22] = float2int8(p0[A_hstep * 5 + 2] * scale5); + pp[23] = float2int8(p0[A_hstep * 5 + 3] * scale5); + pp[24] = float2int8(p0[A_hstep * 6] * scale6); + pp[25] = float2int8(p0[A_hstep * 6 + 1] * scale6); + pp[26] = float2int8(p0[A_hstep * 6 + 2] * scale6); + pp[27] = float2int8(p0[A_hstep * 6 + 3] * scale6); + pp[28] = float2int8(p0[A_hstep * 7] * scale7); + pp[29] = float2int8(p0[A_hstep * 7 + 1] * scale7); + pp[30] = float2int8(p0[A_hstep * 7 + 2] * scale7); + pp[31] = float2int8(p0[A_hstep * 7 + 3] * scale7); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + pp += 32; + p0 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1652,9 +3836,9 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i } } } - signed char* pp = (signed char*)AT + ii * max_kk; -#else - signed char* pp = (signed char*)AT; +#if !__AVX2__ + pp = pp1; +#endif #endif // __AVX__ for (; ii + 3 < max_ii; ii += 4) { @@ -1670,6 +3854,57 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[4] * scale0); + pp[2] = float2int8(p0[8] * scale0); + pp[3] = float2int8(p0[12] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[9] * scale1); + pp[7] = float2int8(p0[13] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[6] * scale2); + pp[10] = float2int8(p0[10] * scale2); + pp[11] = float2int8(p0[14] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[7] * scale3); + pp[14] = float2int8(p0[11] * scale3); + pp[15] = float2int8(p0[15] * scale3); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + pp += 16; + p0 += 16; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1698,6 +3933,57 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[A_hstep] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep + 2] * scale1); + pp[7] = float2int8(p0[A_hstep + 3] * scale1); + pp[8] = float2int8(p0[A_hstep * 2] * scale2); + pp[9] = float2int8(p0[A_hstep * 2 + 1] * scale2); + pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); + pp[11] = float2int8(p0[A_hstep * 2 + 3] * scale2); + pp[12] = float2int8(p0[A_hstep * 3] * scale3); + pp[13] = float2int8(p0[A_hstep * 3 + 1] * scale3); + pp[14] = float2int8(p0[A_hstep * 3 + 2] * scale3); + pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + pp += 16; + p0 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1724,8 +4010,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i } } } -#else - signed char* pp = (signed char*)AT; #endif // __SSE2__ for (; ii + 1 < max_ii; ii += 2) { @@ -1738,6 +4022,37 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[A_hstep] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep + 2] * scale1); + pp[7] = float2int8(p0[A_hstep + 3] * scale1); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + pp += 8; + p0 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -1766,6 +4081,27 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i // if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + pp += 4; + p0 += 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -1782,7 +4118,7 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; const int K = A.dims == 3 ? A.c : A.h; - NCNN_LOGE("transpose_compute_A_tile_int8_scales %d %d", max_ii, elempack); + // NCNN_LOGE("transpose_compute_A_tile_int8_scales %d %d", max_ii, elempack); const float v127_B_scale = 127.f * B_scale; @@ -1805,7 +4141,7 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, { __m512 _p = _mm512_loadu_ps(p0); _absmax0 = _mm512_max_ps(_absmax0, abs512_ps(_p)); - p0 += A_hstep * 8; + p0 += A_hstep * 16; } float absmax = _mm512_reduce_max_ps(_absmax0); @@ -1888,6 +4224,22 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_A_tile_fp32_to_int8_avx512vnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_A_tile_fp32_to_int8_avxvnni(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { @@ -1901,14 +4253,14 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int NCNN_LOGE("transpose_pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + signed char* pp = (signed char*)AT; + int ii = 0; #if __SSE2__ #if __AVX__ #if __AVX512F__ for (; ii + 15 < max_ii; ii += 16) { - signed char* pp = (signed char*)AT + ii * max_kk; - const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; const float scale0 = scales[i + ii]; @@ -1931,6 +4283,577 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2 + 0] * scale0); + pp[3] = float2int8(p0[2 + 1] * scale0); + pp[4] = float2int8(p0[16] * scale1); + pp[5] = float2int8(p0[17] * scale1); + pp[6] = float2int8(p0[2 + 16] * scale1); + pp[7] = float2int8(p0[2 + 17] * scale1); + pp[8] = float2int8(p0[32] * scale2); + pp[9] = float2int8(p0[33] * scale2); + pp[10] = float2int8(p0[2 + 32] * scale2); + pp[11] = float2int8(p0[2 + 33] * scale2); + pp[12] = float2int8(p0[48] * scale3); + pp[13] = float2int8(p0[49] * scale3); + pp[14] = float2int8(p0[2 + 48] * scale3); + pp[15] = float2int8(p0[2 + 49] * scale3); + pp[16] = float2int8(p0[64] * scale4); + pp[17] = float2int8(p0[65] * scale4); + pp[18] = float2int8(p0[2 + 64] * scale4); + pp[19] = float2int8(p0[2 + 65] * scale4); + pp[20] = float2int8(p0[80] * scale5); + pp[21] = float2int8(p0[81] * scale5); + pp[22] = float2int8(p0[2 + 80] * scale5); + pp[23] = float2int8(p0[2 + 81] * scale5); + pp[24] = float2int8(p0[96] * scale6); + pp[25] = float2int8(p0[97] * scale6); + pp[26] = float2int8(p0[2 + 96] * scale6); + pp[27] = float2int8(p0[2 + 97] * scale6); + pp[28] = float2int8(p0[112] * scale7); + pp[29] = float2int8(p0[113] * scale7); + pp[30] = float2int8(p0[2 + 112] * scale7); + pp[31] = float2int8(p0[2 + 113] * scale7); + + pp[32 + 0] = float2int8(p0[128 + 0] * scale8); + pp[32 + 1] = float2int8(p0[128 + 1] * scale8); + pp[32 + 2] = float2int8(p0[2 + 128 + 0] * scale8); + pp[32 + 3] = float2int8(p0[2 + 128 + 1] * scale8); + pp[32 + 4] = float2int8(p0[128 + 16] * scale9); + pp[32 + 5] = float2int8(p0[128 + 17] * scale9); + pp[32 + 6] = float2int8(p0[2 + 128 + 16] * scale9); + pp[32 + 7] = float2int8(p0[2 + 128 + 17] * scale9); + pp[32 + 8] = float2int8(p0[128 + 32] * scalea); + pp[32 + 9] = float2int8(p0[128 + 33] * scalea); + pp[32 + 10] = float2int8(p0[2 + 128 + 32] * scalea); + pp[32 + 11] = float2int8(p0[2 + 128 + 33] * scalea); + pp[32 + 12] = float2int8(p0[128 + 48] * scaleb); + pp[32 + 13] = float2int8(p0[128 + 49] * scaleb); + pp[32 + 14] = float2int8(p0[2 + 128 + 48] * scaleb); + pp[32 + 15] = float2int8(p0[2 + 128 + 49] * scaleb); + pp[32 + 16] = float2int8(p0[128 + 64] * scalec); + pp[32 + 17] = float2int8(p0[128 + 65] * scalec); + pp[32 + 18] = float2int8(p0[2 + 128 + 64] * scalec); + pp[32 + 19] = float2int8(p0[2 + 128 + 65] * scalec); + pp[32 + 20] = float2int8(p0[128 + 80] * scaled); + pp[32 + 21] = float2int8(p0[128 + 81] * scaled); + pp[32 + 22] = float2int8(p0[2 + 128 + 80] * scaled); + pp[32 + 23] = float2int8(p0[2 + 128 + 81] * scaled); + pp[32 + 24] = float2int8(p0[128 + 96] * scalee); + pp[32 + 25] = float2int8(p0[128 + 97] * scalee); + pp[32 + 26] = float2int8(p0[2 + 128 + 96] * scalee); + pp[32 + 27] = float2int8(p0[2 + 128 + 97] * scalee); + pp[32 + 28] = float2int8(p0[128 + 112] * scalef); + pp[32 + 29] = float2int8(p0[128 + 113] * scalef); + pp[32 + 30] = float2int8(p0[2 + 128 + 112] * scalef); + pp[32 + 31] = float2int8(p0[2 + 128 + 113] * scalef); + + pp[64 + 0] = float2int8(p0[4 + 0] * scale0); + pp[64 + 1] = float2int8(p0[4 + 1] * scale0); + pp[64 + 2] = float2int8(p0[6 + 0] * scale0); + pp[64 + 3] = float2int8(p0[6 + 1] * scale0); + pp[64 + 4] = float2int8(p0[4 + 16] * scale1); + pp[64 + 5] = float2int8(p0[4 + 17] * scale1); + pp[64 + 6] = float2int8(p0[6 + 16] * scale1); + pp[64 + 7] = float2int8(p0[6 + 17] * scale1); + pp[64 + 8] = float2int8(p0[4 + 32] * scale2); + pp[64 + 9] = float2int8(p0[4 + 33] * scale2); + pp[64 + 10] = float2int8(p0[6 + 32] * scale2); + pp[64 + 11] = float2int8(p0[6 + 33] * scale2); + pp[64 + 12] = float2int8(p0[4 + 48] * scale3); + pp[64 + 13] = float2int8(p0[4 + 49] * scale3); + pp[64 + 14] = float2int8(p0[6 + 48] * scale3); + pp[64 + 15] = float2int8(p0[6 + 49] * scale3); + pp[64 + 16] = float2int8(p0[4 + 64] * scale4); + pp[64 + 17] = float2int8(p0[4 + 65] * scale4); + pp[64 + 18] = float2int8(p0[6 + 64] * scale4); + pp[64 + 19] = float2int8(p0[6 + 65] * scale4); + pp[64 + 20] = float2int8(p0[4 + 80] * scale5); + pp[64 + 21] = float2int8(p0[4 + 81] * scale5); + pp[64 + 22] = float2int8(p0[6 + 80] * scale5); + pp[64 + 23] = float2int8(p0[6 + 81] * scale5); + pp[64 + 24] = float2int8(p0[4 + 96] * scale6); + pp[64 + 25] = float2int8(p0[4 + 97] * scale6); + pp[64 + 26] = float2int8(p0[6 + 96] * scale6); + pp[64 + 27] = float2int8(p0[6 + 97] * scale6); + pp[64 + 28] = float2int8(p0[4 + 112] * scale7); + pp[64 + 29] = float2int8(p0[4 + 113] * scale7); + pp[64 + 30] = float2int8(p0[6 + 112] * scale7); + pp[64 + 31] = float2int8(p0[6 + 113] * scale7); + + pp[96 + 0] = float2int8(p0[4 + 128 + 0] * scale8); + pp[96 + 1] = float2int8(p0[4 + 128 + 1] * scale8); + pp[96 + 2] = float2int8(p0[6 + 128 + 0] * scale8); + pp[96 + 3] = float2int8(p0[6 + 128 + 1] * scale8); + pp[96 + 4] = float2int8(p0[4 + 128 + 16] * scale9); + pp[96 + 5] = float2int8(p0[4 + 128 + 17] * scale9); + pp[96 + 6] = float2int8(p0[6 + 128 + 16] * scale9); + pp[96 + 7] = float2int8(p0[6 + 128 + 17] * scale9); + pp[96 + 8] = float2int8(p0[4 + 128 + 32] * scalea); + pp[96 + 9] = float2int8(p0[4 + 128 + 33] * scalea); + pp[96 + 10] = float2int8(p0[6 + 128 + 32] * scalea); + pp[96 + 11] = float2int8(p0[6 + 128 + 33] * scalea); + pp[96 + 12] = float2int8(p0[4 + 128 + 48] * scaleb); + pp[96 + 13] = float2int8(p0[4 + 128 + 49] * scaleb); + pp[96 + 14] = float2int8(p0[6 + 128 + 48] * scaleb); + pp[96 + 15] = float2int8(p0[6 + 128 + 49] * scaleb); + pp[96 + 16] = float2int8(p0[4 + 128 + 64] * scalec); + pp[96 + 17] = float2int8(p0[4 + 128 + 65] * scalec); + pp[96 + 18] = float2int8(p0[6 + 128 + 64] * scalec); + pp[96 + 19] = float2int8(p0[6 + 128 + 65] * scalec); + pp[96 + 20] = float2int8(p0[4 + 128 + 80] * scaled); + pp[96 + 21] = float2int8(p0[4 + 128 + 81] * scaled); + pp[96 + 22] = float2int8(p0[6 + 128 + 80] * scaled); + pp[96 + 23] = float2int8(p0[6 + 128 + 81] * scaled); + pp[96 + 24] = float2int8(p0[4 + 128 + 96] * scalee); + pp[96 + 25] = float2int8(p0[4 + 128 + 97] * scalee); + pp[96 + 26] = float2int8(p0[6 + 128 + 96] * scalee); + pp[96 + 27] = float2int8(p0[6 + 128 + 97] * scalee); + pp[96 + 28] = float2int8(p0[4 + 128 + 112] * scalef); + pp[96 + 29] = float2int8(p0[4 + 128 + 113] * scalef); + pp[96 + 30] = float2int8(p0[6 + 128 + 112] * scalef); + pp[96 + 31] = float2int8(p0[6 + 128 + 113] * scalef); + + pp[128 + 0] = float2int8(p0[8 + 0] * scale0); + pp[128 + 1] = float2int8(p0[8 + 1] * scale0); + pp[128 + 2] = float2int8(p0[10 + 0] * scale0); + pp[128 + 3] = float2int8(p0[10 + 1] * scale0); + pp[128 + 4] = float2int8(p0[8 + 16] * scale1); + pp[128 + 5] = float2int8(p0[8 + 17] * scale1); + pp[128 + 6] = float2int8(p0[10 + 16] * scale1); + pp[128 + 7] = float2int8(p0[10 + 17] * scale1); + pp[128 + 8] = float2int8(p0[8 + 32] * scale2); + pp[128 + 9] = float2int8(p0[8 + 33] * scale2); + pp[128 + 10] = float2int8(p0[10 + 32] * scale2); + pp[128 + 11] = float2int8(p0[10 + 33] * scale2); + pp[128 + 12] = float2int8(p0[8 + 48] * scale3); + pp[128 + 13] = float2int8(p0[8 + 49] * scale3); + pp[128 + 14] = float2int8(p0[10 + 48] * scale3); + pp[128 + 15] = float2int8(p0[10 + 49] * scale3); + pp[128 + 16] = float2int8(p0[8 + 64] * scale4); + pp[128 + 17] = float2int8(p0[8 + 65] * scale4); + pp[128 + 18] = float2int8(p0[10 + 64] * scale4); + pp[128 + 19] = float2int8(p0[10 + 65] * scale4); + pp[128 + 20] = float2int8(p0[8 + 80] * scale5); + pp[128 + 21] = float2int8(p0[8 + 81] * scale5); + pp[128 + 22] = float2int8(p0[10 + 80] * scale5); + pp[128 + 23] = float2int8(p0[10 + 81] * scale5); + pp[128 + 24] = float2int8(p0[8 + 96] * scale6); + pp[128 + 25] = float2int8(p0[8 + 97] * scale6); + pp[128 + 26] = float2int8(p0[10 + 96] * scale6); + pp[128 + 27] = float2int8(p0[10 + 97] * scale6); + pp[128 + 28] = float2int8(p0[8 + 112] * scale7); + pp[128 + 29] = float2int8(p0[8 + 113] * scale7); + pp[128 + 30] = float2int8(p0[10 + 112] * scale7); + pp[128 + 31] = float2int8(p0[10 + 113] * scale7); + + pp[160 + 0] = float2int8(p0[8 + 128 + 0] * scale8); + pp[160 + 1] = float2int8(p0[8 + 128 + 1] * scale8); + pp[160 + 2] = float2int8(p0[10 + 128 + 0] * scale8); + pp[160 + 3] = float2int8(p0[10 + 128 + 1] * scale8); + pp[160 + 4] = float2int8(p0[8 + 128 + 16] * scale9); + pp[160 + 5] = float2int8(p0[8 + 128 + 17] * scale9); + pp[160 + 6] = float2int8(p0[10 + 128 + 16] * scale9); + pp[160 + 7] = float2int8(p0[10 + 128 + 17] * scale9); + pp[160 + 8] = float2int8(p0[8 + 128 + 32] * scalea); + pp[160 + 9] = float2int8(p0[8 + 128 + 33] * scalea); + pp[160 + 10] = float2int8(p0[10 + 128 + 32] * scalea); + pp[160 + 11] = float2int8(p0[10 + 128 + 33] * scalea); + pp[160 + 12] = float2int8(p0[8 + 128 + 48] * scaleb); + pp[160 + 13] = float2int8(p0[8 + 128 + 49] * scaleb); + pp[160 + 14] = float2int8(p0[10 + 128 + 48] * scaleb); + pp[160 + 15] = float2int8(p0[10 + 128 + 49] * scaleb); + pp[160 + 16] = float2int8(p0[8 + 128 + 64] * scalec); + pp[160 + 17] = float2int8(p0[8 + 128 + 65] * scalec); + pp[160 + 18] = float2int8(p0[10 + 128 + 64] * scalec); + pp[160 + 19] = float2int8(p0[10 + 128 + 65] * scalec); + pp[160 + 20] = float2int8(p0[8 + 128 + 80] * scaled); + pp[160 + 21] = float2int8(p0[8 + 128 + 81] * scaled); + pp[160 + 22] = float2int8(p0[10 + 128 + 80] * scaled); + pp[160 + 23] = float2int8(p0[10 + 128 + 81] * scaled); + pp[160 + 24] = float2int8(p0[8 + 128 + 96] * scalee); + pp[160 + 25] = float2int8(p0[8 + 128 + 97] * scalee); + pp[160 + 26] = float2int8(p0[10 + 128 + 96] * scalee); + pp[160 + 27] = float2int8(p0[10 + 128 + 97] * scalee); + pp[160 + 28] = float2int8(p0[8 + 128 + 112] * scalef); + pp[160 + 29] = float2int8(p0[8 + 128 + 113] * scalef); + pp[160 + 30] = float2int8(p0[10 + 128 + 112] * scalef); + pp[160 + 31] = float2int8(p0[10 + 128 + 113] * scalef); + + pp[192 + 0] = float2int8(p0[12 + 0] * scale0); + pp[192 + 1] = float2int8(p0[12 + 1] * scale0); + pp[192 + 2] = float2int8(p0[14 + 0] * scale0); + pp[192 + 3] = float2int8(p0[14 + 1] * scale0); + pp[192 + 4] = float2int8(p0[12 + 16] * scale1); + pp[192 + 5] = float2int8(p0[12 + 17] * scale1); + pp[192 + 6] = float2int8(p0[14 + 16] * scale1); + pp[192 + 7] = float2int8(p0[14 + 17] * scale1); + pp[192 + 8] = float2int8(p0[12 + 32] * scale2); + pp[192 + 9] = float2int8(p0[12 + 33] * scale2); + pp[192 + 10] = float2int8(p0[14 + 32] * scale2); + pp[192 + 11] = float2int8(p0[14 + 33] * scale2); + pp[192 + 12] = float2int8(p0[12 + 48] * scale3); + pp[192 + 13] = float2int8(p0[12 + 49] * scale3); + pp[192 + 14] = float2int8(p0[14 + 48] * scale3); + pp[192 + 15] = float2int8(p0[14 + 49] * scale3); + pp[192 + 16] = float2int8(p0[12 + 64] * scale4); + pp[192 + 17] = float2int8(p0[12 + 65] * scale4); + pp[192 + 18] = float2int8(p0[14 + 64] * scale4); + pp[192 + 19] = float2int8(p0[14 + 65] * scale4); + pp[192 + 20] = float2int8(p0[12 + 80] * scale5); + pp[192 + 21] = float2int8(p0[12 + 81] * scale5); + pp[192 + 22] = float2int8(p0[14 + 80] * scale5); + pp[192 + 23] = float2int8(p0[14 + 81] * scale5); + pp[192 + 24] = float2int8(p0[12 + 96] * scale6); + pp[192 + 25] = float2int8(p0[12 + 97] * scale6); + pp[192 + 26] = float2int8(p0[14 + 96] * scale6); + pp[192 + 27] = float2int8(p0[14 + 97] * scale6); + pp[192 + 28] = float2int8(p0[12 + 112] * scale7); + pp[192 + 29] = float2int8(p0[12 + 113] * scale7); + pp[192 + 30] = float2int8(p0[14 + 112] * scale7); + pp[192 + 31] = float2int8(p0[14 + 113] * scale7); + + pp[224 + 0] = float2int8(p0[12 + 128 + 0] * scale8); + pp[224 + 1] = float2int8(p0[12 + 128 + 1] * scale8); + pp[224 + 2] = float2int8(p0[14 + 128 + 0] * scale8); + pp[224 + 3] = float2int8(p0[14 + 128 + 1] * scale8); + pp[224 + 4] = float2int8(p0[12 + 128 + 16] * scale9); + pp[224 + 5] = float2int8(p0[12 + 128 + 17] * scale9); + pp[224 + 6] = float2int8(p0[14 + 128 + 16] * scale9); + pp[224 + 7] = float2int8(p0[14 + 128 + 17] * scale9); + pp[224 + 8] = float2int8(p0[12 + 128 + 32] * scalea); + pp[224 + 9] = float2int8(p0[12 + 128 + 33] * scalea); + pp[224 + 10] = float2int8(p0[14 + 128 + 32] * scalea); + pp[224 + 11] = float2int8(p0[14 + 128 + 33] * scalea); + pp[224 + 12] = float2int8(p0[12 + 128 + 48] * scaleb); + pp[224 + 13] = float2int8(p0[12 + 128 + 49] * scaleb); + pp[224 + 14] = float2int8(p0[14 + 128 + 48] * scaleb); + pp[224 + 15] = float2int8(p0[14 + 128 + 49] * scaleb); + pp[224 + 16] = float2int8(p0[12 + 128 + 64] * scalec); + pp[224 + 17] = float2int8(p0[12 + 128 + 65] * scalec); + pp[224 + 18] = float2int8(p0[14 + 128 + 64] * scalec); + pp[224 + 19] = float2int8(p0[14 + 128 + 65] * scalec); + pp[224 + 20] = float2int8(p0[12 + 128 + 80] * scaled); + pp[224 + 21] = float2int8(p0[12 + 128 + 81] * scaled); + pp[224 + 22] = float2int8(p0[14 + 128 + 80] * scaled); + pp[224 + 23] = float2int8(p0[14 + 128 + 81] * scaled); + pp[224 + 24] = float2int8(p0[12 + 128 + 96] * scalee); + pp[224 + 25] = float2int8(p0[12 + 128 + 97] * scalee); + pp[224 + 26] = float2int8(p0[14 + 128 + 96] * scalee); + pp[224 + 27] = float2int8(p0[14 + 128 + 97] * scalee); + pp[224 + 28] = float2int8(p0[12 + 128 + 112] * scalef); + pp[224 + 29] = float2int8(p0[12 + 128 + 113] * scalef); + pp[224 + 30] = float2int8(p0[14 + 128 + 112] * scalef); + pp[224 + 31] = float2int8(p0[14 + 128 + 113] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + w_shift0 += pp[64 + 0]; + w_shift0 += pp[64 + 1]; + w_shift0 += pp[64 + 2]; + w_shift0 += pp[64 + 3]; + w_shift1 += pp[64 + 4]; + w_shift1 += pp[64 + 5]; + w_shift1 += pp[64 + 6]; + w_shift1 += pp[64 + 7]; + w_shift2 += pp[64 + 8]; + w_shift2 += pp[64 + 9]; + w_shift2 += pp[64 + 10]; + w_shift2 += pp[64 + 11]; + w_shift3 += pp[64 + 12]; + w_shift3 += pp[64 + 13]; + w_shift3 += pp[64 + 14]; + w_shift3 += pp[64 + 15]; + w_shift4 += pp[64 + 16]; + w_shift4 += pp[64 + 17]; + w_shift4 += pp[64 + 18]; + w_shift4 += pp[64 + 19]; + w_shift5 += pp[64 + 20]; + w_shift5 += pp[64 + 21]; + w_shift5 += pp[64 + 22]; + w_shift5 += pp[64 + 23]; + w_shift6 += pp[64 + 24]; + w_shift6 += pp[64 + 25]; + w_shift6 += pp[64 + 26]; + w_shift6 += pp[64 + 27]; + w_shift7 += pp[64 + 28]; + w_shift7 += pp[64 + 29]; + w_shift7 += pp[64 + 30]; + w_shift7 += pp[64 + 31]; + + w_shift8 += pp[96 + 0]; + w_shift8 += pp[96 + 1]; + w_shift8 += pp[96 + 2]; + w_shift8 += pp[96 + 3]; + w_shift9 += pp[96 + 4]; + w_shift9 += pp[96 + 5]; + w_shift9 += pp[96 + 6]; + w_shift9 += pp[96 + 7]; + w_shifta += pp[96 + 8]; + w_shifta += pp[96 + 9]; + w_shifta += pp[96 + 10]; + w_shifta += pp[96 + 11]; + w_shiftb += pp[96 + 12]; + w_shiftb += pp[96 + 13]; + w_shiftb += pp[96 + 14]; + w_shiftb += pp[96 + 15]; + w_shiftc += pp[96 + 16]; + w_shiftc += pp[96 + 17]; + w_shiftc += pp[96 + 18]; + w_shiftc += pp[96 + 19]; + w_shiftd += pp[96 + 20]; + w_shiftd += pp[96 + 21]; + w_shiftd += pp[96 + 22]; + w_shiftd += pp[96 + 23]; + w_shifte += pp[96 + 24]; + w_shifte += pp[96 + 25]; + w_shifte += pp[96 + 26]; + w_shifte += pp[96 + 27]; + w_shiftf += pp[96 + 28]; + w_shiftf += pp[96 + 29]; + w_shiftf += pp[96 + 30]; + w_shiftf += pp[96 + 31]; + + w_shift0 += pp[128 + 0]; + w_shift0 += pp[128 + 1]; + w_shift0 += pp[128 + 2]; + w_shift0 += pp[128 + 3]; + w_shift1 += pp[128 + 4]; + w_shift1 += pp[128 + 5]; + w_shift1 += pp[128 + 6]; + w_shift1 += pp[128 + 7]; + w_shift2 += pp[128 + 8]; + w_shift2 += pp[128 + 9]; + w_shift2 += pp[128 + 10]; + w_shift2 += pp[128 + 11]; + w_shift3 += pp[128 + 12]; + w_shift3 += pp[128 + 13]; + w_shift3 += pp[128 + 14]; + w_shift3 += pp[128 + 15]; + w_shift4 += pp[128 + 16]; + w_shift4 += pp[128 + 17]; + w_shift4 += pp[128 + 18]; + w_shift4 += pp[128 + 19]; + w_shift5 += pp[128 + 20]; + w_shift5 += pp[128 + 21]; + w_shift5 += pp[128 + 22]; + w_shift5 += pp[128 + 23]; + w_shift6 += pp[128 + 24]; + w_shift6 += pp[128 + 25]; + w_shift6 += pp[128 + 26]; + w_shift6 += pp[128 + 27]; + w_shift7 += pp[128 + 28]; + w_shift7 += pp[128 + 29]; + w_shift7 += pp[128 + 30]; + w_shift7 += pp[128 + 31]; + + w_shift8 += pp[160 + 0]; + w_shift8 += pp[160 + 1]; + w_shift8 += pp[160 + 2]; + w_shift8 += pp[160 + 3]; + w_shift9 += pp[160 + 4]; + w_shift9 += pp[160 + 5]; + w_shift9 += pp[160 + 6]; + w_shift9 += pp[160 + 7]; + w_shifta += pp[160 + 8]; + w_shifta += pp[160 + 9]; + w_shifta += pp[160 + 10]; + w_shifta += pp[160 + 11]; + w_shiftb += pp[160 + 12]; + w_shiftb += pp[160 + 13]; + w_shiftb += pp[160 + 14]; + w_shiftb += pp[160 + 15]; + w_shiftc += pp[160 + 16]; + w_shiftc += pp[160 + 17]; + w_shiftc += pp[160 + 18]; + w_shiftc += pp[160 + 19]; + w_shiftd += pp[160 + 20]; + w_shiftd += pp[160 + 21]; + w_shiftd += pp[160 + 22]; + w_shiftd += pp[160 + 23]; + w_shifte += pp[160 + 24]; + w_shifte += pp[160 + 25]; + w_shifte += pp[160 + 26]; + w_shifte += pp[160 + 27]; + w_shiftf += pp[160 + 28]; + w_shiftf += pp[160 + 29]; + w_shiftf += pp[160 + 30]; + w_shiftf += pp[160 + 31]; + + w_shift0 += pp[192 + 0]; + w_shift0 += pp[192 + 1]; + w_shift0 += pp[192 + 2]; + w_shift0 += pp[192 + 3]; + w_shift1 += pp[192 + 4]; + w_shift1 += pp[192 + 5]; + w_shift1 += pp[192 + 6]; + w_shift1 += pp[192 + 7]; + w_shift2 += pp[192 + 8]; + w_shift2 += pp[192 + 9]; + w_shift2 += pp[192 + 10]; + w_shift2 += pp[192 + 11]; + w_shift3 += pp[192 + 12]; + w_shift3 += pp[192 + 13]; + w_shift3 += pp[192 + 14]; + w_shift3 += pp[192 + 15]; + w_shift4 += pp[192 + 16]; + w_shift4 += pp[192 + 17]; + w_shift4 += pp[192 + 18]; + w_shift4 += pp[192 + 19]; + w_shift5 += pp[192 + 20]; + w_shift5 += pp[192 + 21]; + w_shift5 += pp[192 + 22]; + w_shift5 += pp[192 + 23]; + w_shift6 += pp[192 + 24]; + w_shift6 += pp[192 + 25]; + w_shift6 += pp[192 + 26]; + w_shift6 += pp[192 + 27]; + w_shift7 += pp[192 + 28]; + w_shift7 += pp[192 + 29]; + w_shift7 += pp[192 + 30]; + w_shift7 += pp[192 + 31]; + + w_shift8 += pp[224 + 0]; + w_shift8 += pp[224 + 1]; + w_shift8 += pp[224 + 2]; + w_shift8 += pp[224 + 3]; + w_shift9 += pp[224 + 4]; + w_shift9 += pp[224 + 5]; + w_shift9 += pp[224 + 6]; + w_shift9 += pp[224 + 7]; + w_shifta += pp[224 + 8]; + w_shifta += pp[224 + 9]; + w_shifta += pp[224 + 10]; + w_shifta += pp[224 + 11]; + w_shiftb += pp[224 + 12]; + w_shiftb += pp[224 + 13]; + w_shiftb += pp[224 + 14]; + w_shiftb += pp[224 + 15]; + w_shiftc += pp[224 + 16]; + w_shiftc += pp[224 + 17]; + w_shiftc += pp[224 + 18]; + w_shiftc += pp[224 + 19]; + w_shiftd += pp[224 + 20]; + w_shiftd += pp[224 + 21]; + w_shiftd += pp[224 + 22]; + w_shiftd += pp[224 + 23]; + w_shifte += pp[224 + 24]; + w_shifte += pp[224 + 25]; + w_shifte += pp[224 + 26]; + w_shifte += pp[224 + 27]; + w_shiftf += pp[224 + 28]; + w_shiftf += pp[224 + 29]; + w_shiftf += pp[224 + 30]; + w_shiftf += pp[224 + 31]; + + pp += 256; + p0 += A_hstep * 16; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { pp[0] = float2int8(p0[0] * scale0); @@ -2208,10 +5131,318 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 256; p0 += A_hstep * 16; } +#endif // __AVX512VNNI__ } if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[8] * scale1); + pp[5] = float2int8(p0[9] * scale1); + pp[6] = float2int8(p0[10] * scale1); + pp[7] = float2int8(p0[11] * scale1); + pp[8] = float2int8(p0[16] * scale2); + pp[9] = float2int8(p0[17] * scale2); + pp[10] = float2int8(p0[18] * scale2); + pp[11] = float2int8(p0[19] * scale2); + pp[12] = float2int8(p0[24] * scale3); + pp[13] = float2int8(p0[25] * scale3); + pp[14] = float2int8(p0[26] * scale3); + pp[15] = float2int8(p0[27] * scale3); + pp[16] = float2int8(p0[32] * scale4); + pp[17] = float2int8(p0[33] * scale4); + pp[18] = float2int8(p0[34] * scale4); + pp[19] = float2int8(p0[35] * scale4); + pp[20] = float2int8(p0[40] * scale5); + pp[21] = float2int8(p0[41] * scale5); + pp[22] = float2int8(p0[42] * scale5); + pp[23] = float2int8(p0[43] * scale5); + pp[24] = float2int8(p0[48] * scale6); + pp[25] = float2int8(p0[49] * scale6); + pp[26] = float2int8(p0[50] * scale6); + pp[27] = float2int8(p0[51] * scale6); + pp[28] = float2int8(p0[56] * scale7); + pp[29] = float2int8(p0[57] * scale7); + pp[30] = float2int8(p0[58] * scale7); + pp[31] = float2int8(p0[59] * scale7); + + pp[32 + 0] = float2int8(p0[64 + 0] * scale8); + pp[32 + 1] = float2int8(p0[64 + 1] * scale8); + pp[32 + 2] = float2int8(p0[64 + 2] * scale8); + pp[32 + 3] = float2int8(p0[64 + 3] * scale8); + pp[32 + 4] = float2int8(p0[64 + 8] * scale9); + pp[32 + 5] = float2int8(p0[64 + 9] * scale9); + pp[32 + 6] = float2int8(p0[64 + 10] * scale9); + pp[32 + 7] = float2int8(p0[64 + 11] * scale9); + pp[32 + 8] = float2int8(p0[64 + 16] * scalea); + pp[32 + 9] = float2int8(p0[64 + 17] * scalea); + pp[32 + 10] = float2int8(p0[64 + 18] * scalea); + pp[32 + 11] = float2int8(p0[64 + 19] * scalea); + pp[32 + 12] = float2int8(p0[64 + 24] * scaleb); + pp[32 + 13] = float2int8(p0[64 + 25] * scaleb); + pp[32 + 14] = float2int8(p0[64 + 26] * scaleb); + pp[32 + 15] = float2int8(p0[64 + 27] * scaleb); + pp[32 + 16] = float2int8(p0[64 + 32] * scalec); + pp[32 + 17] = float2int8(p0[64 + 33] * scalec); + pp[32 + 18] = float2int8(p0[64 + 34] * scalec); + pp[32 + 19] = float2int8(p0[64 + 35] * scalec); + pp[32 + 20] = float2int8(p0[64 + 40] * scaled); + pp[32 + 21] = float2int8(p0[64 + 41] * scaled); + pp[32 + 22] = float2int8(p0[64 + 42] * scaled); + pp[32 + 23] = float2int8(p0[64 + 43] * scaled); + pp[32 + 24] = float2int8(p0[64 + 48] * scalee); + pp[32 + 25] = float2int8(p0[64 + 49] * scalee); + pp[32 + 26] = float2int8(p0[64 + 50] * scalee); + pp[32 + 27] = float2int8(p0[64 + 51] * scalee); + pp[32 + 28] = float2int8(p0[64 + 56] * scalef); + pp[32 + 29] = float2int8(p0[64 + 57] * scalef); + pp[32 + 30] = float2int8(p0[64 + 58] * scalef); + pp[32 + 31] = float2int8(p0[64 + 59] * scalef); + + pp[64 + 0] = float2int8(p0[4] * scale0); + pp[64 + 1] = float2int8(p0[5] * scale0); + pp[64 + 2] = float2int8(p0[6] * scale0); + pp[64 + 3] = float2int8(p0[7] * scale0); + pp[64 + 4] = float2int8(p0[12] * scale1); + pp[64 + 5] = float2int8(p0[13] * scale1); + pp[64 + 6] = float2int8(p0[14] * scale1); + pp[64 + 7] = float2int8(p0[15] * scale1); + pp[64 + 8] = float2int8(p0[20] * scale2); + pp[64 + 9] = float2int8(p0[21] * scale2); + pp[64 + 10] = float2int8(p0[22] * scale2); + pp[64 + 11] = float2int8(p0[23] * scale2); + pp[64 + 12] = float2int8(p0[28] * scale3); + pp[64 + 13] = float2int8(p0[29] * scale3); + pp[64 + 14] = float2int8(p0[30] * scale3); + pp[64 + 15] = float2int8(p0[31] * scale3); + pp[64 + 16] = float2int8(p0[36] * scale4); + pp[64 + 17] = float2int8(p0[37] * scale4); + pp[64 + 18] = float2int8(p0[38] * scale4); + pp[64 + 19] = float2int8(p0[39] * scale4); + pp[64 + 20] = float2int8(p0[44] * scale5); + pp[64 + 21] = float2int8(p0[45] * scale5); + pp[64 + 22] = float2int8(p0[46] * scale5); + pp[64 + 23] = float2int8(p0[47] * scale5); + pp[64 + 24] = float2int8(p0[52] * scale6); + pp[64 + 25] = float2int8(p0[53] * scale6); + pp[64 + 26] = float2int8(p0[54] * scale6); + pp[64 + 27] = float2int8(p0[55] * scale6); + pp[64 + 28] = float2int8(p0[60] * scale7); + pp[64 + 29] = float2int8(p0[61] * scale7); + pp[64 + 30] = float2int8(p0[62] * scale7); + pp[64 + 31] = float2int8(p0[63] * scale7); + + pp[96 + 0] = float2int8(p0[64 + 4] * scale8); + pp[96 + 1] = float2int8(p0[64 + 5] * scale8); + pp[96 + 2] = float2int8(p0[64 + 6] * scale8); + pp[96 + 3] = float2int8(p0[64 + 7] * scale8); + pp[96 + 4] = float2int8(p0[64 + 12] * scale9); + pp[96 + 5] = float2int8(p0[64 + 13] * scale9); + pp[96 + 6] = float2int8(p0[64 + 14] * scale9); + pp[96 + 7] = float2int8(p0[64 + 15] * scale9); + pp[96 + 8] = float2int8(p0[64 + 20] * scalea); + pp[96 + 9] = float2int8(p0[64 + 21] * scalea); + pp[96 + 10] = float2int8(p0[64 + 22] * scalea); + pp[96 + 11] = float2int8(p0[64 + 23] * scalea); + pp[96 + 12] = float2int8(p0[64 + 28] * scaleb); + pp[96 + 13] = float2int8(p0[64 + 29] * scaleb); + pp[96 + 14] = float2int8(p0[64 + 30] * scaleb); + pp[96 + 15] = float2int8(p0[64 + 31] * scaleb); + pp[96 + 16] = float2int8(p0[64 + 36] * scalec); + pp[96 + 17] = float2int8(p0[64 + 37] * scalec); + pp[96 + 18] = float2int8(p0[64 + 38] * scalec); + pp[96 + 19] = float2int8(p0[64 + 39] * scalec); + pp[96 + 20] = float2int8(p0[64 + 44] * scaled); + pp[96 + 21] = float2int8(p0[64 + 45] * scaled); + pp[96 + 22] = float2int8(p0[64 + 46] * scaled); + pp[96 + 23] = float2int8(p0[64 + 47] * scaled); + pp[96 + 24] = float2int8(p0[64 + 52] * scalee); + pp[96 + 25] = float2int8(p0[64 + 53] * scalee); + pp[96 + 26] = float2int8(p0[64 + 54] * scalee); + pp[96 + 27] = float2int8(p0[64 + 55] * scalee); + pp[96 + 28] = float2int8(p0[64 + 60] * scalef); + pp[96 + 29] = float2int8(p0[64 + 61] * scalef); + pp[96 + 30] = float2int8(p0[64 + 62] * scalef); + pp[96 + 31] = float2int8(p0[64 + 63] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + w_shift0 += pp[64 + 0]; + w_shift0 += pp[64 + 1]; + w_shift0 += pp[64 + 2]; + w_shift0 += pp[64 + 3]; + w_shift1 += pp[64 + 4]; + w_shift1 += pp[64 + 5]; + w_shift1 += pp[64 + 6]; + w_shift1 += pp[64 + 7]; + w_shift2 += pp[64 + 8]; + w_shift2 += pp[64 + 9]; + w_shift2 += pp[64 + 10]; + w_shift2 += pp[64 + 11]; + w_shift3 += pp[64 + 12]; + w_shift3 += pp[64 + 13]; + w_shift3 += pp[64 + 14]; + w_shift3 += pp[64 + 15]; + w_shift4 += pp[64 + 16]; + w_shift4 += pp[64 + 17]; + w_shift4 += pp[64 + 18]; + w_shift4 += pp[64 + 19]; + w_shift5 += pp[64 + 20]; + w_shift5 += pp[64 + 21]; + w_shift5 += pp[64 + 22]; + w_shift5 += pp[64 + 23]; + w_shift6 += pp[64 + 24]; + w_shift6 += pp[64 + 25]; + w_shift6 += pp[64 + 26]; + w_shift6 += pp[64 + 27]; + w_shift7 += pp[64 + 28]; + w_shift7 += pp[64 + 29]; + w_shift7 += pp[64 + 30]; + w_shift7 += pp[64 + 31]; + + w_shift8 += pp[96 + 0]; + w_shift8 += pp[96 + 1]; + w_shift8 += pp[96 + 2]; + w_shift8 += pp[96 + 3]; + w_shift9 += pp[96 + 4]; + w_shift9 += pp[96 + 5]; + w_shift9 += pp[96 + 6]; + w_shift9 += pp[96 + 7]; + w_shifta += pp[96 + 8]; + w_shifta += pp[96 + 9]; + w_shifta += pp[96 + 10]; + w_shifta += pp[96 + 11]; + w_shiftb += pp[96 + 12]; + w_shiftb += pp[96 + 13]; + w_shiftb += pp[96 + 14]; + w_shiftb += pp[96 + 15]; + w_shiftc += pp[96 + 16]; + w_shiftc += pp[96 + 17]; + w_shiftc += pp[96 + 18]; + w_shiftc += pp[96 + 19]; + w_shiftd += pp[96 + 20]; + w_shiftd += pp[96 + 21]; + w_shiftd += pp[96 + 22]; + w_shiftd += pp[96 + 23]; + w_shifte += pp[96 + 24]; + w_shifte += pp[96 + 25]; + w_shifte += pp[96 + 26]; + w_shifte += pp[96 + 27]; + w_shiftf += pp[96 + 28]; + w_shiftf += pp[96 + 29]; + w_shiftf += pp[96 + 30]; + w_shiftf += pp[96 + 31]; + + pp += 128; + p0 += A_hstep * 8; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { pp[0] = float2int8(p0[0] * scale0); @@ -2353,10 +5584,186 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 128; p0 += A_hstep * 8; } +#endif // __AVX512VNNI__ } if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[4] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[6] * scale1); + pp[7] = float2int8(p0[7] * scale1); + pp[8] = float2int8(p0[8] * scale2); + pp[9] = float2int8(p0[9] * scale2); + pp[10] = float2int8(p0[10] * scale2); + pp[11] = float2int8(p0[11] * scale2); + pp[12] = float2int8(p0[12] * scale3); + pp[13] = float2int8(p0[13] * scale3); + pp[14] = float2int8(p0[14] * scale3); + pp[15] = float2int8(p0[15] * scale3); + pp[16] = float2int8(p0[16] * scale4); + pp[17] = float2int8(p0[17] * scale4); + pp[18] = float2int8(p0[18] * scale4); + pp[19] = float2int8(p0[19] * scale4); + pp[20] = float2int8(p0[20] * scale5); + pp[21] = float2int8(p0[21] * scale5); + pp[22] = float2int8(p0[22] * scale5); + pp[23] = float2int8(p0[23] * scale5); + pp[24] = float2int8(p0[24] * scale6); + pp[25] = float2int8(p0[25] * scale6); + pp[26] = float2int8(p0[26] * scale6); + pp[27] = float2int8(p0[27] * scale6); + pp[28] = float2int8(p0[28] * scale7); + pp[29] = float2int8(p0[29] * scale7); + pp[30] = float2int8(p0[30] * scale7); + pp[31] = float2int8(p0[31] * scale7); + + pp[32 + 0] = float2int8(p0[32 + 0] * scale8); + pp[32 + 1] = float2int8(p0[32 + 1] * scale8); + pp[32 + 2] = float2int8(p0[32 + 2] * scale8); + pp[32 + 3] = float2int8(p0[32 + 3] * scale8); + pp[32 + 4] = float2int8(p0[32 + 4] * scale9); + pp[32 + 5] = float2int8(p0[32 + 5] * scale9); + pp[32 + 6] = float2int8(p0[32 + 6] * scale9); + pp[32 + 7] = float2int8(p0[32 + 7] * scale9); + pp[32 + 8] = float2int8(p0[32 + 8] * scalea); + pp[32 + 9] = float2int8(p0[32 + 9] * scalea); + pp[32 + 10] = float2int8(p0[32 + 10] * scalea); + pp[32 + 11] = float2int8(p0[32 + 11] * scalea); + pp[32 + 12] = float2int8(p0[32 + 12] * scaleb); + pp[32 + 13] = float2int8(p0[32 + 13] * scaleb); + pp[32 + 14] = float2int8(p0[32 + 14] * scaleb); + pp[32 + 15] = float2int8(p0[32 + 15] * scaleb); + pp[32 + 16] = float2int8(p0[32 + 16] * scalec); + pp[32 + 17] = float2int8(p0[32 + 17] * scalec); + pp[32 + 18] = float2int8(p0[32 + 18] * scalec); + pp[32 + 19] = float2int8(p0[32 + 19] * scalec); + pp[32 + 20] = float2int8(p0[32 + 20] * scaled); + pp[32 + 21] = float2int8(p0[32 + 21] * scaled); + pp[32 + 22] = float2int8(p0[32 + 22] * scaled); + pp[32 + 23] = float2int8(p0[32 + 23] * scaled); + pp[32 + 24] = float2int8(p0[32 + 24] * scalee); + pp[32 + 25] = float2int8(p0[32 + 25] * scalee); + pp[32 + 26] = float2int8(p0[32 + 26] * scalee); + pp[32 + 27] = float2int8(p0[32 + 27] * scalee); + pp[32 + 28] = float2int8(p0[32 + 28] * scalef); + pp[32 + 29] = float2int8(p0[32 + 29] * scalef); + pp[32 + 30] = float2int8(p0[32 + 30] * scalef); + pp[32 + 31] = float2int8(p0[32 + 31] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); @@ -2430,10 +5837,186 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 64; p0 += A_hstep * 4; } +#endif // __AVX512VNNI__ } if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + int w_shift8 = 0; + int w_shift9 = 0; + int w_shifta = 0; + int w_shiftb = 0; + int w_shiftc = 0; + int w_shiftd = 0; + int w_shifte = 0; + int w_shiftf = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale0); + pp[2] = float2int8(p0[A_hstep * 2] * scale0); + pp[3] = float2int8(p0[A_hstep * 3] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep * 2 + 1] * scale1); + pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[A_hstep + 2] * scale2); + pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); + pp[11] = float2int8(p0[A_hstep * 3 + 2] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[A_hstep + 3] * scale3); + pp[14] = float2int8(p0[A_hstep * 2 + 3] * scale3); + pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); + pp[16] = float2int8(p0[4] * scale4); + pp[17] = float2int8(p0[A_hstep + 4] * scale4); + pp[18] = float2int8(p0[A_hstep * 2 + 4] * scale4); + pp[19] = float2int8(p0[A_hstep * 3 + 4] * scale4); + pp[20] = float2int8(p0[5] * scale5); + pp[21] = float2int8(p0[A_hstep + 5] * scale5); + pp[22] = float2int8(p0[A_hstep * 2 + 5] * scale5); + pp[23] = float2int8(p0[A_hstep * 3 + 5] * scale5); + pp[24] = float2int8(p0[6] * scale6); + pp[25] = float2int8(p0[A_hstep + 6] * scale6); + pp[26] = float2int8(p0[A_hstep * 2 + 6] * scale6); + pp[27] = float2int8(p0[A_hstep * 3 + 6] * scale6); + pp[28] = float2int8(p0[7] * scale7); + pp[29] = float2int8(p0[A_hstep + 7] * scale7); + pp[30] = float2int8(p0[A_hstep * 2 + 7] * scale7); + pp[31] = float2int8(p0[A_hstep * 3 + 7] * scale7); + + pp[32 + 0] = float2int8(p0[8] * scale8); + pp[32 + 1] = float2int8(p0[A_hstep + 8] * scale8); + pp[32 + 2] = float2int8(p0[A_hstep * 2 + 8] * scale8); + pp[32 + 3] = float2int8(p0[A_hstep * 3 + 8] * scale8); + pp[32 + 4] = float2int8(p0[9] * scale9); + pp[32 + 5] = float2int8(p0[A_hstep + 9] * scale9); + pp[32 + 6] = float2int8(p0[A_hstep * 2 + 9] * scale9); + pp[32 + 7] = float2int8(p0[A_hstep * 3 + 9] * scale9); + pp[32 + 8] = float2int8(p0[10] * scalea); + pp[32 + 9] = float2int8(p0[A_hstep + 10] * scalea); + pp[32 + 10] = float2int8(p0[A_hstep * 2 + 10] * scalea); + pp[32 + 11] = float2int8(p0[A_hstep * 3 + 10] * scalea); + pp[32 + 12] = float2int8(p0[11] * scaleb); + pp[32 + 13] = float2int8(p0[A_hstep + 11] * scaleb); + pp[32 + 14] = float2int8(p0[A_hstep * 2 + 11] * scaleb); + pp[32 + 15] = float2int8(p0[A_hstep * 3 + 11] * scaleb); + pp[32 + 16] = float2int8(p0[12] * scalec); + pp[32 + 17] = float2int8(p0[A_hstep + 12] * scalec); + pp[32 + 18] = float2int8(p0[A_hstep * 2 + 12] * scalec); + pp[32 + 19] = float2int8(p0[A_hstep * 3 + 12] * scalec); + pp[32 + 20] = float2int8(p0[13] * scaled); + pp[32 + 21] = float2int8(p0[A_hstep + 13] * scaled); + pp[32 + 22] = float2int8(p0[A_hstep * 2 + 13] * scaled); + pp[32 + 23] = float2int8(p0[A_hstep * 3 + 13] * scaled); + pp[32 + 24] = float2int8(p0[14] * scalee); + pp[32 + 25] = float2int8(p0[A_hstep + 14] * scalee); + pp[32 + 26] = float2int8(p0[A_hstep * 2 + 14] * scalee); + pp[32 + 27] = float2int8(p0[A_hstep * 3 + 14] * scalee); + pp[32 + 28] = float2int8(p0[15] * scalef); + pp[32 + 29] = float2int8(p0[A_hstep + 15] * scalef); + pp[32 + 30] = float2int8(p0[A_hstep * 2 + 15] * scalef); + pp[32 + 31] = float2int8(p0[A_hstep * 3 + 15] * scalef); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift8 += pp[32 + 0]; + w_shift8 += pp[32 + 1]; + w_shift8 += pp[32 + 2]; + w_shift8 += pp[32 + 3]; + w_shift9 += pp[32 + 4]; + w_shift9 += pp[32 + 5]; + w_shift9 += pp[32 + 6]; + w_shift9 += pp[32 + 7]; + w_shifta += pp[32 + 8]; + w_shifta += pp[32 + 9]; + w_shifta += pp[32 + 10]; + w_shifta += pp[32 + 11]; + w_shiftb += pp[32 + 12]; + w_shiftb += pp[32 + 13]; + w_shiftb += pp[32 + 14]; + w_shiftb += pp[32 + 15]; + w_shiftc += pp[32 + 16]; + w_shiftc += pp[32 + 17]; + w_shiftc += pp[32 + 18]; + w_shiftc += pp[32 + 19]; + w_shiftd += pp[32 + 20]; + w_shiftd += pp[32 + 21]; + w_shiftd += pp[32 + 22]; + w_shiftd += pp[32 + 23]; + w_shifte += pp[32 + 24]; + w_shifte += pp[32 + 25]; + w_shifte += pp[32 + 26]; + w_shifte += pp[32 + 27]; + w_shiftf += pp[32 + 28]; + w_shiftf += pp[32 + 29]; + w_shiftf += pp[32 + 30]; + w_shiftf += pp[32 + 31]; + + pp += 64; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + ((int*)pp)[8] = w_shift8 * 127; + ((int*)pp)[9] = w_shift9 * 127; + ((int*)pp)[10] = w_shifta * 127; + ((int*)pp)[11] = w_shiftb * 127; + ((int*)pp)[12] = w_shiftc * 127; + ((int*)pp)[13] = w_shiftd * 127; + ((int*)pp)[14] = w_shifte * 127; + ((int*)pp)[15] = w_shiftf * 127; + pp += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -2496,15 +6079,11 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int } } #endif // __AVX512F__ +#if !__AVX2__ + signed char* pp1 = pp + max_kk * 4; +#endif for (; ii + 7 < max_ii; ii += 8) { -#if __AVX2__ - signed char* pp = (signed char*)AT + ii * max_kk; -#else - signed char* pp = (signed char*)AT + ii * max_kk; - signed char* pp1 = (signed char*)AT + (ii + 4) * max_kk; -#endif - const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; const float scale0 = scales[i + ii]; @@ -2520,6 +6099,296 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2 + 0] * scale0); + pp[3] = float2int8(p0[2 + 1] * scale0); + pp[4] = float2int8(p0[16] * scale1); + pp[5] = float2int8(p0[17] * scale1); + pp[6] = float2int8(p0[2 + 16] * scale1); + pp[7] = float2int8(p0[2 + 17] * scale1); + pp[8] = float2int8(p0[32] * scale2); + pp[9] = float2int8(p0[33] * scale2); + pp[10] = float2int8(p0[2 + 32] * scale2); + pp[11] = float2int8(p0[2 + 33] * scale2); + pp[12] = float2int8(p0[48] * scale3); + pp[13] = float2int8(p0[49] * scale3); + pp[14] = float2int8(p0[2 + 48] * scale3); + pp[15] = float2int8(p0[2 + 49] * scale3); + pp[16] = float2int8(p0[64] * scale4); + pp[17] = float2int8(p0[65] * scale4); + pp[18] = float2int8(p0[2 + 64] * scale4); + pp[19] = float2int8(p0[2 + 65] * scale4); + pp[20] = float2int8(p0[80] * scale5); + pp[21] = float2int8(p0[81] * scale5); + pp[22] = float2int8(p0[2 + 80] * scale5); + pp[23] = float2int8(p0[2 + 81] * scale5); + pp[24] = float2int8(p0[96] * scale6); + pp[25] = float2int8(p0[97] * scale6); + pp[26] = float2int8(p0[2 + 96] * scale6); + pp[27] = float2int8(p0[2 + 97] * scale6); + pp[28] = float2int8(p0[112] * scale7); + pp[29] = float2int8(p0[113] * scale7); + pp[30] = float2int8(p0[2 + 112] * scale7); + pp[31] = float2int8(p0[2 + 113] * scale7); + + pp[32 + 0] = float2int8(p0[4 + 0] * scale0); + pp[32 + 1] = float2int8(p0[4 + 1] * scale0); + pp[32 + 2] = float2int8(p0[6 + 0] * scale0); + pp[32 + 3] = float2int8(p0[6 + 1] * scale0); + pp[32 + 4] = float2int8(p0[4 + 16] * scale1); + pp[32 + 5] = float2int8(p0[4 + 17] * scale1); + pp[32 + 6] = float2int8(p0[6 + 16] * scale1); + pp[32 + 7] = float2int8(p0[6 + 17] * scale1); + pp[32 + 8] = float2int8(p0[4 + 32] * scale2); + pp[32 + 9] = float2int8(p0[4 + 33] * scale2); + pp[32 + 10] = float2int8(p0[6 + 32] * scale2); + pp[32 + 11] = float2int8(p0[6 + 33] * scale2); + pp[32 + 12] = float2int8(p0[4 + 48] * scale3); + pp[32 + 13] = float2int8(p0[4 + 49] * scale3); + pp[32 + 14] = float2int8(p0[6 + 48] * scale3); + pp[32 + 15] = float2int8(p0[6 + 49] * scale3); + pp[32 + 16] = float2int8(p0[4 + 64] * scale4); + pp[32 + 17] = float2int8(p0[4 + 65] * scale4); + pp[32 + 18] = float2int8(p0[6 + 64] * scale4); + pp[32 + 19] = float2int8(p0[6 + 65] * scale4); + pp[32 + 20] = float2int8(p0[4 + 80] * scale5); + pp[32 + 21] = float2int8(p0[4 + 81] * scale5); + pp[32 + 22] = float2int8(p0[6 + 80] * scale5); + pp[32 + 23] = float2int8(p0[6 + 81] * scale5); + pp[32 + 24] = float2int8(p0[4 + 96] * scale6); + pp[32 + 25] = float2int8(p0[4 + 97] * scale6); + pp[32 + 26] = float2int8(p0[6 + 96] * scale6); + pp[32 + 27] = float2int8(p0[6 + 97] * scale6); + pp[32 + 28] = float2int8(p0[4 + 112] * scale7); + pp[32 + 29] = float2int8(p0[4 + 113] * scale7); + pp[32 + 30] = float2int8(p0[6 + 112] * scale7); + pp[32 + 31] = float2int8(p0[6 + 113] * scale7); + + pp[64 + 0] = float2int8(p0[8 + 0] * scale0); + pp[64 + 1] = float2int8(p0[8 + 1] * scale0); + pp[64 + 2] = float2int8(p0[10 + 0] * scale0); + pp[64 + 3] = float2int8(p0[10 + 1] * scale0); + pp[64 + 4] = float2int8(p0[8 + 16] * scale1); + pp[64 + 5] = float2int8(p0[8 + 17] * scale1); + pp[64 + 6] = float2int8(p0[10 + 16] * scale1); + pp[64 + 7] = float2int8(p0[10 + 17] * scale1); + pp[64 + 8] = float2int8(p0[8 + 32] * scale2); + pp[64 + 9] = float2int8(p0[8 + 33] * scale2); + pp[64 + 10] = float2int8(p0[10 + 32] * scale2); + pp[64 + 11] = float2int8(p0[10 + 33] * scale2); + pp[64 + 12] = float2int8(p0[8 + 48] * scale3); + pp[64 + 13] = float2int8(p0[8 + 49] * scale3); + pp[64 + 14] = float2int8(p0[10 + 48] * scale3); + pp[64 + 15] = float2int8(p0[10 + 49] * scale3); + pp[64 + 16] = float2int8(p0[8 + 64] * scale4); + pp[64 + 17] = float2int8(p0[8 + 65] * scale4); + pp[64 + 18] = float2int8(p0[10 + 64] * scale4); + pp[64 + 19] = float2int8(p0[10 + 65] * scale4); + pp[64 + 20] = float2int8(p0[8 + 80] * scale5); + pp[64 + 21] = float2int8(p0[8 + 81] * scale5); + pp[64 + 22] = float2int8(p0[10 + 80] * scale5); + pp[64 + 23] = float2int8(p0[10 + 81] * scale5); + pp[64 + 24] = float2int8(p0[8 + 96] * scale6); + pp[64 + 25] = float2int8(p0[8 + 97] * scale6); + pp[64 + 26] = float2int8(p0[10 + 96] * scale6); + pp[64 + 27] = float2int8(p0[10 + 97] * scale6); + pp[64 + 28] = float2int8(p0[8 + 112] * scale7); + pp[64 + 29] = float2int8(p0[8 + 113] * scale7); + pp[64 + 30] = float2int8(p0[10 + 112] * scale7); + pp[64 + 31] = float2int8(p0[10 + 113] * scale7); + + pp[96 + 0] = float2int8(p0[12 + 0] * scale0); + pp[96 + 1] = float2int8(p0[12 + 1] * scale0); + pp[96 + 2] = float2int8(p0[14 + 0] * scale0); + pp[96 + 3] = float2int8(p0[14 + 1] * scale0); + pp[96 + 4] = float2int8(p0[12 + 16] * scale1); + pp[96 + 5] = float2int8(p0[12 + 17] * scale1); + pp[96 + 6] = float2int8(p0[14 + 16] * scale1); + pp[96 + 7] = float2int8(p0[14 + 17] * scale1); + pp[96 + 8] = float2int8(p0[12 + 32] * scale2); + pp[96 + 9] = float2int8(p0[12 + 33] * scale2); + pp[96 + 10] = float2int8(p0[14 + 32] * scale2); + pp[96 + 11] = float2int8(p0[14 + 33] * scale2); + pp[96 + 12] = float2int8(p0[12 + 48] * scale3); + pp[96 + 13] = float2int8(p0[12 + 49] * scale3); + pp[96 + 14] = float2int8(p0[14 + 48] * scale3); + pp[96 + 15] = float2int8(p0[14 + 49] * scale3); + pp[96 + 16] = float2int8(p0[12 + 64] * scale4); + pp[96 + 17] = float2int8(p0[12 + 65] * scale4); + pp[96 + 18] = float2int8(p0[14 + 64] * scale4); + pp[96 + 19] = float2int8(p0[14 + 65] * scale4); + pp[96 + 20] = float2int8(p0[12 + 80] * scale5); + pp[96 + 21] = float2int8(p0[12 + 81] * scale5); + pp[96 + 22] = float2int8(p0[14 + 80] * scale5); + pp[96 + 23] = float2int8(p0[14 + 81] * scale5); + pp[96 + 24] = float2int8(p0[12 + 96] * scale6); + pp[96 + 25] = float2int8(p0[12 + 97] * scale6); + pp[96 + 26] = float2int8(p0[14 + 96] * scale6); + pp[96 + 27] = float2int8(p0[14 + 97] * scale6); + pp[96 + 28] = float2int8(p0[12 + 112] * scale7); + pp[96 + 29] = float2int8(p0[12 + 113] * scale7); + pp[96 + 30] = float2int8(p0[14 + 112] * scale7); + pp[96 + 31] = float2int8(p0[14 + 113] * scale7); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift0 += pp[32 + 0]; + w_shift0 += pp[32 + 1]; + w_shift0 += pp[32 + 2]; + w_shift0 += pp[32 + 3]; + w_shift1 += pp[32 + 4]; + w_shift1 += pp[32 + 5]; + w_shift1 += pp[32 + 6]; + w_shift1 += pp[32 + 7]; + w_shift2 += pp[32 + 8]; + w_shift2 += pp[32 + 9]; + w_shift2 += pp[32 + 10]; + w_shift2 += pp[32 + 11]; + w_shift3 += pp[32 + 12]; + w_shift3 += pp[32 + 13]; + w_shift3 += pp[32 + 14]; + w_shift3 += pp[32 + 15]; + w_shift4 += pp[32 + 16]; + w_shift4 += pp[32 + 17]; + w_shift4 += pp[32 + 18]; + w_shift4 += pp[32 + 19]; + w_shift5 += pp[32 + 20]; + w_shift5 += pp[32 + 21]; + w_shift5 += pp[32 + 22]; + w_shift5 += pp[32 + 23]; + w_shift6 += pp[32 + 24]; + w_shift6 += pp[32 + 25]; + w_shift6 += pp[32 + 26]; + w_shift6 += pp[32 + 27]; + w_shift7 += pp[32 + 28]; + w_shift7 += pp[32 + 29]; + w_shift7 += pp[32 + 30]; + w_shift7 += pp[32 + 31]; + + w_shift0 += pp[64 + 0]; + w_shift0 += pp[64 + 1]; + w_shift0 += pp[64 + 2]; + w_shift0 += pp[64 + 3]; + w_shift1 += pp[64 + 4]; + w_shift1 += pp[64 + 5]; + w_shift1 += pp[64 + 6]; + w_shift1 += pp[64 + 7]; + w_shift2 += pp[64 + 8]; + w_shift2 += pp[64 + 9]; + w_shift2 += pp[64 + 10]; + w_shift2 += pp[64 + 11]; + w_shift3 += pp[64 + 12]; + w_shift3 += pp[64 + 13]; + w_shift3 += pp[64 + 14]; + w_shift3 += pp[64 + 15]; + w_shift4 += pp[64 + 16]; + w_shift4 += pp[64 + 17]; + w_shift4 += pp[64 + 18]; + w_shift4 += pp[64 + 19]; + w_shift5 += pp[64 + 20]; + w_shift5 += pp[64 + 21]; + w_shift5 += pp[64 + 22]; + w_shift5 += pp[64 + 23]; + w_shift6 += pp[64 + 24]; + w_shift6 += pp[64 + 25]; + w_shift6 += pp[64 + 26]; + w_shift6 += pp[64 + 27]; + w_shift7 += pp[64 + 28]; + w_shift7 += pp[64 + 29]; + w_shift7 += pp[64 + 30]; + w_shift7 += pp[64 + 31]; + + w_shift0 += pp[96 + 0]; + w_shift0 += pp[96 + 1]; + w_shift0 += pp[96 + 2]; + w_shift0 += pp[96 + 3]; + w_shift1 += pp[96 + 4]; + w_shift1 += pp[96 + 5]; + w_shift1 += pp[96 + 6]; + w_shift1 += pp[96 + 7]; + w_shift2 += pp[96 + 8]; + w_shift2 += pp[96 + 9]; + w_shift2 += pp[96 + 10]; + w_shift2 += pp[96 + 11]; + w_shift3 += pp[96 + 12]; + w_shift3 += pp[96 + 13]; + w_shift3 += pp[96 + 14]; + w_shift3 += pp[96 + 15]; + w_shift4 += pp[96 + 16]; + w_shift4 += pp[96 + 17]; + w_shift4 += pp[96 + 18]; + w_shift4 += pp[96 + 19]; + w_shift5 += pp[96 + 20]; + w_shift5 += pp[96 + 21]; + w_shift5 += pp[96 + 22]; + w_shift5 += pp[96 + 23]; + w_shift6 += pp[96 + 24]; + w_shift6 += pp[96 + 25]; + w_shift6 += pp[96 + 26]; + w_shift6 += pp[96 + 27]; + w_shift7 += pp[96 + 28]; + w_shift7 += pp[96 + 29]; + w_shift7 += pp[96 + 30]; + w_shift7 += pp[96 + 31]; + pp += 128; + p0 += A_hstep * 16; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { pp[0] = float2int8(p0[0] * scale0); @@ -2661,11 +6530,170 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 128; p0 += A_hstep * 16; } +#endif // __AVX512VNNI__ } #endif // __AVX512F__ if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[8] * scale1); + pp[5] = float2int8(p0[9] * scale1); + pp[6] = float2int8(p0[10] * scale1); + pp[7] = float2int8(p0[11] * scale1); + pp[8] = float2int8(p0[16] * scale2); + pp[9] = float2int8(p0[17] * scale2); + pp[10] = float2int8(p0[18] * scale2); + pp[11] = float2int8(p0[19] * scale2); + pp[12] = float2int8(p0[24] * scale3); + pp[13] = float2int8(p0[25] * scale3); + pp[14] = float2int8(p0[26] * scale3); + pp[15] = float2int8(p0[27] * scale3); + pp[16] = float2int8(p0[32] * scale4); + pp[17] = float2int8(p0[33] * scale4); + pp[18] = float2int8(p0[34] * scale4); + pp[19] = float2int8(p0[35] * scale4); + pp[20] = float2int8(p0[40] * scale5); + pp[21] = float2int8(p0[41] * scale5); + pp[22] = float2int8(p0[42] * scale5); + pp[23] = float2int8(p0[43] * scale5); + pp[24] = float2int8(p0[48] * scale6); + pp[25] = float2int8(p0[49] * scale6); + pp[26] = float2int8(p0[50] * scale6); + pp[27] = float2int8(p0[51] * scale6); + pp[28] = float2int8(p0[56] * scale7); + pp[29] = float2int8(p0[57] * scale7); + pp[30] = float2int8(p0[58] * scale7); + pp[31] = float2int8(p0[59] * scale7); + + pp[32 + 0] = float2int8(p0[4] * scale0); + pp[32 + 1] = float2int8(p0[5] * scale0); + pp[32 + 2] = float2int8(p0[6] * scale0); + pp[32 + 3] = float2int8(p0[7] * scale0); + pp[32 + 4] = float2int8(p0[12] * scale1); + pp[32 + 5] = float2int8(p0[13] * scale1); + pp[32 + 6] = float2int8(p0[14] * scale1); + pp[32 + 7] = float2int8(p0[15] * scale1); + pp[32 + 8] = float2int8(p0[20] * scale2); + pp[32 + 9] = float2int8(p0[21] * scale2); + pp[32 + 10] = float2int8(p0[22] * scale2); + pp[32 + 11] = float2int8(p0[23] * scale2); + pp[32 + 12] = float2int8(p0[28] * scale3); + pp[32 + 13] = float2int8(p0[29] * scale3); + pp[32 + 14] = float2int8(p0[30] * scale3); + pp[32 + 15] = float2int8(p0[31] * scale3); + pp[32 + 16] = float2int8(p0[36] * scale4); + pp[32 + 17] = float2int8(p0[37] * scale4); + pp[32 + 18] = float2int8(p0[38] * scale4); + pp[32 + 19] = float2int8(p0[39] * scale4); + pp[32 + 20] = float2int8(p0[44] * scale5); + pp[32 + 21] = float2int8(p0[45] * scale5); + pp[32 + 22] = float2int8(p0[46] * scale5); + pp[32 + 23] = float2int8(p0[47] * scale5); + pp[32 + 24] = float2int8(p0[52] * scale6); + pp[32 + 25] = float2int8(p0[53] * scale6); + pp[32 + 26] = float2int8(p0[54] * scale6); + pp[32 + 27] = float2int8(p0[55] * scale6); + pp[32 + 28] = float2int8(p0[60] * scale7); + pp[32 + 29] = float2int8(p0[61] * scale7); + pp[32 + 30] = float2int8(p0[62] * scale7); + pp[32 + 31] = float2int8(p0[63] * scale7); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + + w_shift0 += pp[32 + 0]; + w_shift0 += pp[32 + 1]; + w_shift0 += pp[32 + 2]; + w_shift0 += pp[32 + 3]; + w_shift1 += pp[32 + 4]; + w_shift1 += pp[32 + 5]; + w_shift1 += pp[32 + 6]; + w_shift1 += pp[32 + 7]; + w_shift2 += pp[32 + 8]; + w_shift2 += pp[32 + 9]; + w_shift2 += pp[32 + 10]; + w_shift2 += pp[32 + 11]; + w_shift3 += pp[32 + 12]; + w_shift3 += pp[32 + 13]; + w_shift3 += pp[32 + 14]; + w_shift3 += pp[32 + 15]; + w_shift4 += pp[32 + 16]; + w_shift4 += pp[32 + 17]; + w_shift4 += pp[32 + 18]; + w_shift4 += pp[32 + 19]; + w_shift5 += pp[32 + 20]; + w_shift5 += pp[32 + 21]; + w_shift5 += pp[32 + 22]; + w_shift5 += pp[32 + 23]; + w_shift6 += pp[32 + 24]; + w_shift6 += pp[32 + 25]; + w_shift6 += pp[32 + 26]; + w_shift6 += pp[32 + 27]; + w_shift7 += pp[32 + 28]; + w_shift7 += pp[32 + 29]; + w_shift7 += pp[32 + 30]; + w_shift7 += pp[32 + 31]; + pp += 64; + p0 += A_hstep * 8; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { pp[0] = float2int8(p0[0] * scale0); @@ -2794,10 +6822,102 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } +#endif // __AVX512VNNI__ } if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[4] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[6] * scale1); + pp[7] = float2int8(p0[7] * scale1); + pp[8] = float2int8(p0[8] * scale2); + pp[9] = float2int8(p0[9] * scale2); + pp[10] = float2int8(p0[10] * scale2); + pp[11] = float2int8(p0[11] * scale2); + pp[12] = float2int8(p0[12] * scale3); + pp[13] = float2int8(p0[13] * scale3); + pp[14] = float2int8(p0[14] * scale3); + pp[15] = float2int8(p0[15] * scale3); + pp[16] = float2int8(p0[16] * scale4); + pp[17] = float2int8(p0[17] * scale4); + pp[18] = float2int8(p0[18] * scale4); + pp[19] = float2int8(p0[19] * scale4); + pp[20] = float2int8(p0[20] * scale5); + pp[21] = float2int8(p0[21] * scale5); + pp[22] = float2int8(p0[22] * scale5); + pp[23] = float2int8(p0[23] * scale5); + pp[24] = float2int8(p0[24] * scale6); + pp[25] = float2int8(p0[25] * scale6); + pp[26] = float2int8(p0[26] * scale6); + pp[27] = float2int8(p0[27] * scale6); + pp[28] = float2int8(p0[28] * scale7); + pp[29] = float2int8(p0[29] * scale7); + pp[30] = float2int8(p0[30] * scale7); + pp[31] = float2int8(p0[31] * scale7); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + pp += 32; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); @@ -2863,10 +6983,102 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int #endif p0 += A_hstep * 4; } +#endif // __AVX512VNNI__ } if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + int w_shift4 = 0; + int w_shift5 = 0; + int w_shift6 = 0; + int w_shift7 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale0); + pp[2] = float2int8(p0[A_hstep * 2] * scale0); + pp[3] = float2int8(p0[A_hstep * 3] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep * 2 + 1] * scale1); + pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[A_hstep + 2] * scale2); + pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); + pp[11] = float2int8(p0[A_hstep * 3 + 2] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[A_hstep + 3] * scale3); + pp[14] = float2int8(p0[A_hstep * 2 + 3] * scale3); + pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); + pp[16] = float2int8(p0[4] * scale4); + pp[17] = float2int8(p0[A_hstep + 4] * scale4); + pp[18] = float2int8(p0[A_hstep * 2 + 4] * scale4); + pp[19] = float2int8(p0[A_hstep * 3 + 4] * scale4); + pp[20] = float2int8(p0[5] * scale5); + pp[21] = float2int8(p0[A_hstep + 5] * scale5); + pp[22] = float2int8(p0[A_hstep * 2 + 5] * scale5); + pp[23] = float2int8(p0[A_hstep * 3 + 5] * scale5); + pp[24] = float2int8(p0[6] * scale6); + pp[25] = float2int8(p0[A_hstep + 6] * scale6); + pp[26] = float2int8(p0[A_hstep * 2 + 6] * scale6); + pp[27] = float2int8(p0[A_hstep * 3 + 6] * scale6); + pp[28] = float2int8(p0[7] * scale7); + pp[29] = float2int8(p0[A_hstep + 7] * scale7); + pp[30] = float2int8(p0[A_hstep * 2 + 7] * scale7); + pp[31] = float2int8(p0[A_hstep * 3 + 7] * scale7); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + w_shift4 += pp[16]; + w_shift4 += pp[17]; + w_shift4 += pp[18]; + w_shift4 += pp[19]; + w_shift5 += pp[20]; + w_shift5 += pp[21]; + w_shift5 += pp[22]; + w_shift5 += pp[23]; + w_shift6 += pp[24]; + w_shift6 += pp[25]; + w_shift6 += pp[26]; + w_shift6 += pp[27]; + w_shift7 += pp[28]; + w_shift7 += pp[29]; + w_shift7 += pp[30]; + w_shift7 += pp[31]; + pp += 32; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + ((int*)pp)[4] = w_shift4 * 127; + ((int*)pp)[5] = w_shift5 * 127; + ((int*)pp)[6] = w_shift6 * 127; + ((int*)pp)[7] = w_shift7 * 127; + pp += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -2925,9 +7137,9 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int } } } - signed char* pp = (signed char*)AT + ii * max_kk; -#else - signed char* pp = (signed char*)AT; +#if !__AVX2__ + pp = pp1; +#endif #endif // __AVX__ for (; ii + 3 < max_ii; ii += 4) { @@ -2943,6 +7155,161 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2 + 0] * scale0); + pp[3] = float2int8(p0[2 + 1] * scale0); + pp[4] = float2int8(p0[16] * scale1); + pp[5] = float2int8(p0[17] * scale1); + pp[6] = float2int8(p0[2 + 16] * scale1); + pp[7] = float2int8(p0[2 + 17] * scale1); + pp[8] = float2int8(p0[32] * scale2); + pp[9] = float2int8(p0[33] * scale2); + pp[10] = float2int8(p0[2 + 32] * scale2); + pp[11] = float2int8(p0[2 + 33] * scale2); + pp[12] = float2int8(p0[48] * scale3); + pp[13] = float2int8(p0[49] * scale3); + pp[14] = float2int8(p0[2 + 48] * scale3); + pp[15] = float2int8(p0[2 + 49] * scale3); + + pp[16 + 0] = float2int8(p0[4 + 0] * scale0); + pp[16 + 1] = float2int8(p0[4 + 1] * scale0); + pp[16 + 2] = float2int8(p0[6 + 0] * scale0); + pp[16 + 3] = float2int8(p0[6 + 1] * scale0); + pp[16 + 4] = float2int8(p0[4 + 16] * scale1); + pp[16 + 5] = float2int8(p0[4 + 17] * scale1); + pp[16 + 6] = float2int8(p0[6 + 16] * scale1); + pp[16 + 7] = float2int8(p0[6 + 17] * scale1); + pp[16 + 8] = float2int8(p0[4 + 32] * scale2); + pp[16 + 9] = float2int8(p0[4 + 33] * scale2); + pp[16 + 10] = float2int8(p0[6 + 32] * scale2); + pp[16 + 11] = float2int8(p0[6 + 33] * scale2); + pp[16 + 12] = float2int8(p0[4 + 48] * scale3); + pp[16 + 13] = float2int8(p0[4 + 49] * scale3); + pp[16 + 14] = float2int8(p0[6 + 48] * scale3); + pp[16 + 15] = float2int8(p0[6 + 49] * scale3); + + pp[32 + 0] = float2int8(p0[8 + 0] * scale0); + pp[32 + 1] = float2int8(p0[8 + 1] * scale0); + pp[32 + 2] = float2int8(p0[10 + 0] * scale0); + pp[32 + 3] = float2int8(p0[10 + 1] * scale0); + pp[32 + 4] = float2int8(p0[8 + 16] * scale1); + pp[32 + 5] = float2int8(p0[8 + 17] * scale1); + pp[32 + 6] = float2int8(p0[10 + 16] * scale1); + pp[32 + 7] = float2int8(p0[10 + 17] * scale1); + pp[32 + 8] = float2int8(p0[8 + 32] * scale2); + pp[32 + 9] = float2int8(p0[8 + 33] * scale2); + pp[32 + 10] = float2int8(p0[10 + 32] * scale2); + pp[32 + 11] = float2int8(p0[10 + 33] * scale2); + pp[32 + 12] = float2int8(p0[8 + 48] * scale3); + pp[32 + 13] = float2int8(p0[8 + 49] * scale3); + pp[32 + 14] = float2int8(p0[10 + 48] * scale3); + pp[32 + 15] = float2int8(p0[10 + 49] * scale3); + + pp[48 + 0] = float2int8(p0[12 + 0] * scale0); + pp[48 + 1] = float2int8(p0[12 + 1] * scale0); + pp[48 + 2] = float2int8(p0[14 + 0] * scale0); + pp[48 + 3] = float2int8(p0[14 + 1] * scale0); + pp[48 + 4] = float2int8(p0[12 + 16] * scale1); + pp[48 + 5] = float2int8(p0[12 + 17] * scale1); + pp[48 + 6] = float2int8(p0[14 + 16] * scale1); + pp[48 + 7] = float2int8(p0[14 + 17] * scale1); + pp[48 + 8] = float2int8(p0[12 + 32] * scale2); + pp[48 + 9] = float2int8(p0[12 + 33] * scale2); + pp[48 + 10] = float2int8(p0[14 + 32] * scale2); + pp[48 + 11] = float2int8(p0[14 + 33] * scale2); + pp[48 + 12] = float2int8(p0[12 + 48] * scale3); + pp[48 + 13] = float2int8(p0[12 + 49] * scale3); + pp[48 + 14] = float2int8(p0[14 + 48] * scale3); + pp[48 + 15] = float2int8(p0[14 + 49] * scale3); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + + w_shift0 += pp[16 + 0]; + w_shift0 += pp[16 + 1]; + w_shift0 += pp[16 + 2]; + w_shift0 += pp[16 + 3]; + w_shift1 += pp[16 + 4]; + w_shift1 += pp[16 + 5]; + w_shift1 += pp[16 + 6]; + w_shift1 += pp[16 + 7]; + w_shift2 += pp[16 + 8]; + w_shift2 += pp[16 + 9]; + w_shift2 += pp[16 + 10]; + w_shift2 += pp[16 + 11]; + w_shift3 += pp[16 + 12]; + w_shift3 += pp[16 + 13]; + w_shift3 += pp[16 + 14]; + w_shift3 += pp[16 + 15]; + + w_shift0 += pp[32 + 0]; + w_shift0 += pp[32 + 1]; + w_shift0 += pp[32 + 2]; + w_shift0 += pp[32 + 3]; + w_shift1 += pp[32 + 4]; + w_shift1 += pp[32 + 5]; + w_shift1 += pp[32 + 6]; + w_shift1 += pp[32 + 7]; + w_shift2 += pp[32 + 8]; + w_shift2 += pp[32 + 9]; + w_shift2 += pp[32 + 10]; + w_shift2 += pp[32 + 11]; + w_shift3 += pp[32 + 12]; + w_shift3 += pp[32 + 13]; + w_shift3 += pp[32 + 14]; + w_shift3 += pp[32 + 15]; + + w_shift0 += pp[48 + 0]; + w_shift0 += pp[48 + 1]; + w_shift0 += pp[48 + 2]; + w_shift0 += pp[48 + 3]; + w_shift1 += pp[48 + 4]; + w_shift1 += pp[48 + 5]; + w_shift1 += pp[48 + 6]; + w_shift1 += pp[48 + 7]; + w_shift2 += pp[48 + 8]; + w_shift2 += pp[48 + 9]; + w_shift2 += pp[48 + 10]; + w_shift2 += pp[48 + 11]; + w_shift3 += pp[48 + 12]; + w_shift3 += pp[48 + 13]; + w_shift3 += pp[48 + 14]; + w_shift3 += pp[48 + 15]; + + pp += 64; + p0 += A_hstep * 16; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { pp[0] = float2int8(p0[0] * scale0); @@ -3020,11 +7387,98 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 64; p0 += A_hstep * 16; } - } -#endif // __AVX512F__ - if (elempack == 8) - { - int kk = 0; +#endif // __AVX512VNNI__ + } +#endif // __AVX512F__ + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[8] * scale1); + pp[5] = float2int8(p0[9] * scale1); + pp[6] = float2int8(p0[10] * scale1); + pp[7] = float2int8(p0[11] * scale1); + pp[8] = float2int8(p0[16] * scale2); + pp[9] = float2int8(p0[17] * scale2); + pp[10] = float2int8(p0[18] * scale2); + pp[11] = float2int8(p0[19] * scale2); + pp[12] = float2int8(p0[24] * scale3); + pp[13] = float2int8(p0[25] * scale3); + pp[14] = float2int8(p0[26] * scale3); + pp[15] = float2int8(p0[27] * scale3); + + pp[16 + 0] = float2int8(p0[4] * scale0); + pp[16 + 1] = float2int8(p0[5] * scale0); + pp[16 + 2] = float2int8(p0[6] * scale0); + pp[16 + 3] = float2int8(p0[7] * scale0); + pp[16 + 4] = float2int8(p0[12] * scale1); + pp[16 + 5] = float2int8(p0[13] * scale1); + pp[16 + 6] = float2int8(p0[14] * scale1); + pp[16 + 7] = float2int8(p0[15] * scale1); + pp[16 + 8] = float2int8(p0[20] * scale2); + pp[16 + 9] = float2int8(p0[21] * scale2); + pp[16 + 10] = float2int8(p0[22] * scale2); + pp[16 + 11] = float2int8(p0[23] * scale2); + pp[16 + 12] = float2int8(p0[28] * scale3); + pp[16 + 13] = float2int8(p0[29] * scale3); + pp[16 + 14] = float2int8(p0[30] * scale3); + pp[16 + 15] = float2int8(p0[31] * scale3); + + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + + w_shift0 += pp[16 + 0]; + w_shift0 += pp[16 + 1]; + w_shift0 += pp[16 + 2]; + w_shift0 += pp[16 + 3]; + w_shift1 += pp[16 + 4]; + w_shift1 += pp[16 + 5]; + w_shift1 += pp[16 + 6]; + w_shift1 += pp[16 + 7]; + w_shift2 += pp[16 + 8]; + w_shift2 += pp[16 + 9]; + w_shift2 += pp[16 + 10]; + w_shift2 += pp[16 + 11]; + w_shift3 += pp[16 + 12]; + w_shift3 += pp[16 + 13]; + w_shift3 += pp[16 + 14]; + w_shift3 += pp[16 + 15]; + pp += 32; + p0 += A_hstep * 8; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { pp[0] = float2int8(p0[0] * scale0); @@ -3066,11 +7520,63 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 32; p0 += A_hstep * 8; } +#endif // __AVX512VNNI__ } #endif // __AVX__ if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[4] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[6] * scale1); + pp[7] = float2int8(p0[7] * scale1); + pp[8] = float2int8(p0[8] * scale2); + pp[9] = float2int8(p0[9] * scale2); + pp[10] = float2int8(p0[10] * scale2); + pp[11] = float2int8(p0[11] * scale2); + pp[12] = float2int8(p0[12] * scale3); + pp[13] = float2int8(p0[13] * scale3); + pp[14] = float2int8(p0[14] * scale3); + pp[15] = float2int8(p0[15] * scale3); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + pp += 16; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); @@ -3093,10 +7599,62 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 16; p0 += A_hstep * 4; } +#endif // __AVX512VNNI__ } if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + int w_shift2 = 0; + int w_shift3 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale0); + pp[2] = float2int8(p0[A_hstep * 2] * scale0); + pp[3] = float2int8(p0[A_hstep * 3] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep * 2 + 1] * scale1); + pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale1); + pp[8] = float2int8(p0[2] * scale2); + pp[9] = float2int8(p0[A_hstep + 2] * scale2); + pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); + pp[11] = float2int8(p0[A_hstep * 3 + 2] * scale2); + pp[12] = float2int8(p0[3] * scale3); + pp[13] = float2int8(p0[A_hstep + 3] * scale3); + pp[14] = float2int8(p0[A_hstep * 2 + 3] * scale3); + pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift2 += pp[8]; + w_shift2 += pp[9]; + w_shift2 += pp[10]; + w_shift2 += pp[11]; + w_shift3 += pp[12]; + w_shift3 += pp[13]; + w_shift3 += pp[14]; + w_shift3 += pp[15]; + pp += 16; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + ((int*)pp)[2] = w_shift2 * 127; + ((int*)pp)[3] = w_shift3 * 127; + pp += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -3123,8 +7681,6 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int } } } -#else - signed char* pp = (signed char*)AT; #endif // __SSE2__ for (; ii + 1 < max_ii; ii += 2) { @@ -3139,8 +7695,81 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; +#endif // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { +#if __AVX512VNNI__ + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[16] * scale1); + pp[5] = float2int8(p0[17] * scale1); + pp[6] = float2int8(p0[18] * scale1); + pp[7] = float2int8(p0[19] * scale1); + + pp[8] = float2int8(p0[4] * scale0); + pp[9] = float2int8(p0[5] * scale0); + pp[10] = float2int8(p0[6] * scale0); + pp[11] = float2int8(p0[7] * scale0); + pp[12] = float2int8(p0[20] * scale1); + pp[13] = float2int8(p0[21] * scale1); + pp[14] = float2int8(p0[22] * scale1); + pp[15] = float2int8(p0[23] * scale1); + + pp[16 + 0] = float2int8(p0[8] * scale0); + pp[16 + 1] = float2int8(p0[9] * scale0); + pp[16 + 2] = float2int8(p0[10] * scale0); + pp[16 + 3] = float2int8(p0[11] * scale0); + pp[16 + 4] = float2int8(p0[24] * scale1); + pp[16 + 5] = float2int8(p0[25] * scale1); + pp[16 + 6] = float2int8(p0[26] * scale1); + pp[16 + 7] = float2int8(p0[27] * scale1); + + pp[16 + 8] = float2int8(p0[12] * scale0); + pp[16 + 9] = float2int8(p0[13] * scale0); + pp[16 + 10] = float2int8(p0[14] * scale0); + pp[16 + 11] = float2int8(p0[15] * scale0); + pp[16 + 12] = float2int8(p0[28] * scale1); + pp[16 + 13] = float2int8(p0[29] * scale1); + pp[16 + 14] = float2int8(p0[30] * scale1); + pp[16 + 15] = float2int8(p0[31] * scale1); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift0 += pp[8]; + w_shift0 += pp[9]; + w_shift0 += pp[10]; + w_shift0 += pp[11]; + w_shift1 += pp[12]; + w_shift1 += pp[13]; + w_shift1 += pp[14]; + w_shift1 += pp[15]; + w_shift0 += pp[16 + 0]; + w_shift0 += pp[16 + 1]; + w_shift0 += pp[16 + 2]; + w_shift0 += pp[16 + 3]; + w_shift1 += pp[16 + 4]; + w_shift1 += pp[16 + 5]; + w_shift1 += pp[16 + 6]; + w_shift1 += pp[16 + 7]; + w_shift0 += pp[16 + 8]; + w_shift0 += pp[16 + 9]; + w_shift0 += pp[16 + 10]; + w_shift0 += pp[16 + 11]; + w_shift1 += pp[16 + 12]; + w_shift1 += pp[16 + 13]; + w_shift1 += pp[16 + 14]; + w_shift1 += pp[16 + 15]; +#else // __AVX512VNNI__ pp[0] = float2int8(p0[0] * scale0); pp[1] = float2int8(p0[1] * scale0); pp[2] = float2int8(p0[16] * scale1); @@ -3180,17 +7809,63 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp[16 + 13] = float2int8(p0[15] * scale0); pp[16 + 14] = float2int8(p0[30] * scale1); pp[16 + 15] = float2int8(p0[31] * scale1); - +#endif // __AVX512VNNI__ pp += 32; p0 += A_hstep * 16; } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ } #endif // __AVX512F__ if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; +#endif // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { +#if __AVX512VNNI__ + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[8] * scale1); + pp[5] = float2int8(p0[9] * scale1); + pp[6] = float2int8(p0[10] * scale1); + pp[7] = float2int8(p0[11] * scale1); + pp[8] = float2int8(p0[4] * scale0); + pp[9] = float2int8(p0[5] * scale0); + pp[10] = float2int8(p0[6] * scale0); + pp[11] = float2int8(p0[7] * scale0); + pp[12] = float2int8(p0[12] * scale1); + pp[13] = float2int8(p0[13] * scale1); + pp[14] = float2int8(p0[14] * scale1); + pp[15] = float2int8(p0[15] * scale1); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + w_shift0 += pp[8]; + w_shift0 += pp[9]; + w_shift0 += pp[10]; + w_shift0 += pp[11]; + w_shift1 += pp[12]; + w_shift1 += pp[13]; + w_shift1 += pp[14]; + w_shift1 += pp[15]; +#else // __AVX512VNNI__ pp[0] = float2int8(p0[0] * scale0); pp[1] = float2int8(p0[1] * scale0); pp[2] = float2int8(p0[8] * scale1); @@ -3207,17 +7882,47 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp[13] = float2int8(p0[7] * scale0); pp[14] = float2int8(p0[14] * scale1); pp[15] = float2int8(p0[15] * scale1); - +#endif // __AVX512VNNI__ pp += 16; p0 += A_hstep * 8; } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ } #endif // __AVX__ if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; +#endif // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { +#if __AVX512VNNI__ + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[2] * scale0); + pp[3] = float2int8(p0[3] * scale0); + pp[4] = float2int8(p0[4] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[6] * scale1); + pp[7] = float2int8(p0[7] * scale1); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; +#else // __AVX512VNNI__ pp[0] = float2int8(p0[0] * scale0); pp[1] = float2int8(p0[1] * scale0); pp[2] = float2int8(p0[4] * scale1); @@ -3226,10 +7931,18 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp[5] = float2int8(p0[3] * scale0); pp[6] = float2int8(p0[6] * scale1); pp[7] = float2int8(p0[7] * scale1); - +#endif // __AVX512VNNI__ pp += 8; p0 += A_hstep * 4; } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ } #endif // __SSE2__ if (elempack == 1) @@ -3238,6 +7951,37 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + int w_shift0 = 0; + int w_shift1 = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep + 0] * scale0); + pp[2] = float2int8(p0[A_hstep * 2 + 0] * scale0); + pp[3] = float2int8(p0[A_hstep * 3 + 0] * scale0); + pp[4] = float2int8(p0[1] * scale1); + pp[5] = float2int8(p0[A_hstep + 1] * scale1); + pp[6] = float2int8(p0[A_hstep * 2 + 1] * scale1); + pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale1); + w_shift0 += pp[0]; + w_shift0 += pp[1]; + w_shift0 += pp[2]; + w_shift0 += pp[3]; + w_shift1 += pp[4]; + w_shift1 += pp[5]; + w_shift1 += pp[6]; + w_shift1 += pp[7]; + pp += 8; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift0 * 127; + ((int*)pp)[1] = w_shift1 * 127; + pp += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -3263,12 +8007,20 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int const float scale = scales[i + ii]; + // if (max_kk == 32) + // { + // NCNN_LOGE("===== %p %d %f", p0, p0[0], scale); + // } + #if __SSE2__ #if __AVX__ #if __AVX512F__ if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + int w_shift = 0; +#endif // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { pp[0] = float2int8(p0[0] * scale); @@ -3287,14 +8039,43 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp[13] = float2int8(p0[13] * scale); pp[14] = float2int8(p0[14] * scale); pp[15] = float2int8(p0[15] * scale); + +#if __AVX512VNNI__ + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + w_shift += pp[4]; + w_shift += pp[5]; + w_shift += pp[6]; + w_shift += pp[7]; + w_shift += pp[8]; + w_shift += pp[9]; + w_shift += pp[10]; + w_shift += pp[11]; + w_shift += pp[12]; + w_shift += pp[13]; + w_shift += pp[14]; + w_shift += pp[15]; +#endif // __AVX512VNNI__ pp += 16; p0 += A_hstep * 16; } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ } #endif // __AVX512F__ if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + int w_shift = 0; +#endif // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { pp[0] = float2int8(p0[0] * scale); @@ -3305,28 +8086,82 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp[5] = float2int8(p0[5] * scale); pp[6] = float2int8(p0[6] * scale); pp[7] = float2int8(p0[7] * scale); +#if __AVX512VNNI__ + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + w_shift += pp[4]; + w_shift += pp[5]; + w_shift += pp[6]; + w_shift += pp[7]; +#endif // __AVX512VNNI__ pp += 8; p0 += A_hstep * 8; } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ } #endif // __AVX__ if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + int w_shift = 0; +#endif // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale); pp[1] = float2int8(p0[1] * scale); pp[2] = float2int8(p0[2] * scale); pp[3] = float2int8(p0[3] * scale); +#if __AVX512VNNI__ + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; +#endif // __AVX512VNNI__ pp += 4; p0 += A_hstep * 4; } +#if __AVX512VNNI__ + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ } #endif // __SSE2__ if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + int w_shift = 0; + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[A_hstep] * scale); + pp[2] = float2int8(p0[A_hstep * 2] * scale); + pp[3] = float2int8(p0[A_hstep * 3] * scale); + w_shift += pp[0]; + w_shift += pp[1]; + w_shift += pp[2]; + w_shift += pp[3]; + pp += 4; + p0 += A_hstep * 4; + } + if (max_kk >= 4) + { + ((int*)pp)[0] = w_shift * 127; + pp += 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -3339,7 +8174,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int static void compute_B_fp32_int8_scale(const Mat& B, float& scale) { - NCNN_LOGE("compute_B_fp32_int8_scale"); + // NCNN_LOGE("compute_B_fp32_int8_scale"); float absmax = 0.f; #if __SSE2__ @@ -3404,10 +8239,26 @@ static void compute_B_fp32_int8_scale(const Mat& B, float& scale) static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + pack_B_tile_fp32_to_int8_avx512vnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + pack_B_tile_fp32_to_int8_avxvnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + const int elempack = B.elempack; const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; - NCNN_LOGE("pack_B_tile_fp32_to_int8 %d %d %d", max_jj, max_kk, elempack); + // NCNN_LOGE("pack_B_tile_fp32_to_int8 %d %d %d", max_jj, max_kk, elempack); signed char* pp = BT; @@ -3422,6 +8273,77 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[16] * scale) + 127; + pp[2] = float2int8(p0[32] * scale) + 127; + pp[3] = float2int8(p0[48] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[17] * scale) + 127; + pp[6] = float2int8(p0[33] * scale) + 127; + pp[7] = float2int8(p0[49] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[18] * scale) + 127; + pp[10] = float2int8(p0[34] * scale) + 127; + pp[11] = float2int8(p0[50] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[19] * scale) + 127; + pp[14] = float2int8(p0[35] * scale) + 127; + pp[15] = float2int8(p0[51] * scale) + 127; + pp[16] = float2int8(p0[4] * scale) + 127; + pp[17] = float2int8(p0[20] * scale) + 127; + pp[18] = float2int8(p0[36] * scale) + 127; + pp[19] = float2int8(p0[52] * scale) + 127; + pp[20] = float2int8(p0[5] * scale) + 127; + pp[21] = float2int8(p0[21] * scale) + 127; + pp[22] = float2int8(p0[37] * scale) + 127; + pp[23] = float2int8(p0[53] * scale) + 127; + pp[24] = float2int8(p0[6] * scale) + 127; + pp[25] = float2int8(p0[22] * scale) + 127; + pp[26] = float2int8(p0[38] * scale) + 127; + pp[27] = float2int8(p0[54] * scale) + 127; + pp[28] = float2int8(p0[7] * scale) + 127; + pp[29] = float2int8(p0[23] * scale) + 127; + pp[30] = float2int8(p0[39] * scale) + 127; + pp[31] = float2int8(p0[55] * scale) + 127; + pp[32] = float2int8(p0[8] * scale) + 127; + pp[33] = float2int8(p0[24] * scale) + 127; + pp[34] = float2int8(p0[40] * scale) + 127; + pp[35] = float2int8(p0[56] * scale) + 127; + pp[36] = float2int8(p0[9] * scale) + 127; + pp[37] = float2int8(p0[25] * scale) + 127; + pp[38] = float2int8(p0[41] * scale) + 127; + pp[39] = float2int8(p0[57] * scale) + 127; + pp[40] = float2int8(p0[10] * scale) + 127; + pp[41] = float2int8(p0[26] * scale) + 127; + pp[42] = float2int8(p0[42] * scale) + 127; + pp[43] = float2int8(p0[58] * scale) + 127; + pp[44] = float2int8(p0[11] * scale) + 127; + pp[45] = float2int8(p0[27] * scale) + 127; + pp[46] = float2int8(p0[43] * scale) + 127; + pp[47] = float2int8(p0[59] * scale) + 127; + pp[48] = float2int8(p0[12] * scale) + 127; + pp[49] = float2int8(p0[28] * scale) + 127; + pp[50] = float2int8(p0[44] * scale) + 127; + pp[51] = float2int8(p0[60] * scale) + 127; + pp[52] = float2int8(p0[13] * scale) + 127; + pp[53] = float2int8(p0[29] * scale) + 127; + pp[54] = float2int8(p0[45] * scale) + 127; + pp[55] = float2int8(p0[61] * scale) + 127; + pp[56] = float2int8(p0[14] * scale) + 127; + pp[57] = float2int8(p0[30] * scale) + 127; + pp[58] = float2int8(p0[46] * scale) + 127; + pp[59] = float2int8(p0[62] * scale) + 127; + pp[60] = float2int8(p0[15] * scale) + 127; + pp[61] = float2int8(p0[31] * scale) + 127; + pp[62] = float2int8(p0[47] * scale) + 127; + pp[63] = float2int8(p0[63] * scale) + 127; + pp += 64; + p0 += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3484,6 +8406,78 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[8] * scale) + 127; + pp[2] = float2int8(p0[16] * scale) + 127; + pp[3] = float2int8(p0[24] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[9] * scale) + 127; + pp[6] = float2int8(p0[17] * scale) + 127; + pp[7] = float2int8(p0[25] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[10] * scale) + 127; + pp[10] = float2int8(p0[18] * scale) + 127; + pp[11] = float2int8(p0[26] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[11] * scale) + 127; + pp[14] = float2int8(p0[19] * scale) + 127; + pp[15] = float2int8(p0[27] * scale) + 127; + pp[16] = float2int8(p0[4] * scale) + 127; + pp[17] = float2int8(p0[12] * scale) + 127; + pp[18] = float2int8(p0[20] * scale) + 127; + pp[19] = float2int8(p0[28] * scale) + 127; + pp[20] = float2int8(p0[5] * scale) + 127; + pp[21] = float2int8(p0[13] * scale) + 127; + pp[22] = float2int8(p0[21] * scale) + 127; + pp[23] = float2int8(p0[29] * scale) + 127; + pp[24] = float2int8(p0[6] * scale) + 127; + pp[25] = float2int8(p0[14] * scale) + 127; + pp[26] = float2int8(p0[22] * scale) + 127; + pp[27] = float2int8(p0[30] * scale) + 127; + pp[28] = float2int8(p0[7] * scale) + 127; + pp[29] = float2int8(p0[15] * scale) + 127; + pp[30] = float2int8(p0[23] * scale) + 127; + pp[31] = float2int8(p0[31] * scale) + 127; + + pp[32 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale) + 127; + pp[32 + 1] = float2int8(p0[B_hstep * 8 + 8] * scale) + 127; + pp[32 + 2] = float2int8(p0[B_hstep * 8 + 16] * scale) + 127; + pp[32 + 3] = float2int8(p0[B_hstep * 8 + 24] * scale) + 127; + pp[32 + 4] = float2int8(p0[B_hstep * 8 + 1] * scale) + 127; + pp[32 + 5] = float2int8(p0[B_hstep * 8 + 9] * scale) + 127; + pp[32 + 6] = float2int8(p0[B_hstep * 8 + 17] * scale) + 127; + pp[32 + 7] = float2int8(p0[B_hstep * 8 + 25] * scale) + 127; + pp[32 + 8] = float2int8(p0[B_hstep * 8 + 2] * scale) + 127; + pp[32 + 9] = float2int8(p0[B_hstep * 8 + 10] * scale) + 127; + pp[32 + 10] = float2int8(p0[B_hstep * 8 + 18] * scale) + 127; + pp[32 + 11] = float2int8(p0[B_hstep * 8 + 26] * scale) + 127; + pp[32 + 12] = float2int8(p0[B_hstep * 8 + 3] * scale) + 127; + pp[32 + 13] = float2int8(p0[B_hstep * 8 + 11] * scale) + 127; + pp[32 + 14] = float2int8(p0[B_hstep * 8 + 19] * scale) + 127; + pp[32 + 15] = float2int8(p0[B_hstep * 8 + 27] * scale) + 127; + pp[32 + 16] = float2int8(p0[B_hstep * 8 + 4] * scale) + 127; + pp[32 + 17] = float2int8(p0[B_hstep * 8 + 12] * scale) + 127; + pp[32 + 18] = float2int8(p0[B_hstep * 8 + 20] * scale) + 127; + pp[32 + 19] = float2int8(p0[B_hstep * 8 + 28] * scale) + 127; + pp[32 + 20] = float2int8(p0[B_hstep * 8 + 5] * scale) + 127; + pp[32 + 21] = float2int8(p0[B_hstep * 8 + 13] * scale) + 127; + pp[32 + 22] = float2int8(p0[B_hstep * 8 + 21] * scale) + 127; + pp[32 + 23] = float2int8(p0[B_hstep * 8 + 29] * scale) + 127; + pp[32 + 24] = float2int8(p0[B_hstep * 8 + 6] * scale) + 127; + pp[32 + 25] = float2int8(p0[B_hstep * 8 + 14] * scale) + 127; + pp[32 + 26] = float2int8(p0[B_hstep * 8 + 22] * scale) + 127; + pp[32 + 27] = float2int8(p0[B_hstep * 8 + 30] * scale) + 127; + pp[32 + 28] = float2int8(p0[B_hstep * 8 + 7] * scale) + 127; + pp[32 + 29] = float2int8(p0[B_hstep * 8 + 15] * scale) + 127; + pp[32 + 30] = float2int8(p0[B_hstep * 8 + 23] * scale) + 127; + pp[32 + 31] = float2int8(p0[B_hstep * 8 + 31] * scale) + 127; + pp += 64; + p0 += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3547,6 +8541,80 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[4] * scale) + 127; + pp[2] = float2int8(p0[8] * scale) + 127; + pp[3] = float2int8(p0[12] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[5] * scale) + 127; + pp[6] = float2int8(p0[9] * scale) + 127; + pp[7] = float2int8(p0[13] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[6] * scale) + 127; + pp[10] = float2int8(p0[10] * scale) + 127; + pp[11] = float2int8(p0[14] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[7] * scale) + 127; + pp[14] = float2int8(p0[11] * scale) + 127; + pp[15] = float2int8(p0[15] * scale) + 127; + pp[16 + 0] = float2int8(p0[B_hstep * 4 + 0] * scale) + 127; + pp[16 + 1] = float2int8(p0[B_hstep * 4 + 4] * scale) + 127; + pp[16 + 2] = float2int8(p0[B_hstep * 4 + 8] * scale) + 127; + pp[16 + 3] = float2int8(p0[B_hstep * 4 + 12] * scale) + 127; + pp[16 + 4] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; + pp[16 + 5] = float2int8(p0[B_hstep * 4 + 5] * scale) + 127; + pp[16 + 6] = float2int8(p0[B_hstep * 4 + 9] * scale) + 127; + pp[16 + 7] = float2int8(p0[B_hstep * 4 + 13] * scale) + 127; + pp[16 + 8] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; + pp[16 + 9] = float2int8(p0[B_hstep * 4 + 6] * scale) + 127; + pp[16 + 10] = float2int8(p0[B_hstep * 4 + 10] * scale) + 127; + pp[16 + 11] = float2int8(p0[B_hstep * 4 + 14] * scale) + 127; + pp[16 + 12] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; + pp[16 + 13] = float2int8(p0[B_hstep * 4 + 7] * scale) + 127; + pp[16 + 14] = float2int8(p0[B_hstep * 4 + 11] * scale) + 127; + pp[16 + 15] = float2int8(p0[B_hstep * 4 + 15] * scale) + 127; + + pp[32 + 0] = float2int8(p0[B_hstep * 8 + 0] * scale) + 127; + pp[32 + 1] = float2int8(p0[B_hstep * 8 + 4] * scale) + 127; + pp[32 + 2] = float2int8(p0[B_hstep * 8 + 8] * scale) + 127; + pp[32 + 3] = float2int8(p0[B_hstep * 8 + 12] * scale) + 127; + pp[32 + 4] = float2int8(p0[B_hstep * 8 + 1] * scale) + 127; + pp[32 + 5] = float2int8(p0[B_hstep * 8 + 5] * scale) + 127; + pp[32 + 6] = float2int8(p0[B_hstep * 8 + 9] * scale) + 127; + pp[32 + 7] = float2int8(p0[B_hstep * 8 + 13] * scale) + 127; + pp[32 + 8] = float2int8(p0[B_hstep * 8 + 2] * scale) + 127; + pp[32 + 9] = float2int8(p0[B_hstep * 8 + 6] * scale) + 127; + pp[32 + 10] = float2int8(p0[B_hstep * 8 + 10] * scale) + 127; + pp[32 + 11] = float2int8(p0[B_hstep * 8 + 14] * scale) + 127; + pp[32 + 12] = float2int8(p0[B_hstep * 8 + 3] * scale) + 127; + pp[32 + 13] = float2int8(p0[B_hstep * 8 + 7] * scale) + 127; + pp[32 + 14] = float2int8(p0[B_hstep * 8 + 11] * scale) + 127; + pp[32 + 15] = float2int8(p0[B_hstep * 8 + 15] * scale) + 127; + + pp[48 + 0] = float2int8(p0[B_hstep * 12 + 0] * scale) + 127; + pp[48 + 1] = float2int8(p0[B_hstep * 12 + 4] * scale) + 127; + pp[48 + 2] = float2int8(p0[B_hstep * 12 + 8] * scale) + 127; + pp[48 + 3] = float2int8(p0[B_hstep * 12 + 12] * scale) + 127; + pp[48 + 4] = float2int8(p0[B_hstep * 12 + 1] * scale) + 127; + pp[48 + 5] = float2int8(p0[B_hstep * 12 + 5] * scale) + 127; + pp[48 + 6] = float2int8(p0[B_hstep * 12 + 9] * scale) + 127; + pp[48 + 7] = float2int8(p0[B_hstep * 12 + 13] * scale) + 127; + pp[48 + 8] = float2int8(p0[B_hstep * 12 + 2] * scale) + 127; + pp[48 + 9] = float2int8(p0[B_hstep * 12 + 6] * scale) + 127; + pp[48 + 10] = float2int8(p0[B_hstep * 12 + 10] * scale) + 127; + pp[48 + 11] = float2int8(p0[B_hstep * 12 + 14] * scale) + 127; + pp[48 + 12] = float2int8(p0[B_hstep * 12 + 3] * scale) + 127; + pp[48 + 13] = float2int8(p0[B_hstep * 12 + 7] * scale) + 127; + pp[48 + 14] = float2int8(p0[B_hstep * 12 + 11] * scale) + 127; + pp[48 + 15] = float2int8(p0[B_hstep * 12 + 15] * scale) + 127; + + pp += 64; + p0 += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3612,6 +8680,79 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[B_hstep] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep + 2] * scale) + 127; + pp[7] = float2int8(p0[B_hstep + 3] * scale) + 127; + pp[8] = float2int8(p0[B_hstep * 2] * scale) + 127; + pp[9] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; + pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; + pp[11] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; + pp[12] = float2int8(p0[B_hstep * 3] * scale) + 127; + pp[13] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; + pp[14] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; + pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; + pp[16] = float2int8(p0[B_hstep * 4] * scale) + 127; + pp[17] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; + pp[18] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; + pp[19] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; + pp[20] = float2int8(p0[B_hstep * 5] * scale) + 127; + pp[21] = float2int8(p0[B_hstep * 5 + 1] * scale) + 127; + pp[22] = float2int8(p0[B_hstep * 5 + 2] * scale) + 127; + pp[23] = float2int8(p0[B_hstep * 5 + 3] * scale) + 127; + pp[24] = float2int8(p0[B_hstep * 6] * scale) + 127; + pp[25] = float2int8(p0[B_hstep * 6 + 1] * scale) + 127; + pp[26] = float2int8(p0[B_hstep * 6 + 2] * scale) + 127; + pp[27] = float2int8(p0[B_hstep * 6 + 3] * scale) + 127; + pp[28] = float2int8(p0[B_hstep * 7] * scale) + 127; + pp[29] = float2int8(p0[B_hstep * 7 + 1] * scale) + 127; + pp[30] = float2int8(p0[B_hstep * 7 + 2] * scale) + 127; + pp[31] = float2int8(p0[B_hstep * 7 + 3] * scale) + 127; + + pp[32 + 0] = float2int8(p0[B_hstep * 8] * scale) + 127; + pp[32 + 1] = float2int8(p0[B_hstep * 8 + 1] * scale) + 127; + pp[32 + 2] = float2int8(p0[B_hstep * 8 + 2] * scale) + 127; + pp[32 + 3] = float2int8(p0[B_hstep * 8 + 3] * scale) + 127; + pp[32 + 4] = float2int8(p0[B_hstep * 9] * scale) + 127; + pp[32 + 5] = float2int8(p0[B_hstep * 9 + 1] * scale) + 127; + pp[32 + 6] = float2int8(p0[B_hstep * 9 + 2] * scale) + 127; + pp[32 + 7] = float2int8(p0[B_hstep * 9 + 3] * scale) + 127; + pp[32 + 8] = float2int8(p0[B_hstep * 10] * scale) + 127; + pp[32 + 9] = float2int8(p0[B_hstep * 10 + 1] * scale) + 127; + pp[32 + 10] = float2int8(p0[B_hstep * 10 + 2] * scale) + 127; + pp[32 + 11] = float2int8(p0[B_hstep * 10 + 3] * scale) + 127; + pp[32 + 12] = float2int8(p0[B_hstep * 11] * scale) + 127; + pp[32 + 13] = float2int8(p0[B_hstep * 11 + 1] * scale) + 127; + pp[32 + 14] = float2int8(p0[B_hstep * 11 + 2] * scale) + 127; + pp[32 + 15] = float2int8(p0[B_hstep * 11 + 3] * scale) + 127; + pp[32 + 16] = float2int8(p0[B_hstep * 12] * scale) + 127; + pp[32 + 17] = float2int8(p0[B_hstep * 12 + 1] * scale) + 127; + pp[32 + 18] = float2int8(p0[B_hstep * 12 + 2] * scale) + 127; + pp[32 + 19] = float2int8(p0[B_hstep * 12 + 3] * scale) + 127; + pp[32 + 20] = float2int8(p0[B_hstep * 13] * scale) + 127; + pp[32 + 21] = float2int8(p0[B_hstep * 13 + 1] * scale) + 127; + pp[32 + 22] = float2int8(p0[B_hstep * 13 + 2] * scale) + 127; + pp[32 + 23] = float2int8(p0[B_hstep * 13 + 3] * scale) + 127; + pp[32 + 24] = float2int8(p0[B_hstep * 14] * scale) + 127; + pp[32 + 25] = float2int8(p0[B_hstep * 14 + 1] * scale) + 127; + pp[32 + 26] = float2int8(p0[B_hstep * 14 + 2] * scale) + 127; + pp[32 + 27] = float2int8(p0[B_hstep * 14 + 3] * scale) + 127; + pp[32 + 28] = float2int8(p0[B_hstep * 15] * scale) + 127; + pp[32 + 29] = float2int8(p0[B_hstep * 15 + 1] * scale) + 127; + pp[32 + 30] = float2int8(p0[B_hstep * 15 + 2] * scale) + 127; + pp[32 + 31] = float2int8(p0[B_hstep * 15 + 3] * scale) + 127; + + pp += 64; + p0 += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3682,6 +8823,45 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[8] * scale) + 127; + pp[2] = float2int8(p0[16] * scale) + 127; + pp[3] = float2int8(p0[24] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[9] * scale) + 127; + pp[6] = float2int8(p0[17] * scale) + 127; + pp[7] = float2int8(p0[25] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[10] * scale) + 127; + pp[10] = float2int8(p0[18] * scale) + 127; + pp[11] = float2int8(p0[26] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[11] * scale) + 127; + pp[14] = float2int8(p0[19] * scale) + 127; + pp[15] = float2int8(p0[27] * scale) + 127; + pp[16 + 0] = float2int8(p0[4] * scale) + 127; + pp[16 + 1] = float2int8(p0[12] * scale) + 127; + pp[16 + 2] = float2int8(p0[20] * scale) + 127; + pp[16 + 3] = float2int8(p0[28] * scale) + 127; + pp[16 + 4] = float2int8(p0[5] * scale) + 127; + pp[16 + 5] = float2int8(p0[13] * scale) + 127; + pp[16 + 6] = float2int8(p0[21] * scale) + 127; + pp[16 + 7] = float2int8(p0[29] * scale) + 127; + pp[16 + 8] = float2int8(p0[6] * scale) + 127; + pp[16 + 9] = float2int8(p0[14] * scale) + 127; + pp[16 + 10] = float2int8(p0[22] * scale) + 127; + pp[16 + 11] = float2int8(p0[30] * scale) + 127; + pp[16 + 12] = float2int8(p0[7] * scale) + 127; + pp[16 + 13] = float2int8(p0[15] * scale) + 127; + pp[16 + 14] = float2int8(p0[23] * scale) + 127; + pp[16 + 15] = float2int8(p0[31] * scale) + 127; + pp += 32; + p0 += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3723,6 +8903,45 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[4] * scale) + 127; + pp[2] = float2int8(p0[8] * scale) + 127; + pp[3] = float2int8(p0[12] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[5] * scale) + 127; + pp[6] = float2int8(p0[9] * scale) + 127; + pp[7] = float2int8(p0[13] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[6] * scale) + 127; + pp[10] = float2int8(p0[10] * scale) + 127; + pp[11] = float2int8(p0[14] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[7] * scale) + 127; + pp[14] = float2int8(p0[11] * scale) + 127; + pp[15] = float2int8(p0[15] * scale) + 127; + pp[16] = float2int8(p0[B_hstep * 4 + 0] * scale) + 127; + pp[17] = float2int8(p0[B_hstep * 4 + 4] * scale) + 127; + pp[18] = float2int8(p0[B_hstep * 4 + 8] * scale) + 127; + pp[19] = float2int8(p0[B_hstep * 4 + 12] * scale) + 127; + pp[20] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; + pp[21] = float2int8(p0[B_hstep * 4 + 5] * scale) + 127; + pp[22] = float2int8(p0[B_hstep * 4 + 9] * scale) + 127; + pp[23] = float2int8(p0[B_hstep * 4 + 13] * scale) + 127; + pp[24] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; + pp[25] = float2int8(p0[B_hstep * 4 + 6] * scale) + 127; + pp[26] = float2int8(p0[B_hstep * 4 + 10] * scale) + 127; + pp[27] = float2int8(p0[B_hstep * 4 + 14] * scale) + 127; + pp[28] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; + pp[29] = float2int8(p0[B_hstep * 4 + 7] * scale) + 127; + pp[30] = float2int8(p0[B_hstep * 4 + 11] * scale) + 127; + pp[31] = float2int8(p0[B_hstep * 4 + 15] * scale) + 127; + pp += 32; + p0 += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3763,6 +8982,45 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[B_hstep] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep + 2] * scale) + 127; + pp[7] = float2int8(p0[B_hstep + 3] * scale) + 127; + pp[8] = float2int8(p0[B_hstep * 2] * scale) + 127; + pp[9] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; + pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; + pp[11] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; + pp[12] = float2int8(p0[B_hstep * 3] * scale) + 127; + pp[13] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; + pp[14] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; + pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; + pp[16] = float2int8(p0[B_hstep * 4] * scale) + 127; + pp[17] = float2int8(p0[B_hstep * 4 + 1] * scale) + 127; + pp[18] = float2int8(p0[B_hstep * 4 + 2] * scale) + 127; + pp[19] = float2int8(p0[B_hstep * 4 + 3] * scale) + 127; + pp[20] = float2int8(p0[B_hstep * 5] * scale) + 127; + pp[21] = float2int8(p0[B_hstep * 5 + 1] * scale) + 127; + pp[22] = float2int8(p0[B_hstep * 5 + 2] * scale) + 127; + pp[23] = float2int8(p0[B_hstep * 5 + 3] * scale) + 127; + pp[24] = float2int8(p0[B_hstep * 6] * scale) + 127; + pp[25] = float2int8(p0[B_hstep * 6 + 1] * scale) + 127; + pp[26] = float2int8(p0[B_hstep * 6 + 2] * scale) + 127; + pp[27] = float2int8(p0[B_hstep * 6 + 3] * scale) + 127; + pp[28] = float2int8(p0[B_hstep * 7] * scale) + 127; + pp[29] = float2int8(p0[B_hstep * 7 + 1] * scale) + 127; + pp[30] = float2int8(p0[B_hstep * 7 + 2] * scale) + 127; + pp[31] = float2int8(p0[B_hstep * 7 + 3] * scale) + 127; + pp += 32; + p0 += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3809,6 +9067,30 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[4] * scale) + 127; + pp[2] = float2int8(p0[8] * scale) + 127; + pp[3] = float2int8(p0[12] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[5] * scale) + 127; + pp[6] = float2int8(p0[9] * scale) + 127; + pp[7] = float2int8(p0[13] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[6] * scale) + 127; + pp[10] = float2int8(p0[10] * scale) + 127; + pp[11] = float2int8(p0[14] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[7] * scale) + 127; + pp[14] = float2int8(p0[11] * scale) + 127; + pp[15] = float2int8(p0[15] * scale) + 127; + + pp += 16; + p0 += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3837,6 +9119,30 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[B_hstep] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep + 2] * scale) + 127; + pp[7] = float2int8(p0[B_hstep + 3] * scale) + 127; + pp[8] = float2int8(p0[B_hstep * 2] * scale) + 127; + pp[9] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; + pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; + pp[11] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; + pp[12] = float2int8(p0[B_hstep * 3] * scale) + 127; + pp[13] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; + pp[14] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; + pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; + + pp += 16; + p0 += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3872,6 +9178,21 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i { int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[B_hstep] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep + 2] * scale) + 127; + pp[7] = float2int8(p0[B_hstep + 3] * scale) + 127; + pp += 8; + p0 += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -3898,6 +9219,17 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i // if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp += 4; + p0 += 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -3910,10 +9242,26 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + transpose_pack_B_tile_fp32_to_int8_avx512vnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + transpose_pack_B_tile_fp32_to_int8_avxvnni(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + const int elempack = B.elempack; const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; - NCNN_LOGE("transpose_pack_B_tile_fp32_to_int8 %d %d", max_jj, elempack); + // NCNN_LOGE("transpose_pack_B_tile_fp32_to_int8 %d %d", max_jj, elempack); signed char* pp = BT; @@ -3928,6 +9276,277 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2 + 0] * scale) + 127; + pp[3] = float2int8(p0[2 + 1] * scale) + 127; + pp[4] = float2int8(p0[16] * scale) + 127; + pp[5] = float2int8(p0[17] * scale) + 127; + pp[6] = float2int8(p0[2 + 16] * scale) + 127; + pp[7] = float2int8(p0[2 + 17] * scale) + 127; + pp[8] = float2int8(p0[32] * scale) + 127; + pp[9] = float2int8(p0[33] * scale) + 127; + pp[10] = float2int8(p0[2 + 32] * scale) + 127; + pp[11] = float2int8(p0[2 + 33] * scale) + 127; + pp[12] = float2int8(p0[48] * scale) + 127; + pp[13] = float2int8(p0[49] * scale) + 127; + pp[14] = float2int8(p0[2 + 48] * scale) + 127; + pp[15] = float2int8(p0[2 + 49] * scale) + 127; + pp[16] = float2int8(p0[64] * scale) + 127; + pp[17] = float2int8(p0[65] * scale) + 127; + pp[18] = float2int8(p0[2 + 64] * scale) + 127; + pp[19] = float2int8(p0[2 + 65] * scale) + 127; + pp[20] = float2int8(p0[80] * scale) + 127; + pp[21] = float2int8(p0[81] * scale) + 127; + pp[22] = float2int8(p0[2 + 80] * scale) + 127; + pp[23] = float2int8(p0[2 + 81] * scale) + 127; + pp[24] = float2int8(p0[96] * scale) + 127; + pp[25] = float2int8(p0[97] * scale) + 127; + pp[26] = float2int8(p0[2 + 96] * scale) + 127; + pp[27] = float2int8(p0[2 + 97] * scale) + 127; + pp[28] = float2int8(p0[112] * scale) + 127; + pp[29] = float2int8(p0[113] * scale) + 127; + pp[30] = float2int8(p0[2 + 112] * scale) + 127; + pp[31] = float2int8(p0[2 + 113] * scale) + 127; + + pp[32 + 0] = float2int8(p0[128 + 0] * scale) + 127; + pp[32 + 1] = float2int8(p0[128 + 1] * scale) + 127; + pp[32 + 2] = float2int8(p0[2 + 128 + 0] * scale) + 127; + pp[32 + 3] = float2int8(p0[2 + 128 + 1] * scale) + 127; + pp[32 + 4] = float2int8(p0[128 + 16] * scale) + 127; + pp[32 + 5] = float2int8(p0[128 + 17] * scale) + 127; + pp[32 + 6] = float2int8(p0[2 + 128 + 16] * scale) + 127; + pp[32 + 7] = float2int8(p0[2 + 128 + 17] * scale) + 127; + pp[32 + 8] = float2int8(p0[128 + 32] * scale) + 127; + pp[32 + 9] = float2int8(p0[128 + 33] * scale) + 127; + pp[32 + 10] = float2int8(p0[2 + 128 + 32] * scale) + 127; + pp[32 + 11] = float2int8(p0[2 + 128 + 33] * scale) + 127; + pp[32 + 12] = float2int8(p0[128 + 48] * scale) + 127; + pp[32 + 13] = float2int8(p0[128 + 49] * scale) + 127; + pp[32 + 14] = float2int8(p0[2 + 128 + 48] * scale) + 127; + pp[32 + 15] = float2int8(p0[2 + 128 + 49] * scale) + 127; + pp[32 + 16] = float2int8(p0[128 + 64] * scale) + 127; + pp[32 + 17] = float2int8(p0[128 + 65] * scale) + 127; + pp[32 + 18] = float2int8(p0[2 + 128 + 64] * scale) + 127; + pp[32 + 19] = float2int8(p0[2 + 128 + 65] * scale) + 127; + pp[32 + 20] = float2int8(p0[128 + 80] * scale) + 127; + pp[32 + 21] = float2int8(p0[128 + 81] * scale) + 127; + pp[32 + 22] = float2int8(p0[2 + 128 + 80] * scale) + 127; + pp[32 + 23] = float2int8(p0[2 + 128 + 81] * scale) + 127; + pp[32 + 24] = float2int8(p0[128 + 96] * scale) + 127; + pp[32 + 25] = float2int8(p0[128 + 97] * scale) + 127; + pp[32 + 26] = float2int8(p0[2 + 128 + 96] * scale) + 127; + pp[32 + 27] = float2int8(p0[2 + 128 + 97] * scale) + 127; + pp[32 + 28] = float2int8(p0[128 + 112] * scale) + 127; + pp[32 + 29] = float2int8(p0[128 + 113] * scale) + 127; + pp[32 + 30] = float2int8(p0[2 + 128 + 112] * scale) + 127; + pp[32 + 31] = float2int8(p0[2 + 128 + 113] * scale) + 127; + + pp[64 + 0] = float2int8(p0[4 + 0] * scale) + 127; + pp[64 + 1] = float2int8(p0[4 + 1] * scale) + 127; + pp[64 + 2] = float2int8(p0[6 + 0] * scale) + 127; + pp[64 + 3] = float2int8(p0[6 + 1] * scale) + 127; + pp[64 + 4] = float2int8(p0[4 + 16] * scale) + 127; + pp[64 + 5] = float2int8(p0[4 + 17] * scale) + 127; + pp[64 + 6] = float2int8(p0[6 + 16] * scale) + 127; + pp[64 + 7] = float2int8(p0[6 + 17] * scale) + 127; + pp[64 + 8] = float2int8(p0[4 + 32] * scale) + 127; + pp[64 + 9] = float2int8(p0[4 + 33] * scale) + 127; + pp[64 + 10] = float2int8(p0[6 + 32] * scale) + 127; + pp[64 + 11] = float2int8(p0[6 + 33] * scale) + 127; + pp[64 + 12] = float2int8(p0[4 + 48] * scale) + 127; + pp[64 + 13] = float2int8(p0[4 + 49] * scale) + 127; + pp[64 + 14] = float2int8(p0[6 + 48] * scale) + 127; + pp[64 + 15] = float2int8(p0[6 + 49] * scale) + 127; + pp[64 + 16] = float2int8(p0[4 + 64] * scale) + 127; + pp[64 + 17] = float2int8(p0[4 + 65] * scale) + 127; + pp[64 + 18] = float2int8(p0[6 + 64] * scale) + 127; + pp[64 + 19] = float2int8(p0[6 + 65] * scale) + 127; + pp[64 + 20] = float2int8(p0[4 + 80] * scale) + 127; + pp[64 + 21] = float2int8(p0[4 + 81] * scale) + 127; + pp[64 + 22] = float2int8(p0[6 + 80] * scale) + 127; + pp[64 + 23] = float2int8(p0[6 + 81] * scale) + 127; + pp[64 + 24] = float2int8(p0[4 + 96] * scale) + 127; + pp[64 + 25] = float2int8(p0[4 + 97] * scale) + 127; + pp[64 + 26] = float2int8(p0[6 + 96] * scale) + 127; + pp[64 + 27] = float2int8(p0[6 + 97] * scale) + 127; + pp[64 + 28] = float2int8(p0[4 + 112] * scale) + 127; + pp[64 + 29] = float2int8(p0[4 + 113] * scale) + 127; + pp[64 + 30] = float2int8(p0[6 + 112] * scale) + 127; + pp[64 + 31] = float2int8(p0[6 + 113] * scale) + 127; + + pp[96 + 0] = float2int8(p0[4 + 128 + 0] * scale) + 127; + pp[96 + 1] = float2int8(p0[4 + 128 + 1] * scale) + 127; + pp[96 + 2] = float2int8(p0[6 + 128 + 0] * scale) + 127; + pp[96 + 3] = float2int8(p0[6 + 128 + 1] * scale) + 127; + pp[96 + 4] = float2int8(p0[4 + 128 + 16] * scale) + 127; + pp[96 + 5] = float2int8(p0[4 + 128 + 17] * scale) + 127; + pp[96 + 6] = float2int8(p0[6 + 128 + 16] * scale) + 127; + pp[96 + 7] = float2int8(p0[6 + 128 + 17] * scale) + 127; + pp[96 + 8] = float2int8(p0[4 + 128 + 32] * scale) + 127; + pp[96 + 9] = float2int8(p0[4 + 128 + 33] * scale) + 127; + pp[96 + 10] = float2int8(p0[6 + 128 + 32] * scale) + 127; + pp[96 + 11] = float2int8(p0[6 + 128 + 33] * scale) + 127; + pp[96 + 12] = float2int8(p0[4 + 128 + 48] * scale) + 127; + pp[96 + 13] = float2int8(p0[4 + 128 + 49] * scale) + 127; + pp[96 + 14] = float2int8(p0[6 + 128 + 48] * scale) + 127; + pp[96 + 15] = float2int8(p0[6 + 128 + 49] * scale) + 127; + pp[96 + 16] = float2int8(p0[4 + 128 + 64] * scale) + 127; + pp[96 + 17] = float2int8(p0[4 + 128 + 65] * scale) + 127; + pp[96 + 18] = float2int8(p0[6 + 128 + 64] * scale) + 127; + pp[96 + 19] = float2int8(p0[6 + 128 + 65] * scale) + 127; + pp[96 + 20] = float2int8(p0[4 + 128 + 80] * scale) + 127; + pp[96 + 21] = float2int8(p0[4 + 128 + 81] * scale) + 127; + pp[96 + 22] = float2int8(p0[6 + 128 + 80] * scale) + 127; + pp[96 + 23] = float2int8(p0[6 + 128 + 81] * scale) + 127; + pp[96 + 24] = float2int8(p0[4 + 128 + 96] * scale) + 127; + pp[96 + 25] = float2int8(p0[4 + 128 + 97] * scale) + 127; + pp[96 + 26] = float2int8(p0[6 + 128 + 96] * scale) + 127; + pp[96 + 27] = float2int8(p0[6 + 128 + 97] * scale) + 127; + pp[96 + 28] = float2int8(p0[4 + 128 + 112] * scale) + 127; + pp[96 + 29] = float2int8(p0[4 + 128 + 113] * scale) + 127; + pp[96 + 30] = float2int8(p0[6 + 128 + 112] * scale) + 127; + pp[96 + 31] = float2int8(p0[6 + 128 + 113] * scale) + 127; + + pp[128 + 0] = float2int8(p0[8 + 0] * scale) + 127; + pp[128 + 1] = float2int8(p0[8 + 1] * scale) + 127; + pp[128 + 2] = float2int8(p0[10 + 0] * scale) + 127; + pp[128 + 3] = float2int8(p0[10 + 1] * scale) + 127; + pp[128 + 4] = float2int8(p0[8 + 16] * scale) + 127; + pp[128 + 5] = float2int8(p0[8 + 17] * scale) + 127; + pp[128 + 6] = float2int8(p0[10 + 16] * scale) + 127; + pp[128 + 7] = float2int8(p0[10 + 17] * scale) + 127; + pp[128 + 8] = float2int8(p0[8 + 32] * scale) + 127; + pp[128 + 9] = float2int8(p0[8 + 33] * scale) + 127; + pp[128 + 10] = float2int8(p0[10 + 32] * scale) + 127; + pp[128 + 11] = float2int8(p0[10 + 33] * scale) + 127; + pp[128 + 12] = float2int8(p0[8 + 48] * scale) + 127; + pp[128 + 13] = float2int8(p0[8 + 49] * scale) + 127; + pp[128 + 14] = float2int8(p0[10 + 48] * scale) + 127; + pp[128 + 15] = float2int8(p0[10 + 49] * scale) + 127; + pp[128 + 16] = float2int8(p0[8 + 64] * scale) + 127; + pp[128 + 17] = float2int8(p0[8 + 65] * scale) + 127; + pp[128 + 18] = float2int8(p0[10 + 64] * scale) + 127; + pp[128 + 19] = float2int8(p0[10 + 65] * scale) + 127; + pp[128 + 20] = float2int8(p0[8 + 80] * scale) + 127; + pp[128 + 21] = float2int8(p0[8 + 81] * scale) + 127; + pp[128 + 22] = float2int8(p0[10 + 80] * scale) + 127; + pp[128 + 23] = float2int8(p0[10 + 81] * scale) + 127; + pp[128 + 24] = float2int8(p0[8 + 96] * scale) + 127; + pp[128 + 25] = float2int8(p0[8 + 97] * scale) + 127; + pp[128 + 26] = float2int8(p0[10 + 96] * scale) + 127; + pp[128 + 27] = float2int8(p0[10 + 97] * scale) + 127; + pp[128 + 28] = float2int8(p0[8 + 112] * scale) + 127; + pp[128 + 29] = float2int8(p0[8 + 113] * scale) + 127; + pp[128 + 30] = float2int8(p0[10 + 112] * scale) + 127; + pp[128 + 31] = float2int8(p0[10 + 113] * scale) + 127; + + pp[160 + 0] = float2int8(p0[8 + 128 + 0] * scale) + 127; + pp[160 + 1] = float2int8(p0[8 + 128 + 1] * scale) + 127; + pp[160 + 2] = float2int8(p0[10 + 128 + 0] * scale) + 127; + pp[160 + 3] = float2int8(p0[10 + 128 + 1] * scale) + 127; + pp[160 + 4] = float2int8(p0[8 + 128 + 16] * scale) + 127; + pp[160 + 5] = float2int8(p0[8 + 128 + 17] * scale) + 127; + pp[160 + 6] = float2int8(p0[10 + 128 + 16] * scale) + 127; + pp[160 + 7] = float2int8(p0[10 + 128 + 17] * scale) + 127; + pp[160 + 8] = float2int8(p0[8 + 128 + 32] * scale) + 127; + pp[160 + 9] = float2int8(p0[8 + 128 + 33] * scale) + 127; + pp[160 + 10] = float2int8(p0[10 + 128 + 32] * scale) + 127; + pp[160 + 11] = float2int8(p0[10 + 128 + 33] * scale) + 127; + pp[160 + 12] = float2int8(p0[8 + 128 + 48] * scale) + 127; + pp[160 + 13] = float2int8(p0[8 + 128 + 49] * scale) + 127; + pp[160 + 14] = float2int8(p0[10 + 128 + 48] * scale) + 127; + pp[160 + 15] = float2int8(p0[10 + 128 + 49] * scale) + 127; + pp[160 + 16] = float2int8(p0[8 + 128 + 64] * scale) + 127; + pp[160 + 17] = float2int8(p0[8 + 128 + 65] * scale) + 127; + pp[160 + 18] = float2int8(p0[10 + 128 + 64] * scale) + 127; + pp[160 + 19] = float2int8(p0[10 + 128 + 65] * scale) + 127; + pp[160 + 20] = float2int8(p0[8 + 128 + 80] * scale) + 127; + pp[160 + 21] = float2int8(p0[8 + 128 + 81] * scale) + 127; + pp[160 + 22] = float2int8(p0[10 + 128 + 80] * scale) + 127; + pp[160 + 23] = float2int8(p0[10 + 128 + 81] * scale) + 127; + pp[160 + 24] = float2int8(p0[8 + 128 + 96] * scale) + 127; + pp[160 + 25] = float2int8(p0[8 + 128 + 97] * scale) + 127; + pp[160 + 26] = float2int8(p0[10 + 128 + 96] * scale) + 127; + pp[160 + 27] = float2int8(p0[10 + 128 + 97] * scale) + 127; + pp[160 + 28] = float2int8(p0[8 + 128 + 112] * scale) + 127; + pp[160 + 29] = float2int8(p0[8 + 128 + 113] * scale) + 127; + pp[160 + 30] = float2int8(p0[10 + 128 + 112] * scale) + 127; + pp[160 + 31] = float2int8(p0[10 + 128 + 113] * scale) + 127; + + pp[192 + 0] = float2int8(p0[12 + 0] * scale) + 127; + pp[192 + 1] = float2int8(p0[12 + 1] * scale) + 127; + pp[192 + 2] = float2int8(p0[14 + 0] * scale) + 127; + pp[192 + 3] = float2int8(p0[14 + 1] * scale) + 127; + pp[192 + 4] = float2int8(p0[12 + 16] * scale) + 127; + pp[192 + 5] = float2int8(p0[12 + 17] * scale) + 127; + pp[192 + 6] = float2int8(p0[14 + 16] * scale) + 127; + pp[192 + 7] = float2int8(p0[14 + 17] * scale) + 127; + pp[192 + 8] = float2int8(p0[12 + 32] * scale) + 127; + pp[192 + 9] = float2int8(p0[12 + 33] * scale) + 127; + pp[192 + 10] = float2int8(p0[14 + 32] * scale) + 127; + pp[192 + 11] = float2int8(p0[14 + 33] * scale) + 127; + pp[192 + 12] = float2int8(p0[12 + 48] * scale) + 127; + pp[192 + 13] = float2int8(p0[12 + 49] * scale) + 127; + pp[192 + 14] = float2int8(p0[14 + 48] * scale) + 127; + pp[192 + 15] = float2int8(p0[14 + 49] * scale) + 127; + pp[192 + 16] = float2int8(p0[12 + 64] * scale) + 127; + pp[192 + 17] = float2int8(p0[12 + 65] * scale) + 127; + pp[192 + 18] = float2int8(p0[14 + 64] * scale) + 127; + pp[192 + 19] = float2int8(p0[14 + 65] * scale) + 127; + pp[192 + 20] = float2int8(p0[12 + 80] * scale) + 127; + pp[192 + 21] = float2int8(p0[12 + 81] * scale) + 127; + pp[192 + 22] = float2int8(p0[14 + 80] * scale) + 127; + pp[192 + 23] = float2int8(p0[14 + 81] * scale) + 127; + pp[192 + 24] = float2int8(p0[12 + 96] * scale) + 127; + pp[192 + 25] = float2int8(p0[12 + 97] * scale) + 127; + pp[192 + 26] = float2int8(p0[14 + 96] * scale) + 127; + pp[192 + 27] = float2int8(p0[14 + 97] * scale) + 127; + pp[192 + 28] = float2int8(p0[12 + 112] * scale) + 127; + pp[192 + 29] = float2int8(p0[12 + 113] * scale) + 127; + pp[192 + 30] = float2int8(p0[14 + 112] * scale) + 127; + pp[192 + 31] = float2int8(p0[14 + 113] * scale) + 127; + + pp[224 + 0] = float2int8(p0[12 + 128 + 0] * scale) + 127; + pp[224 + 1] = float2int8(p0[12 + 128 + 1] * scale) + 127; + pp[224 + 2] = float2int8(p0[14 + 128 + 0] * scale) + 127; + pp[224 + 3] = float2int8(p0[14 + 128 + 1] * scale) + 127; + pp[224 + 4] = float2int8(p0[12 + 128 + 16] * scale) + 127; + pp[224 + 5] = float2int8(p0[12 + 128 + 17] * scale) + 127; + pp[224 + 6] = float2int8(p0[14 + 128 + 16] * scale) + 127; + pp[224 + 7] = float2int8(p0[14 + 128 + 17] * scale) + 127; + pp[224 + 8] = float2int8(p0[12 + 128 + 32] * scale) + 127; + pp[224 + 9] = float2int8(p0[12 + 128 + 33] * scale) + 127; + pp[224 + 10] = float2int8(p0[14 + 128 + 32] * scale) + 127; + pp[224 + 11] = float2int8(p0[14 + 128 + 33] * scale) + 127; + pp[224 + 12] = float2int8(p0[12 + 128 + 48] * scale) + 127; + pp[224 + 13] = float2int8(p0[12 + 128 + 49] * scale) + 127; + pp[224 + 14] = float2int8(p0[14 + 128 + 48] * scale) + 127; + pp[224 + 15] = float2int8(p0[14 + 128 + 49] * scale) + 127; + pp[224 + 16] = float2int8(p0[12 + 128 + 64] * scale) + 127; + pp[224 + 17] = float2int8(p0[12 + 128 + 65] * scale) + 127; + pp[224 + 18] = float2int8(p0[14 + 128 + 64] * scale) + 127; + pp[224 + 19] = float2int8(p0[14 + 128 + 65] * scale) + 127; + pp[224 + 20] = float2int8(p0[12 + 128 + 80] * scale) + 127; + pp[224 + 21] = float2int8(p0[12 + 128 + 81] * scale) + 127; + pp[224 + 22] = float2int8(p0[14 + 128 + 80] * scale) + 127; + pp[224 + 23] = float2int8(p0[14 + 128 + 81] * scale) + 127; + pp[224 + 24] = float2int8(p0[12 + 128 + 96] * scale) + 127; + pp[224 + 25] = float2int8(p0[12 + 128 + 97] * scale) + 127; + pp[224 + 26] = float2int8(p0[14 + 128 + 96] * scale) + 127; + pp[224 + 27] = float2int8(p0[14 + 128 + 97] * scale) + 127; + pp[224 + 28] = float2int8(p0[12 + 128 + 112] * scale) + 127; + pp[224 + 29] = float2int8(p0[12 + 128 + 113] * scale) + 127; + pp[224 + 30] = float2int8(p0[14 + 128 + 112] * scale) + 127; + pp[224 + 31] = float2int8(p0[14 + 128 + 113] * scale) + 127; + + pp += 256; + p0 += B_hstep * 16; + } +#else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { pp[0] = float2int8(p0[0] * scale); @@ -4205,10 +9824,150 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 256; p0 += B_hstep * 16; } +#endif // __AVX512VNNI__ } if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[8] * scale) + 127; + pp[5] = float2int8(p0[9] * scale) + 127; + pp[6] = float2int8(p0[10] * scale) + 127; + pp[7] = float2int8(p0[11] * scale) + 127; + pp[8] = float2int8(p0[16] * scale) + 127; + pp[9] = float2int8(p0[17] * scale) + 127; + pp[10] = float2int8(p0[18] * scale) + 127; + pp[11] = float2int8(p0[19] * scale) + 127; + pp[12] = float2int8(p0[24] * scale) + 127; + pp[13] = float2int8(p0[25] * scale) + 127; + pp[14] = float2int8(p0[26] * scale) + 127; + pp[15] = float2int8(p0[27] * scale) + 127; + pp[16] = float2int8(p0[32] * scale) + 127; + pp[17] = float2int8(p0[33] * scale) + 127; + pp[18] = float2int8(p0[34] * scale) + 127; + pp[19] = float2int8(p0[35] * scale) + 127; + pp[20] = float2int8(p0[40] * scale) + 127; + pp[21] = float2int8(p0[41] * scale) + 127; + pp[22] = float2int8(p0[42] * scale) + 127; + pp[23] = float2int8(p0[43] * scale) + 127; + pp[24] = float2int8(p0[48] * scale) + 127; + pp[25] = float2int8(p0[49] * scale) + 127; + pp[26] = float2int8(p0[50] * scale) + 127; + pp[27] = float2int8(p0[51] * scale) + 127; + pp[28] = float2int8(p0[56] * scale) + 127; + pp[29] = float2int8(p0[57] * scale) + 127; + pp[30] = float2int8(p0[58] * scale) + 127; + pp[31] = float2int8(p0[59] * scale) + 127; + + pp[32 + 0] = float2int8(p0[64 + 0] * scale) + 127; + pp[32 + 1] = float2int8(p0[64 + 1] * scale) + 127; + pp[32 + 2] = float2int8(p0[64 + 2] * scale) + 127; + pp[32 + 3] = float2int8(p0[64 + 3] * scale) + 127; + pp[32 + 4] = float2int8(p0[64 + 8] * scale) + 127; + pp[32 + 5] = float2int8(p0[64 + 9] * scale) + 127; + pp[32 + 6] = float2int8(p0[64 + 10] * scale) + 127; + pp[32 + 7] = float2int8(p0[64 + 11] * scale) + 127; + pp[32 + 8] = float2int8(p0[64 + 16] * scale) + 127; + pp[32 + 9] = float2int8(p0[64 + 17] * scale) + 127; + pp[32 + 10] = float2int8(p0[64 + 18] * scale) + 127; + pp[32 + 11] = float2int8(p0[64 + 19] * scale) + 127; + pp[32 + 12] = float2int8(p0[64 + 24] * scale) + 127; + pp[32 + 13] = float2int8(p0[64 + 25] * scale) + 127; + pp[32 + 14] = float2int8(p0[64 + 26] * scale) + 127; + pp[32 + 15] = float2int8(p0[64 + 27] * scale) + 127; + pp[32 + 16] = float2int8(p0[64 + 32] * scale) + 127; + pp[32 + 17] = float2int8(p0[64 + 33] * scale) + 127; + pp[32 + 18] = float2int8(p0[64 + 34] * scale) + 127; + pp[32 + 19] = float2int8(p0[64 + 35] * scale) + 127; + pp[32 + 20] = float2int8(p0[64 + 40] * scale) + 127; + pp[32 + 21] = float2int8(p0[64 + 41] * scale) + 127; + pp[32 + 22] = float2int8(p0[64 + 42] * scale) + 127; + pp[32 + 23] = float2int8(p0[64 + 43] * scale) + 127; + pp[32 + 24] = float2int8(p0[64 + 48] * scale) + 127; + pp[32 + 25] = float2int8(p0[64 + 49] * scale) + 127; + pp[32 + 26] = float2int8(p0[64 + 50] * scale) + 127; + pp[32 + 27] = float2int8(p0[64 + 51] * scale) + 127; + pp[32 + 28] = float2int8(p0[64 + 56] * scale) + 127; + pp[32 + 29] = float2int8(p0[64 + 57] * scale) + 127; + pp[32 + 30] = float2int8(p0[64 + 58] * scale) + 127; + pp[32 + 31] = float2int8(p0[64 + 59] * scale) + 127; + + pp[64 + 0] = float2int8(p0[4] * scale) + 127; + pp[64 + 1] = float2int8(p0[5] * scale) + 127; + pp[64 + 2] = float2int8(p0[6] * scale) + 127; + pp[64 + 3] = float2int8(p0[7] * scale) + 127; + pp[64 + 4] = float2int8(p0[12] * scale) + 127; + pp[64 + 5] = float2int8(p0[13] * scale) + 127; + pp[64 + 6] = float2int8(p0[14] * scale) + 127; + pp[64 + 7] = float2int8(p0[15] * scale) + 127; + pp[64 + 8] = float2int8(p0[20] * scale) + 127; + pp[64 + 9] = float2int8(p0[21] * scale) + 127; + pp[64 + 10] = float2int8(p0[22] * scale) + 127; + pp[64 + 11] = float2int8(p0[23] * scale) + 127; + pp[64 + 12] = float2int8(p0[28] * scale) + 127; + pp[64 + 13] = float2int8(p0[29] * scale) + 127; + pp[64 + 14] = float2int8(p0[30] * scale) + 127; + pp[64 + 15] = float2int8(p0[31] * scale) + 127; + pp[64 + 16] = float2int8(p0[36] * scale) + 127; + pp[64 + 17] = float2int8(p0[37] * scale) + 127; + pp[64 + 18] = float2int8(p0[38] * scale) + 127; + pp[64 + 19] = float2int8(p0[39] * scale) + 127; + pp[64 + 20] = float2int8(p0[44] * scale) + 127; + pp[64 + 21] = float2int8(p0[45] * scale) + 127; + pp[64 + 22] = float2int8(p0[46] * scale) + 127; + pp[64 + 23] = float2int8(p0[47] * scale) + 127; + pp[64 + 24] = float2int8(p0[52] * scale) + 127; + pp[64 + 25] = float2int8(p0[53] * scale) + 127; + pp[64 + 26] = float2int8(p0[54] * scale) + 127; + pp[64 + 27] = float2int8(p0[55] * scale) + 127; + pp[64 + 28] = float2int8(p0[60] * scale) + 127; + pp[64 + 29] = float2int8(p0[61] * scale) + 127; + pp[64 + 30] = float2int8(p0[62] * scale) + 127; + pp[64 + 31] = float2int8(p0[63] * scale) + 127; + + pp[96 + 0] = float2int8(p0[64 + 4] * scale) + 127; + pp[96 + 1] = float2int8(p0[64 + 5] * scale) + 127; + pp[96 + 2] = float2int8(p0[64 + 6] * scale) + 127; + pp[96 + 3] = float2int8(p0[64 + 7] * scale) + 127; + pp[96 + 4] = float2int8(p0[64 + 12] * scale) + 127; + pp[96 + 5] = float2int8(p0[64 + 13] * scale) + 127; + pp[96 + 6] = float2int8(p0[64 + 14] * scale) + 127; + pp[96 + 7] = float2int8(p0[64 + 15] * scale) + 127; + pp[96 + 8] = float2int8(p0[64 + 20] * scale) + 127; + pp[96 + 9] = float2int8(p0[64 + 21] * scale) + 127; + pp[96 + 10] = float2int8(p0[64 + 22] * scale) + 127; + pp[96 + 11] = float2int8(p0[64 + 23] * scale) + 127; + pp[96 + 12] = float2int8(p0[64 + 28] * scale) + 127; + pp[96 + 13] = float2int8(p0[64 + 29] * scale) + 127; + pp[96 + 14] = float2int8(p0[64 + 30] * scale) + 127; + pp[96 + 15] = float2int8(p0[64 + 31] * scale) + 127; + pp[96 + 16] = float2int8(p0[64 + 36] * scale) + 127; + pp[96 + 17] = float2int8(p0[64 + 37] * scale) + 127; + pp[96 + 18] = float2int8(p0[64 + 38] * scale) + 127; + pp[96 + 19] = float2int8(p0[64 + 39] * scale) + 127; + pp[96 + 20] = float2int8(p0[64 + 44] * scale) + 127; + pp[96 + 21] = float2int8(p0[64 + 45] * scale) + 127; + pp[96 + 22] = float2int8(p0[64 + 46] * scale) + 127; + pp[96 + 23] = float2int8(p0[64 + 47] * scale) + 127; + pp[96 + 24] = float2int8(p0[64 + 52] * scale) + 127; + pp[96 + 25] = float2int8(p0[64 + 53] * scale) + 127; + pp[96 + 26] = float2int8(p0[64 + 54] * scale) + 127; + pp[96 + 27] = float2int8(p0[64 + 55] * scale) + 127; + pp[96 + 28] = float2int8(p0[64 + 60] * scale) + 127; + pp[96 + 29] = float2int8(p0[64 + 61] * scale) + 127; + pp[96 + 30] = float2int8(p0[64 + 62] * scale) + 127; + pp[96 + 31] = float2int8(p0[64 + 63] * scale) + 127; + + pp += 128; + p0 += B_hstep * 8; + } +#else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { pp[0] = float2int8(p0[0] * scale); @@ -4350,10 +10109,84 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 128; p0 += B_hstep * 8; } +#endif // __AVX512VNNI__ } if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[4] * scale) + 127; + pp[5] = float2int8(p0[5] * scale) + 127; + pp[6] = float2int8(p0[6] * scale) + 127; + pp[7] = float2int8(p0[7] * scale) + 127; + pp[8] = float2int8(p0[8] * scale) + 127; + pp[9] = float2int8(p0[9] * scale) + 127; + pp[10] = float2int8(p0[10] * scale) + 127; + pp[11] = float2int8(p0[11] * scale) + 127; + pp[12] = float2int8(p0[12] * scale) + 127; + pp[13] = float2int8(p0[13] * scale) + 127; + pp[14] = float2int8(p0[14] * scale) + 127; + pp[15] = float2int8(p0[15] * scale) + 127; + pp[16] = float2int8(p0[16] * scale) + 127; + pp[17] = float2int8(p0[17] * scale) + 127; + pp[18] = float2int8(p0[18] * scale) + 127; + pp[19] = float2int8(p0[19] * scale) + 127; + pp[20] = float2int8(p0[20] * scale) + 127; + pp[21] = float2int8(p0[21] * scale) + 127; + pp[22] = float2int8(p0[22] * scale) + 127; + pp[23] = float2int8(p0[23] * scale) + 127; + pp[24] = float2int8(p0[24] * scale) + 127; + pp[25] = float2int8(p0[25] * scale) + 127; + pp[26] = float2int8(p0[26] * scale) + 127; + pp[27] = float2int8(p0[27] * scale) + 127; + pp[28] = float2int8(p0[28] * scale) + 127; + pp[29] = float2int8(p0[29] * scale) + 127; + pp[30] = float2int8(p0[30] * scale) + 127; + pp[31] = float2int8(p0[31] * scale) + 127; + + pp[32 + 0] = float2int8(p0[32 + 0] * scale) + 127; + pp[32 + 1] = float2int8(p0[32 + 1] * scale) + 127; + pp[32 + 2] = float2int8(p0[32 + 2] * scale) + 127; + pp[32 + 3] = float2int8(p0[32 + 3] * scale) + 127; + pp[32 + 4] = float2int8(p0[32 + 4] * scale) + 127; + pp[32 + 5] = float2int8(p0[32 + 5] * scale) + 127; + pp[32 + 6] = float2int8(p0[32 + 6] * scale) + 127; + pp[32 + 7] = float2int8(p0[32 + 7] * scale) + 127; + pp[32 + 8] = float2int8(p0[32 + 8] * scale) + 127; + pp[32 + 9] = float2int8(p0[32 + 9] * scale) + 127; + pp[32 + 10] = float2int8(p0[32 + 10] * scale) + 127; + pp[32 + 11] = float2int8(p0[32 + 11] * scale) + 127; + pp[32 + 12] = float2int8(p0[32 + 12] * scale) + 127; + pp[32 + 13] = float2int8(p0[32 + 13] * scale) + 127; + pp[32 + 14] = float2int8(p0[32 + 14] * scale) + 127; + pp[32 + 15] = float2int8(p0[32 + 15] * scale) + 127; + pp[32 + 16] = float2int8(p0[32 + 16] * scale) + 127; + pp[32 + 17] = float2int8(p0[32 + 17] * scale) + 127; + pp[32 + 18] = float2int8(p0[32 + 18] * scale) + 127; + pp[32 + 19] = float2int8(p0[32 + 19] * scale) + 127; + pp[32 + 20] = float2int8(p0[32 + 20] * scale) + 127; + pp[32 + 21] = float2int8(p0[32 + 21] * scale) + 127; + pp[32 + 22] = float2int8(p0[32 + 22] * scale) + 127; + pp[32 + 23] = float2int8(p0[32 + 23] * scale) + 127; + pp[32 + 24] = float2int8(p0[32 + 24] * scale) + 127; + pp[32 + 25] = float2int8(p0[32 + 25] * scale) + 127; + pp[32 + 26] = float2int8(p0[32 + 26] * scale) + 127; + pp[32 + 27] = float2int8(p0[32 + 27] * scale) + 127; + pp[32 + 28] = float2int8(p0[32 + 28] * scale) + 127; + pp[32 + 29] = float2int8(p0[32 + 29] * scale) + 127; + pp[32 + 30] = float2int8(p0[32 + 30] * scale) + 127; + pp[32 + 31] = float2int8(p0[32 + 31] * scale) + 127; + + pp += 64; + p0 += B_hstep * 4; + } +#else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale); @@ -4427,10 +10260,83 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 64; p0 += B_hstep * 4; } +#endif // __AVX512VNNI__ } if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[B_hstep] * scale) + 127; + pp[2] = float2int8(p0[B_hstep * 2] * scale) + 127; + pp[3] = float2int8(p0[B_hstep * 3] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; + pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[B_hstep + 2] * scale) + 127; + pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; + pp[11] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[B_hstep + 3] * scale) + 127; + pp[14] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; + pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; + pp[16] = float2int8(p0[4] * scale) + 127; + pp[17] = float2int8(p0[B_hstep + 4] * scale) + 127; + pp[18] = float2int8(p0[B_hstep * 2 + 4] * scale) + 127; + pp[19] = float2int8(p0[B_hstep * 3 + 4] * scale) + 127; + pp[20] = float2int8(p0[5] * scale) + 127; + pp[21] = float2int8(p0[B_hstep + 5] * scale) + 127; + pp[22] = float2int8(p0[B_hstep * 2 + 5] * scale) + 127; + pp[23] = float2int8(p0[B_hstep * 3 + 5] * scale) + 127; + pp[24] = float2int8(p0[6] * scale) + 127; + pp[25] = float2int8(p0[B_hstep + 6] * scale) + 127; + pp[26] = float2int8(p0[B_hstep * 2 + 6] * scale) + 127; + pp[27] = float2int8(p0[B_hstep * 3 + 6] * scale) + 127; + pp[28] = float2int8(p0[7] * scale) + 127; + pp[29] = float2int8(p0[B_hstep + 7] * scale) + 127; + pp[30] = float2int8(p0[B_hstep * 2 + 7] * scale) + 127; + pp[31] = float2int8(p0[B_hstep * 3 + 7] * scale) + 127; + + pp[32 + 0] = float2int8(p0[8] * scale) + 127; + pp[32 + 1] = float2int8(p0[B_hstep + 8] * scale) + 127; + pp[32 + 2] = float2int8(p0[B_hstep * 2 + 8] * scale) + 127; + pp[32 + 3] = float2int8(p0[B_hstep * 3 + 8] * scale) + 127; + pp[32 + 4] = float2int8(p0[9] * scale) + 127; + pp[32 + 5] = float2int8(p0[B_hstep + 9] * scale) + 127; + pp[32 + 6] = float2int8(p0[B_hstep * 2 + 9] * scale) + 127; + pp[32 + 7] = float2int8(p0[B_hstep * 3 + 9] * scale) + 127; + pp[32 + 8] = float2int8(p0[10] * scale) + 127; + pp[32 + 9] = float2int8(p0[B_hstep + 10] * scale) + 127; + pp[32 + 10] = float2int8(p0[B_hstep * 2 + 10] * scale) + 127; + pp[32 + 11] = float2int8(p0[B_hstep * 3 + 10] * scale) + 127; + pp[32 + 12] = float2int8(p0[11] * scale) + 127; + pp[32 + 13] = float2int8(p0[B_hstep + 11] * scale) + 127; + pp[32 + 14] = float2int8(p0[B_hstep * 2 + 11] * scale) + 127; + pp[32 + 15] = float2int8(p0[B_hstep * 3 + 11] * scale) + 127; + pp[32 + 16] = float2int8(p0[12] * scale) + 127; + pp[32 + 17] = float2int8(p0[B_hstep + 12] * scale) + 127; + pp[32 + 18] = float2int8(p0[B_hstep * 2 + 12] * scale) + 127; + pp[32 + 19] = float2int8(p0[B_hstep * 3 + 12] * scale) + 127; + pp[32 + 20] = float2int8(p0[13] * scale) + 127; + pp[32 + 21] = float2int8(p0[B_hstep + 13] * scale) + 127; + pp[32 + 22] = float2int8(p0[B_hstep * 2 + 13] * scale) + 127; + pp[32 + 23] = float2int8(p0[B_hstep * 3 + 13] * scale) + 127; + pp[32 + 24] = float2int8(p0[14] * scale) + 127; + pp[32 + 25] = float2int8(p0[B_hstep + 14] * scale) + 127; + pp[32 + 26] = float2int8(p0[B_hstep * 2 + 14] * scale) + 127; + pp[32 + 27] = float2int8(p0[B_hstep * 3 + 14] * scale) + 127; + pp[32 + 28] = float2int8(p0[15] * scale) + 127; + pp[32 + 29] = float2int8(p0[B_hstep + 15] * scale) + 127; + pp[32 + 30] = float2int8(p0[B_hstep * 2 + 15] * scale) + 127; + pp[32 + 31] = float2int8(p0[B_hstep * 3 + 15] * scale) + 127; + pp += 64; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -4502,6 +10408,145 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2 + 0] * scale) + 127; + pp[3] = float2int8(p0[2 + 1] * scale) + 127; + pp[4] = float2int8(p0[16] * scale) + 127; + pp[5] = float2int8(p0[17] * scale) + 127; + pp[6] = float2int8(p0[2 + 16] * scale) + 127; + pp[7] = float2int8(p0[2 + 17] * scale) + 127; + pp[8] = float2int8(p0[32] * scale) + 127; + pp[9] = float2int8(p0[33] * scale) + 127; + pp[10] = float2int8(p0[2 + 32] * scale) + 127; + pp[11] = float2int8(p0[2 + 33] * scale) + 127; + pp[12] = float2int8(p0[48] * scale) + 127; + pp[13] = float2int8(p0[49] * scale) + 127; + pp[14] = float2int8(p0[2 + 48] * scale) + 127; + pp[15] = float2int8(p0[2 + 49] * scale) + 127; + pp[16] = float2int8(p0[64] * scale) + 127; + pp[17] = float2int8(p0[65] * scale) + 127; + pp[18] = float2int8(p0[2 + 64] * scale) + 127; + pp[19] = float2int8(p0[2 + 65] * scale) + 127; + pp[20] = float2int8(p0[80] * scale) + 127; + pp[21] = float2int8(p0[81] * scale) + 127; + pp[22] = float2int8(p0[2 + 80] * scale) + 127; + pp[23] = float2int8(p0[2 + 81] * scale) + 127; + pp[24] = float2int8(p0[96] * scale) + 127; + pp[25] = float2int8(p0[97] * scale) + 127; + pp[26] = float2int8(p0[2 + 96] * scale) + 127; + pp[27] = float2int8(p0[2 + 97] * scale) + 127; + pp[28] = float2int8(p0[112] * scale) + 127; + pp[29] = float2int8(p0[113] * scale) + 127; + pp[30] = float2int8(p0[2 + 112] * scale) + 127; + pp[31] = float2int8(p0[2 + 113] * scale) + 127; + + pp[32 + 0] = float2int8(p0[4 + 0] * scale) + 127; + pp[32 + 1] = float2int8(p0[4 + 1] * scale) + 127; + pp[32 + 2] = float2int8(p0[6 + 0] * scale) + 127; + pp[32 + 3] = float2int8(p0[6 + 1] * scale) + 127; + pp[32 + 4] = float2int8(p0[4 + 16] * scale) + 127; + pp[32 + 5] = float2int8(p0[4 + 17] * scale) + 127; + pp[32 + 6] = float2int8(p0[6 + 16] * scale) + 127; + pp[32 + 7] = float2int8(p0[6 + 17] * scale) + 127; + pp[32 + 8] = float2int8(p0[4 + 32] * scale) + 127; + pp[32 + 9] = float2int8(p0[4 + 33] * scale) + 127; + pp[32 + 10] = float2int8(p0[6 + 32] * scale) + 127; + pp[32 + 11] = float2int8(p0[6 + 33] * scale) + 127; + pp[32 + 12] = float2int8(p0[4 + 48] * scale) + 127; + pp[32 + 13] = float2int8(p0[4 + 49] * scale) + 127; + pp[32 + 14] = float2int8(p0[6 + 48] * scale) + 127; + pp[32 + 15] = float2int8(p0[6 + 49] * scale) + 127; + pp[32 + 16] = float2int8(p0[4 + 64] * scale) + 127; + pp[32 + 17] = float2int8(p0[4 + 65] * scale) + 127; + pp[32 + 18] = float2int8(p0[6 + 64] * scale) + 127; + pp[32 + 19] = float2int8(p0[6 + 65] * scale) + 127; + pp[32 + 20] = float2int8(p0[4 + 80] * scale) + 127; + pp[32 + 21] = float2int8(p0[4 + 81] * scale) + 127; + pp[32 + 22] = float2int8(p0[6 + 80] * scale) + 127; + pp[32 + 23] = float2int8(p0[6 + 81] * scale) + 127; + pp[32 + 24] = float2int8(p0[4 + 96] * scale) + 127; + pp[32 + 25] = float2int8(p0[4 + 97] * scale) + 127; + pp[32 + 26] = float2int8(p0[6 + 96] * scale) + 127; + pp[32 + 27] = float2int8(p0[6 + 97] * scale) + 127; + pp[32 + 28] = float2int8(p0[4 + 112] * scale) + 127; + pp[32 + 29] = float2int8(p0[4 + 113] * scale) + 127; + pp[32 + 30] = float2int8(p0[6 + 112] * scale) + 127; + pp[32 + 31] = float2int8(p0[6 + 113] * scale) + 127; + + pp[64 + 0] = float2int8(p0[8 + 0] * scale) + 127; + pp[64 + 1] = float2int8(p0[8 + 1] * scale) + 127; + pp[64 + 2] = float2int8(p0[10 + 0] * scale) + 127; + pp[64 + 3] = float2int8(p0[10 + 1] * scale) + 127; + pp[64 + 4] = float2int8(p0[8 + 16] * scale) + 127; + pp[64 + 5] = float2int8(p0[8 + 17] * scale) + 127; + pp[64 + 6] = float2int8(p0[10 + 16] * scale) + 127; + pp[64 + 7] = float2int8(p0[10 + 17] * scale) + 127; + pp[64 + 8] = float2int8(p0[8 + 32] * scale) + 127; + pp[64 + 9] = float2int8(p0[8 + 33] * scale) + 127; + pp[64 + 10] = float2int8(p0[10 + 32] * scale) + 127; + pp[64 + 11] = float2int8(p0[10 + 33] * scale) + 127; + pp[64 + 12] = float2int8(p0[8 + 48] * scale) + 127; + pp[64 + 13] = float2int8(p0[8 + 49] * scale) + 127; + pp[64 + 14] = float2int8(p0[10 + 48] * scale) + 127; + pp[64 + 15] = float2int8(p0[10 + 49] * scale) + 127; + pp[64 + 16] = float2int8(p0[8 + 64] * scale) + 127; + pp[64 + 17] = float2int8(p0[8 + 65] * scale) + 127; + pp[64 + 18] = float2int8(p0[10 + 64] * scale) + 127; + pp[64 + 19] = float2int8(p0[10 + 65] * scale) + 127; + pp[64 + 20] = float2int8(p0[8 + 80] * scale) + 127; + pp[64 + 21] = float2int8(p0[8 + 81] * scale) + 127; + pp[64 + 22] = float2int8(p0[10 + 80] * scale) + 127; + pp[64 + 23] = float2int8(p0[10 + 81] * scale) + 127; + pp[64 + 24] = float2int8(p0[8 + 96] * scale) + 127; + pp[64 + 25] = float2int8(p0[8 + 97] * scale) + 127; + pp[64 + 26] = float2int8(p0[10 + 96] * scale) + 127; + pp[64 + 27] = float2int8(p0[10 + 97] * scale) + 127; + pp[64 + 28] = float2int8(p0[8 + 112] * scale) + 127; + pp[64 + 29] = float2int8(p0[8 + 113] * scale) + 127; + pp[64 + 30] = float2int8(p0[10 + 112] * scale) + 127; + pp[64 + 31] = float2int8(p0[10 + 113] * scale) + 127; + + pp[96 + 0] = float2int8(p0[12 + 0] * scale) + 127; + pp[96 + 1] = float2int8(p0[12 + 1] * scale) + 127; + pp[96 + 2] = float2int8(p0[14 + 0] * scale) + 127; + pp[96 + 3] = float2int8(p0[14 + 1] * scale) + 127; + pp[96 + 4] = float2int8(p0[12 + 16] * scale) + 127; + pp[96 + 5] = float2int8(p0[12 + 17] * scale) + 127; + pp[96 + 6] = float2int8(p0[14 + 16] * scale) + 127; + pp[96 + 7] = float2int8(p0[14 + 17] * scale) + 127; + pp[96 + 8] = float2int8(p0[12 + 32] * scale) + 127; + pp[96 + 9] = float2int8(p0[12 + 33] * scale) + 127; + pp[96 + 10] = float2int8(p0[14 + 32] * scale) + 127; + pp[96 + 11] = float2int8(p0[14 + 33] * scale) + 127; + pp[96 + 12] = float2int8(p0[12 + 48] * scale) + 127; + pp[96 + 13] = float2int8(p0[12 + 49] * scale) + 127; + pp[96 + 14] = float2int8(p0[14 + 48] * scale) + 127; + pp[96 + 15] = float2int8(p0[14 + 49] * scale) + 127; + pp[96 + 16] = float2int8(p0[12 + 64] * scale) + 127; + pp[96 + 17] = float2int8(p0[12 + 65] * scale) + 127; + pp[96 + 18] = float2int8(p0[14 + 64] * scale) + 127; + pp[96 + 19] = float2int8(p0[14 + 65] * scale) + 127; + pp[96 + 20] = float2int8(p0[12 + 80] * scale) + 127; + pp[96 + 21] = float2int8(p0[12 + 81] * scale) + 127; + pp[96 + 22] = float2int8(p0[14 + 80] * scale) + 127; + pp[96 + 23] = float2int8(p0[14 + 81] * scale) + 127; + pp[96 + 24] = float2int8(p0[12 + 96] * scale) + 127; + pp[96 + 25] = float2int8(p0[12 + 97] * scale) + 127; + pp[96 + 26] = float2int8(p0[14 + 96] * scale) + 127; + pp[96 + 27] = float2int8(p0[14 + 97] * scale) + 127; + pp[96 + 28] = float2int8(p0[12 + 112] * scale) + 127; + pp[96 + 29] = float2int8(p0[12 + 113] * scale) + 127; + pp[96 + 30] = float2int8(p0[14 + 112] * scale) + 127; + pp[96 + 31] = float2int8(p0[14 + 113] * scale) + 127; + + pp += 128; + p0 += B_hstep * 16; + } +#else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { pp[0] = float2int8(p0[0] * scale); @@ -4643,11 +10688,85 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 128; p0 += B_hstep * 16; } +#endif // __AVX512VNNI__ } #endif // __AVX512F__ if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[8] * scale) + 127; + pp[5] = float2int8(p0[9] * scale) + 127; + pp[6] = float2int8(p0[10] * scale) + 127; + pp[7] = float2int8(p0[11] * scale) + 127; + pp[8] = float2int8(p0[16] * scale) + 127; + pp[9] = float2int8(p0[17] * scale) + 127; + pp[10] = float2int8(p0[18] * scale) + 127; + pp[11] = float2int8(p0[19] * scale) + 127; + pp[12] = float2int8(p0[24] * scale) + 127; + pp[13] = float2int8(p0[25] * scale) + 127; + pp[14] = float2int8(p0[26] * scale) + 127; + pp[15] = float2int8(p0[27] * scale) + 127; + pp[16] = float2int8(p0[32] * scale) + 127; + pp[17] = float2int8(p0[33] * scale) + 127; + pp[18] = float2int8(p0[34] * scale) + 127; + pp[19] = float2int8(p0[35] * scale) + 127; + pp[20] = float2int8(p0[40] * scale) + 127; + pp[21] = float2int8(p0[41] * scale) + 127; + pp[22] = float2int8(p0[42] * scale) + 127; + pp[23] = float2int8(p0[43] * scale) + 127; + pp[24] = float2int8(p0[48] * scale) + 127; + pp[25] = float2int8(p0[49] * scale) + 127; + pp[26] = float2int8(p0[50] * scale) + 127; + pp[27] = float2int8(p0[51] * scale) + 127; + pp[28] = float2int8(p0[56] * scale) + 127; + pp[29] = float2int8(p0[57] * scale) + 127; + pp[30] = float2int8(p0[58] * scale) + 127; + pp[31] = float2int8(p0[59] * scale) + 127; + + pp[32 + 0] = float2int8(p0[4] * scale) + 127; + pp[32 + 1] = float2int8(p0[5] * scale) + 127; + pp[32 + 2] = float2int8(p0[6] * scale) + 127; + pp[32 + 3] = float2int8(p0[7] * scale) + 127; + pp[32 + 4] = float2int8(p0[12] * scale) + 127; + pp[32 + 5] = float2int8(p0[13] * scale) + 127; + pp[32 + 6] = float2int8(p0[14] * scale) + 127; + pp[32 + 7] = float2int8(p0[15] * scale) + 127; + pp[32 + 8] = float2int8(p0[20] * scale) + 127; + pp[32 + 9] = float2int8(p0[21] * scale) + 127; + pp[32 + 10] = float2int8(p0[22] * scale) + 127; + pp[32 + 11] = float2int8(p0[23] * scale) + 127; + pp[32 + 12] = float2int8(p0[28] * scale) + 127; + pp[32 + 13] = float2int8(p0[29] * scale) + 127; + pp[32 + 14] = float2int8(p0[30] * scale) + 127; + pp[32 + 15] = float2int8(p0[31] * scale) + 127; + pp[32 + 16] = float2int8(p0[36] * scale) + 127; + pp[32 + 17] = float2int8(p0[37] * scale) + 127; + pp[32 + 18] = float2int8(p0[38] * scale) + 127; + pp[32 + 19] = float2int8(p0[39] * scale) + 127; + pp[32 + 20] = float2int8(p0[44] * scale) + 127; + pp[32 + 21] = float2int8(p0[45] * scale) + 127; + pp[32 + 22] = float2int8(p0[46] * scale) + 127; + pp[32 + 23] = float2int8(p0[47] * scale) + 127; + pp[32 + 24] = float2int8(p0[52] * scale) + 127; + pp[32 + 25] = float2int8(p0[53] * scale) + 127; + pp[32 + 26] = float2int8(p0[54] * scale) + 127; + pp[32 + 27] = float2int8(p0[55] * scale) + 127; + pp[32 + 28] = float2int8(p0[60] * scale) + 127; + pp[32 + 29] = float2int8(p0[61] * scale) + 127; + pp[32 + 30] = float2int8(p0[62] * scale) + 127; + pp[32 + 31] = float2int8(p0[63] * scale) + 127; + + pp += 64; + p0 += B_hstep * 8; + } +#else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { pp[0] = float2int8(p0[0] * scale); @@ -4724,11 +10843,51 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } +#endif // __AVX512VNNI__ } #endif // __AVX__ if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[4] * scale) + 127; + pp[5] = float2int8(p0[5] * scale) + 127; + pp[6] = float2int8(p0[6] * scale) + 127; + pp[7] = float2int8(p0[7] * scale) + 127; + pp[8] = float2int8(p0[8] * scale) + 127; + pp[9] = float2int8(p0[9] * scale) + 127; + pp[10] = float2int8(p0[10] * scale) + 127; + pp[11] = float2int8(p0[11] * scale) + 127; + pp[12] = float2int8(p0[12] * scale) + 127; + pp[13] = float2int8(p0[13] * scale) + 127; + pp[14] = float2int8(p0[14] * scale) + 127; + pp[15] = float2int8(p0[15] * scale) + 127; + pp[16] = float2int8(p0[16] * scale) + 127; + pp[17] = float2int8(p0[17] * scale) + 127; + pp[18] = float2int8(p0[18] * scale) + 127; + pp[19] = float2int8(p0[19] * scale) + 127; + pp[20] = float2int8(p0[20] * scale) + 127; + pp[21] = float2int8(p0[21] * scale) + 127; + pp[22] = float2int8(p0[22] * scale) + 127; + pp[23] = float2int8(p0[23] * scale) + 127; + pp[24] = float2int8(p0[24] * scale) + 127; + pp[25] = float2int8(p0[25] * scale) + 127; + pp[26] = float2int8(p0[26] * scale) + 127; + pp[27] = float2int8(p0[27] * scale) + 127; + pp[28] = float2int8(p0[28] * scale) + 127; + pp[29] = float2int8(p0[29] * scale) + 127; + pp[30] = float2int8(p0[30] * scale) + 127; + pp[31] = float2int8(p0[31] * scale) + 127; + pp += 32; + p0 += B_hstep * 4; + } +#else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale); @@ -4768,10 +10927,50 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 32; p0 += B_hstep * 4; } +#endif // __AVX512VNNI__ } if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[B_hstep] * scale) + 127; + pp[2] = float2int8(p0[B_hstep * 2] * scale) + 127; + pp[3] = float2int8(p0[B_hstep * 3] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; + pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[B_hstep + 2] * scale) + 127; + pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; + pp[11] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[B_hstep + 3] * scale) + 127; + pp[14] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; + pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; + pp[16] = float2int8(p0[4] * scale) + 127; + pp[17] = float2int8(p0[B_hstep + 4] * scale) + 127; + pp[18] = float2int8(p0[B_hstep * 2 + 4] * scale) + 127; + pp[19] = float2int8(p0[B_hstep * 3 + 4] * scale) + 127; + pp[20] = float2int8(p0[5] * scale) + 127; + pp[21] = float2int8(p0[B_hstep + 5] * scale) + 127; + pp[22] = float2int8(p0[B_hstep * 2 + 5] * scale) + 127; + pp[23] = float2int8(p0[B_hstep * 3 + 5] * scale) + 127; + pp[24] = float2int8(p0[6] * scale) + 127; + pp[25] = float2int8(p0[B_hstep + 6] * scale) + 127; + pp[26] = float2int8(p0[B_hstep * 2 + 6] * scale) + 127; + pp[27] = float2int8(p0[B_hstep * 3 + 6] * scale) + 127; + pp[28] = float2int8(p0[7] * scale) + 127; + pp[29] = float2int8(p0[B_hstep + 7] * scale) + 127; + pp[30] = float2int8(p0[B_hstep * 2 + 7] * scale) + 127; + pp[31] = float2int8(p0[B_hstep * 3 + 7] * scale) + 127; + pp += 32; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -4819,6 +11018,81 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int if (elempack == 16) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2 + 0] * scale) + 127; + pp[3] = float2int8(p0[2 + 1] * scale) + 127; + pp[4] = float2int8(p0[16] * scale) + 127; + pp[5] = float2int8(p0[17] * scale) + 127; + pp[6] = float2int8(p0[2 + 16] * scale) + 127; + pp[7] = float2int8(p0[2 + 17] * scale) + 127; + pp[8] = float2int8(p0[32] * scale) + 127; + pp[9] = float2int8(p0[33] * scale) + 127; + pp[10] = float2int8(p0[2 + 32] * scale) + 127; + pp[11] = float2int8(p0[2 + 33] * scale) + 127; + pp[12] = float2int8(p0[48] * scale) + 127; + pp[13] = float2int8(p0[49] * scale) + 127; + pp[14] = float2int8(p0[2 + 48] * scale) + 127; + pp[15] = float2int8(p0[2 + 49] * scale) + 127; + + pp[16 + 0] = float2int8(p0[4 + 0] * scale) + 127; + pp[16 + 1] = float2int8(p0[4 + 1] * scale) + 127; + pp[16 + 2] = float2int8(p0[6 + 0] * scale) + 127; + pp[16 + 3] = float2int8(p0[6 + 1] * scale) + 127; + pp[16 + 4] = float2int8(p0[4 + 16] * scale) + 127; + pp[16 + 5] = float2int8(p0[4 + 17] * scale) + 127; + pp[16 + 6] = float2int8(p0[6 + 16] * scale) + 127; + pp[16 + 7] = float2int8(p0[6 + 17] * scale) + 127; + pp[16 + 8] = float2int8(p0[4 + 32] * scale) + 127; + pp[16 + 9] = float2int8(p0[4 + 33] * scale) + 127; + pp[16 + 10] = float2int8(p0[6 + 32] * scale) + 127; + pp[16 + 11] = float2int8(p0[6 + 33] * scale) + 127; + pp[16 + 12] = float2int8(p0[4 + 48] * scale) + 127; + pp[16 + 13] = float2int8(p0[4 + 49] * scale) + 127; + pp[16 + 14] = float2int8(p0[6 + 48] * scale) + 127; + pp[16 + 15] = float2int8(p0[6 + 49] * scale) + 127; + + pp[32 + 0] = float2int8(p0[8 + 0] * scale) + 127; + pp[32 + 1] = float2int8(p0[8 + 1] * scale) + 127; + pp[32 + 2] = float2int8(p0[10 + 0] * scale) + 127; + pp[32 + 3] = float2int8(p0[10 + 1] * scale) + 127; + pp[32 + 4] = float2int8(p0[8 + 16] * scale) + 127; + pp[32 + 5] = float2int8(p0[8 + 17] * scale) + 127; + pp[32 + 6] = float2int8(p0[10 + 16] * scale) + 127; + pp[32 + 7] = float2int8(p0[10 + 17] * scale) + 127; + pp[32 + 8] = float2int8(p0[8 + 32] * scale) + 127; + pp[32 + 9] = float2int8(p0[8 + 33] * scale) + 127; + pp[32 + 10] = float2int8(p0[10 + 32] * scale) + 127; + pp[32 + 11] = float2int8(p0[10 + 33] * scale) + 127; + pp[32 + 12] = float2int8(p0[8 + 48] * scale) + 127; + pp[32 + 13] = float2int8(p0[8 + 49] * scale) + 127; + pp[32 + 14] = float2int8(p0[10 + 48] * scale) + 127; + pp[32 + 15] = float2int8(p0[10 + 49] * scale) + 127; + + pp[48 + 0] = float2int8(p0[12 + 0] * scale) + 127; + pp[48 + 1] = float2int8(p0[12 + 1] * scale) + 127; + pp[48 + 2] = float2int8(p0[14 + 0] * scale) + 127; + pp[48 + 3] = float2int8(p0[14 + 1] * scale) + 127; + pp[48 + 4] = float2int8(p0[12 + 16] * scale) + 127; + pp[48 + 5] = float2int8(p0[12 + 17] * scale) + 127; + pp[48 + 6] = float2int8(p0[14 + 16] * scale) + 127; + pp[48 + 7] = float2int8(p0[14 + 17] * scale) + 127; + pp[48 + 8] = float2int8(p0[12 + 32] * scale) + 127; + pp[48 + 9] = float2int8(p0[12 + 33] * scale) + 127; + pp[48 + 10] = float2int8(p0[14 + 32] * scale) + 127; + pp[48 + 11] = float2int8(p0[14 + 33] * scale) + 127; + pp[48 + 12] = float2int8(p0[12 + 48] * scale) + 127; + pp[48 + 13] = float2int8(p0[12 + 49] * scale) + 127; + pp[48 + 14] = float2int8(p0[14 + 48] * scale) + 127; + pp[48 + 15] = float2int8(p0[14 + 49] * scale) + 127; + + pp += 64; + p0 += B_hstep * 16; + } +#else // __AVX512VNNI__ for (; kk + 15 < max_kk; kk += 16) { pp[0] = float2int8(p0[0] * scale); @@ -4896,11 +11170,53 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 64; p0 += B_hstep * 16; } +#endif // __AVX512VNNI__ } #endif // __AVX512F__ if (elempack == 8) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[8] * scale) + 127; + pp[5] = float2int8(p0[9] * scale) + 127; + pp[6] = float2int8(p0[10] * scale) + 127; + pp[7] = float2int8(p0[11] * scale) + 127; + pp[8] = float2int8(p0[16] * scale) + 127; + pp[9] = float2int8(p0[17] * scale) + 127; + pp[10] = float2int8(p0[18] * scale) + 127; + pp[11] = float2int8(p0[19] * scale) + 127; + pp[12] = float2int8(p0[24] * scale) + 127; + pp[13] = float2int8(p0[25] * scale) + 127; + pp[14] = float2int8(p0[26] * scale) + 127; + pp[15] = float2int8(p0[27] * scale) + 127; + + pp[16 + 0] = float2int8(p0[4] * scale) + 127; + pp[16 + 1] = float2int8(p0[5] * scale) + 127; + pp[16 + 2] = float2int8(p0[6] * scale) + 127; + pp[16 + 3] = float2int8(p0[7] * scale) + 127; + pp[16 + 4] = float2int8(p0[12] * scale) + 127; + pp[16 + 5] = float2int8(p0[13] * scale) + 127; + pp[16 + 6] = float2int8(p0[14] * scale) + 127; + pp[16 + 7] = float2int8(p0[15] * scale) + 127; + pp[16 + 8] = float2int8(p0[20] * scale) + 127; + pp[16 + 9] = float2int8(p0[21] * scale) + 127; + pp[16 + 10] = float2int8(p0[22] * scale) + 127; + pp[16 + 11] = float2int8(p0[23] * scale) + 127; + pp[16 + 12] = float2int8(p0[28] * scale) + 127; + pp[16 + 13] = float2int8(p0[29] * scale) + 127; + pp[16 + 14] = float2int8(p0[30] * scale) + 127; + pp[16 + 15] = float2int8(p0[31] * scale) + 127; + + pp += 32; + p0 += B_hstep * 8; + } +#else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { pp[0] = float2int8(p0[0] * scale); @@ -4942,11 +11258,36 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 32; p0 += B_hstep * 8; } +#endif // __AVX512VNNI__ } #endif // __AVX__ if (elempack == 4) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[4] * scale) + 127; + pp[5] = float2int8(p0[5] * scale) + 127; + pp[6] = float2int8(p0[6] * scale) + 127; + pp[7] = float2int8(p0[7] * scale) + 127; + pp[8] = float2int8(p0[8] * scale) + 127; + pp[9] = float2int8(p0[9] * scale) + 127; + pp[10] = float2int8(p0[10] * scale) + 127; + pp[11] = float2int8(p0[11] * scale) + 127; + pp[12] = float2int8(p0[12] * scale) + 127; + pp[13] = float2int8(p0[13] * scale) + 127; + pp[14] = float2int8(p0[14] * scale) + 127; + pp[15] = float2int8(p0[15] * scale) + 127; + + pp += 16; + p0 += B_hstep * 4; + } +#else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale); @@ -4969,10 +11310,35 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 16; p0 += B_hstep * 4; } +#endif // __AVX512VNNI__ } if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[B_hstep] * scale) + 127; + pp[2] = float2int8(p0[B_hstep * 2] * scale) + 127; + pp[3] = float2int8(p0[B_hstep * 3] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; + pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; + pp[8] = float2int8(p0[2] * scale) + 127; + pp[9] = float2int8(p0[B_hstep + 2] * scale) + 127; + pp[10] = float2int8(p0[B_hstep * 2 + 2] * scale) + 127; + pp[11] = float2int8(p0[B_hstep * 3 + 2] * scale) + 127; + pp[12] = float2int8(p0[3] * scale) + 127; + pp[13] = float2int8(p0[B_hstep + 3] * scale) + 127; + pp[14] = float2int8(p0[B_hstep * 2 + 3] * scale) + 127; + pp[15] = float2int8(p0[B_hstep * 3 + 3] * scale) + 127; + + pp += 16; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -5011,6 +11377,43 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int int kk = 0; for (; kk + 15 < max_kk; kk += 16) { +#if __AVX512VNNI__ + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[16] * scale) + 127; + pp[5] = float2int8(p0[17] * scale) + 127; + pp[6] = float2int8(p0[18] * scale) + 127; + pp[7] = float2int8(p0[19] * scale) + 127; + + pp[8] = float2int8(p0[4] * scale) + 127; + pp[9] = float2int8(p0[5] * scale) + 127; + pp[10] = float2int8(p0[6] * scale) + 127; + pp[11] = float2int8(p0[7] * scale) + 127; + pp[12] = float2int8(p0[20] * scale) + 127; + pp[13] = float2int8(p0[21] * scale) + 127; + pp[14] = float2int8(p0[22] * scale) + 127; + pp[15] = float2int8(p0[23] * scale) + 127; + + pp[16 + 0] = float2int8(p0[8] * scale) + 127; + pp[16 + 1] = float2int8(p0[9] * scale) + 127; + pp[16 + 2] = float2int8(p0[10] * scale) + 127; + pp[16 + 3] = float2int8(p0[11] * scale) + 127; + pp[16 + 4] = float2int8(p0[24] * scale) + 127; + pp[16 + 5] = float2int8(p0[25] * scale) + 127; + pp[16 + 6] = float2int8(p0[26] * scale) + 127; + pp[16 + 7] = float2int8(p0[27] * scale) + 127; + + pp[16 + 8] = float2int8(p0[12] * scale) + 127; + pp[16 + 9] = float2int8(p0[13] * scale) + 127; + pp[16 + 10] = float2int8(p0[14] * scale) + 127; + pp[16 + 11] = float2int8(p0[15] * scale) + 127; + pp[16 + 12] = float2int8(p0[28] * scale) + 127; + pp[16 + 13] = float2int8(p0[29] * scale) + 127; + pp[16 + 14] = float2int8(p0[30] * scale) + 127; + pp[16 + 15] = float2int8(p0[31] * scale) + 127; +#else // __AVX512VNNI__ pp[0] = float2int8(p0[0] * scale); pp[1] = float2int8(p0[1] * scale); pp[2] = float2int8(p0[16] * scale); @@ -5050,7 +11453,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[16 + 13] = float2int8(p0[15] * scale); pp[16 + 14] = float2int8(p0[30] * scale); pp[16 + 15] = float2int8(p0[31] * scale); - +#endif // __AVX512VNNI__ pp += 32; p0 += B_hstep * 16; } @@ -5061,6 +11464,24 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int int kk = 0; for (; kk + 7 < max_kk; kk += 8) { +#if __AVX512VNNI__ + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[8] * scale) + 127; + pp[5] = float2int8(p0[9] * scale) + 127; + pp[6] = float2int8(p0[10] * scale) + 127; + pp[7] = float2int8(p0[11] * scale) + 127; + pp[8] = float2int8(p0[4] * scale) + 127; + pp[9] = float2int8(p0[5] * scale) + 127; + pp[10] = float2int8(p0[6] * scale) + 127; + pp[11] = float2int8(p0[7] * scale) + 127; + pp[12] = float2int8(p0[12] * scale) + 127; + pp[13] = float2int8(p0[13] * scale) + 127; + pp[14] = float2int8(p0[14] * scale) + 127; + pp[15] = float2int8(p0[15] * scale) + 127; +#else // __AVX512VNNI__ pp[0] = float2int8(p0[0] * scale); pp[1] = float2int8(p0[1] * scale); pp[2] = float2int8(p0[8] * scale); @@ -5077,7 +11498,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[13] = float2int8(p0[7] * scale); pp[14] = float2int8(p0[14] * scale); pp[15] = float2int8(p0[15] * scale); - +#endif // __AVX512VNNI__ pp += 16; p0 += B_hstep * 8; } @@ -5088,6 +11509,16 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int int kk = 0; for (; kk + 3 < max_kk; kk += 4) { +#if __AVX512VNNI__ + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[1] * scale) + 127; + pp[2] = float2int8(p0[2] * scale) + 127; + pp[3] = float2int8(p0[3] * scale) + 127; + pp[4] = float2int8(p0[4] * scale) + 127; + pp[5] = float2int8(p0[5] * scale) + 127; + pp[6] = float2int8(p0[6] * scale) + 127; + pp[7] = float2int8(p0[7] * scale) + 127; +#else // __AVX512VNNI__ pp[0] = float2int8(p0[0] * scale); pp[1] = float2int8(p0[1] * scale); pp[2] = float2int8(p0[4] * scale); @@ -5096,7 +11527,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[5] = float2int8(p0[3] * scale); pp[6] = float2int8(p0[6] * scale); pp[7] = float2int8(p0[7] * scale); - +#endif // __AVX512VNNI__ pp += 8; p0 += B_hstep * 4; } @@ -5106,6 +11537,21 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int { int kk = 0; #if __SSE2__ +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[B_hstep + 0] * scale) + 127; + pp[2] = float2int8(p0[B_hstep * 2 + 0] * scale) + 127; + pp[3] = float2int8(p0[B_hstep * 3 + 0] * scale) + 127; + pp[4] = float2int8(p0[1] * scale) + 127; + pp[5] = float2int8(p0[B_hstep + 1] * scale) + 127; + pp[6] = float2int8(p0[B_hstep * 2 + 1] * scale) + 127; + pp[7] = float2int8(p0[B_hstep * 3 + 1] * scale) + 127; + pp += 8; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale); @@ -5153,6 +11599,24 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[13] = float2int8(p0[13] * scale); pp[14] = float2int8(p0[14] * scale); pp[15] = float2int8(p0[15] * scale); +#if __AVX512VNNI__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; + pp[8] += 127; + pp[9] += 127; + pp[10] += 127; + pp[11] += 127; + pp[12] += 127; + pp[13] += 127; + pp[14] += 127; + pp[15] += 127; +#endif // __AVX512VNNI__ pp += 16; p0 += B_hstep * 16; } @@ -5171,6 +11635,16 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[5] = float2int8(p0[5] * scale); pp[6] = float2int8(p0[6] * scale); pp[7] = float2int8(p0[7] * scale); +#if __AVX512VNNI__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; + pp[4] += 127; + pp[5] += 127; + pp[6] += 127; + pp[7] += 127; +#endif // __AVX512VNNI__ pp += 8; p0 += B_hstep * 8; } @@ -5185,6 +11659,12 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp[1] = float2int8(p0[1] * scale); pp[2] = float2int8(p0[2] * scale); pp[3] = float2int8(p0[3] * scale); +#if __AVX512VNNI__ + pp[0] += 127; + pp[1] += 127; + pp[2] += 127; + pp[3] += 127; +#endif // __AVX512VNNI__ pp += 4; p0 += B_hstep * 4; } @@ -5193,6 +11673,17 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int if (elempack == 1) { int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale) + 127; + pp[1] = float2int8(p0[B_hstep] * scale) + 127; + pp[2] = float2int8(p0[B_hstep * 2] * scale) + 127; + pp[3] = float2int8(p0[B_hstep * 3] * scale) + 127; + pp += 4; + p0 += B_hstep * 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -10381,6 +16872,22 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) { +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + gemm_transB_packed_tile_int8_avx512vnni(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + gemm_transB_packed_tile_int8_avxvnni(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { @@ -10475,6 +16982,61 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pA3 = _mm512_shuffle_epi32(_pA2, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_i32x4(_pB0, _pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm512_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm512_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm512_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm512_dpbusd_epi32(_sum7, _pB3, _pA1); + _sum8 = _mm512_dpbusd_epi32(_sum8, _pB0, _pA2); + _sum9 = _mm512_dpbusd_epi32(_sum9, _pB1, _pA2); + _suma = _mm512_dpbusd_epi32(_suma, _pB0, _pA3); + _sumb = _mm512_dpbusd_epi32(_sumb, _pB1, _pA3); + _sumc = _mm512_dpbusd_epi32(_sumc, _pB2, _pA2); + _sumd = _mm512_dpbusd_epi32(_sumd, _pB3, _pA2); + _sume = _mm512_dpbusd_epi32(_sume, _pB2, _pA3); + _sumf = _mm512_dpbusd_epi32(_sumf, _pB3, _pA3); + pA += 64; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + __m512i _w_shift2 = _mm512_shuffle_i32x4(_w_shift0, _w_shift0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _w_shift3 = _mm512_shuffle_epi32(_w_shift2, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + _sum4 = _mm512_sub_epi32(_sum4, _w_shift0); + _sum5 = _mm512_sub_epi32(_sum5, _w_shift0); + _sum6 = _mm512_sub_epi32(_sum6, _w_shift1); + _sum7 = _mm512_sub_epi32(_sum7, _w_shift1); + _sum8 = _mm512_sub_epi32(_sum8, _w_shift2); + _sum9 = _mm512_sub_epi32(_sum9, _w_shift2); + _suma = _mm512_sub_epi32(_suma, _w_shift3); + _sumb = _mm512_sub_epi32(_sumb, _w_shift3); + _sumc = _mm512_sub_epi32(_sumc, _w_shift2); + _sumd = _mm512_sub_epi32(_sumd, _w_shift2); + _sume = _mm512_sub_epi32(_sume, _w_shift3); + _sumf = _mm512_sub_epi32(_sumf, _w_shift3); + pA += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); @@ -10636,6 +17198,42 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm512_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm512_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm512_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm512_dpbusd_epi32(_sum7, _pB3, _pA1); + pA += 64; + pB += 32; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + _sum4 = _mm512_sub_epi32(_sum4, _w_shift0); + _sum5 = _mm512_sub_epi32(_sum5, _w_shift0); + _sum6 = _mm512_sub_epi32(_sum6, _w_shift1); + _sum7 = _mm512_sub_epi32(_sum7, _w_shift1); + pA += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); @@ -10748,6 +17346,31 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pB)); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + pA += 64; + pB += 16; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + pA += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); @@ -10827,6 +17450,25 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 64; + pB += 8; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_loadu_si512((const __m512i*)pA); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift); + pA += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); @@ -10891,6 +17533,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_set1_epi32(((const int*)pB)[0]); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB, _pA); + pA += 64; + pB += 4; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_loadu_si512((const __m512i*)pA); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + pA += 64; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); @@ -10927,6 +17585,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } pAT += max_kk * 16; +#if __AVX512VNNI__ + if (max_kk >= 4) + { + pAT += 64; + } +#endif // __AVX512VNNI__ } #endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) @@ -10973,6 +17637,43 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA00); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA00); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA11); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA11); + _sum4 = _mm512_dpbusd_epi32(_sum4, _pB2, _pA00); + _sum5 = _mm512_dpbusd_epi32(_sum5, _pB3, _pA00); + _sum6 = _mm512_dpbusd_epi32(_sum6, _pB2, _pA11); + _sum7 = _mm512_dpbusd_epi32(_sum7, _pB3, _pA11); + pA += 32; + pB += 64; + } + if (max_kk >= 4) + { + __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _w_shift00 = _mm512_inserti32x8(_mm512_castsi256_si512(_w_shift0), _w_shift0, 1); + __m512i _w_shift11 = _mm512_shuffle_epi32(_w_shift00, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift00); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift00); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift11); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift11); + _sum4 = _mm512_sub_epi32(_sum4, _w_shift00); + _sum5 = _mm512_sub_epi32(_sum5, _w_shift00); + _sum6 = _mm512_sub_epi32(_sum6, _w_shift11); + _sum7 = _mm512_sub_epi32(_sum7, _w_shift11); + pA += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadu_si128((const __m128i*)pA); @@ -11114,6 +17815,41 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(0, 1, 2, 3)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + _sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm256_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm256_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm256_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm256_dpbusd_epi32(_sum7, _pB3, _pA1); + pA += 32; + pB += 32; + } + if (max_kk >= 4) + { + __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _w_shift1 = _mm256_permute4x64_epi64(_w_shift0, _MM_SHUFFLE(0, 1, 2, 3)); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm256_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm256_sub_epi32(_sum3, _w_shift1); + _sum4 = _mm256_sub_epi32(_sum4, _w_shift0); + _sum5 = _mm256_sub_epi32(_sum5, _w_shift0); + _sum6 = _mm256_sub_epi32(_sum6, _w_shift1); + _sum7 = _mm256_sub_epi32(_sum7, _w_shift1); + pA += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadu_si128((const __m128i*)pA); @@ -11236,6 +17972,32 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm256_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm256_dpbusd_epi32(_sum3, _pB1, _pA1); + pA += 32; + pB += 16; + } + if (max_kk >= 4) + { + __m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _w_shift1 = _mm256_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm256_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm256_sub_epi32(_sum3, _w_shift1); + pA += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadu_si128((const __m128i*)pA); @@ -11318,6 +18080,25 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + _sum0 = _mm256_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm256_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 32; + pB += 8; + } + if (max_kk >= 4) + { + __m256i _w_shift = _mm256_loadu_si256((const __m256i*)pA); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift); + _sum1 = _mm256_sub_epi32(_sum1, _w_shift); + pA += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadu_si128((const __m128i*)pA); @@ -11384,6 +18165,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); + _sum0 = _mm256_dpbusd_epi32(_sum0, _pB, _pA); + pA += 32; + pB += 4; + } + if (max_kk >= 4) + { + __m256i _w_shift = _mm256_loadu_si256((const __m256i*)pA); + _sum0 = _mm256_sub_epi32(_sum0, _w_shift); + pA += 32; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadu_si128((const __m128i*)pA); @@ -11421,6 +18218,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } pAT += max_kk * 8; +#if __AVX512VNNI__ + if (max_kk >= 4) + { + pAT += 32; + } +#endif // __AVX512VNNI__ } #endif // __AVX2__ for (; ii + 3 < max_ii; ii += 4) @@ -11455,6 +18258,31 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pA)); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm512_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm512_dpbusd_epi32(_sum3, _pB1, _pA1); + pA += 16; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift0 = _mm512_broadcast_i32x4(_mm_loadu_si128((const __m128i*)pA)); + __m512i _w_shift1 = _mm512_shuffle_epi32(_w_shift0, _MM_PERM_BADC); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm512_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm512_sub_epi32(_sum3, _w_shift1); + pA += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); @@ -11554,6 +18382,41 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 16)); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1); + _sum4 = _mm_dpbusd_epi32(_sum4, _pB2, _pA0); + _sum5 = _mm_dpbusd_epi32(_sum5, _pB3, _pA0); + _sum6 = _mm_dpbusd_epi32(_sum6, _pB2, _pA1); + _sum7 = _mm_dpbusd_epi32(_sum7, _pB3, _pA1); + pA += 16; + pB += 32; + } + if (max_kk >= 4) + { + __m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _w_shift1 = _mm_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + _sum0 = _mm_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm_sub_epi32(_sum3, _w_shift1); + _sum4 = _mm_sub_epi32(_sum4, _w_shift0); + _sum5 = _mm_sub_epi32(_sum5, _w_shift0); + _sum6 = _mm_sub_epi32(_sum6, _w_shift1); + _sum7 = _mm_sub_epi32(_sum7, _w_shift1); + pA += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); @@ -11719,6 +18582,31 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1); + pA += 16; + pB += 16; + } + if (max_kk >= 4) + { + __m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _w_shift1 = _mm_shuffle_epi32(_w_shift0, _MM_SHUFFLE(1, 0, 3, 2)); + _sum0 = _mm_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm_sub_epi32(_sum3, _w_shift1); + pA += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); @@ -11839,6 +18727,25 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 16; + pB += 8; + } + if (max_kk >= 4) + { + __m128i _w_shift = _mm_loadu_si128((const __m128i*)pA); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + _sum1 = _mm_sub_epi32(_sum1, _w_shift); + pA += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); @@ -11933,6 +18840,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB, _pA); + pA += 16; + pB += 4; + } + if (max_kk >= 4) + { + __m128i _w_shift = _mm_loadu_si128((const __m128i*)pA); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + pA += 16; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); @@ -11988,6 +18911,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } pAT += max_kk * 4; +#if __AVX512VNNI__ + if (max_kk >= 4) + { + pAT += 16; + } +#endif // __AVX512VNNI__ } #endif // __SSE2__ for (; ii + 1 < max_ii; ii += 2) @@ -12016,6 +18945,25 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pA)[0])); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm512_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 8; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pA)[0])); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + _sum1 = _mm512_sub_epi32(_sum1, _w_shift); + pA += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pA)); @@ -12090,6 +19038,31 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA0 = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 16)); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA0); + _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA0); + _sum2 = _mm_dpbusd_epi32(_sum2, _pB0, _pA1); + _sum3 = _mm_dpbusd_epi32(_sum3, _pB1, _pA1); + pA += 8; + pB += 32; + } + if (max_kk >= 4) + { + __m128i _w_shift0 = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _w_shift1 = _mm_shuffle_epi32(_w_shift0, _MM_SHUFFLE(2, 3, 0, 1)); + _sum0 = _mm_sub_epi32(_sum0, _w_shift0); + _sum1 = _mm_sub_epi32(_sum1, _w_shift0); + _sum2 = _mm_sub_epi32(_sum2, _w_shift1); + _sum3 = _mm_sub_epi32(_sum3, _w_shift1); + pA += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); @@ -12192,6 +19165,25 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 8; + pB += 16; + } + if (max_kk >= 4) + { + __m128i _w_shift = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + _sum1 = _mm_sub_epi32(_sum1, _w_shift); + pA += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); @@ -12283,6 +19275,39 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + sum00 += pA[0] * ((unsigned char*)pB)[0]; + sum00 += pA[1] * ((unsigned char*)pB)[1]; + sum00 += pA[2] * ((unsigned char*)pB)[2]; + sum00 += pA[3] * ((unsigned char*)pB)[3]; + sum01 += pA[0] * ((unsigned char*)pB)[4]; + sum01 += pA[1] * ((unsigned char*)pB)[5]; + sum01 += pA[2] * ((unsigned char*)pB)[6]; + sum01 += pA[3] * ((unsigned char*)pB)[7]; + sum10 += pA[4] * ((unsigned char*)pB)[0]; + sum10 += pA[5] * ((unsigned char*)pB)[1]; + sum10 += pA[6] * ((unsigned char*)pB)[2]; + sum10 += pA[7] * ((unsigned char*)pB)[3]; + sum11 += pA[4] * ((unsigned char*)pB)[4]; + sum11 += pA[5] * ((unsigned char*)pB)[5]; + sum11 += pA[6] * ((unsigned char*)pB)[6]; + sum11 += pA[7] * ((unsigned char*)pB)[7]; + pA += 8; + pB += 8; + } + if (max_kk >= 4) + { + int w_shift0 = ((int*)pA)[0]; + int w_shift1 = ((int*)pA)[1]; + sum00 = sum00 - w_shift0; + sum01 = sum01 - w_shift0; + sum10 = sum10 - w_shift1; + sum11 = sum11 - w_shift1; + pA += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { sum00 += pA[0] * pB[0]; @@ -12332,6 +19357,29 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + sum0 += pA[0] * ((unsigned char*)pB)[0]; + sum0 += pA[1] * ((unsigned char*)pB)[1]; + sum0 += pA[2] * ((unsigned char*)pB)[2]; + sum0 += pA[3] * ((unsigned char*)pB)[3]; + sum1 += pA[4] * ((unsigned char*)pB)[0]; + sum1 += pA[5] * ((unsigned char*)pB)[1]; + sum1 += pA[6] * ((unsigned char*)pB)[2]; + sum1 += pA[7] * ((unsigned char*)pB)[3]; + pA += 8; + pB += 4; + } + if (max_kk >= 4) + { + int w_shift0 = ((int*)pA)[0]; + int w_shift1 = ((int*)pA)[1]; + sum0 = sum0 - w_shift0; + sum1 = sum1 - w_shift1; + pA += 8; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { sum0 += pA[0] * pB[0]; @@ -12356,6 +19404,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } pAT += max_kk * 2; +#if __AVX512VNNI__ + if (max_kk >= 4) + { + pAT += 8; + } +#endif // __AVX512VNNI__ } for (; ii < max_ii; ii += 1) { @@ -12380,6 +19434,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m512i _pA = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pB = _mm512_loadu_si512((const __m512i*)pB); + _sum0 = _mm512_dpbusd_epi32(_sum0, _pB, _pA); + pA += 4; + pB += 64; + } + if (max_kk >= 4) + { + __m512i _w_shift = _mm512_set1_epi32(((const int*)pA)[0]); + _sum0 = _mm512_sub_epi32(_sum0, _w_shift); + pA += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m256i _pA = _mm256_set1_epi16(((const short*)pA)[0]); @@ -12431,6 +19501,25 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 16)); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB0, _pA); + _sum1 = _mm_dpbusd_epi32(_sum1, _pB1, _pA); + pA += 4; + pB += 32; + } + if (max_kk >= 4) + { + __m128i _w_shift = _mm_set1_epi32(((const int*)pA)[0]); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + _sum1 = _mm_sub_epi32(_sum1, _w_shift); + pA += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_set1_epi16(((const short*)pA)[0]); @@ -12501,6 +19590,22 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + _sum0 = _mm_dpbusd_epi32(_sum0, _pB, _pA); + pA += 4; + pB += 16; + } + if (max_kk >= 4) + { + __m128i _w_shift = _mm_set1_epi32(((const int*)pA)[0]); + _sum0 = _mm_sub_epi32(_sum0, _w_shift); + pA += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); @@ -12570,6 +19675,28 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + sum0 += pA[0] * ((unsigned char*)pB)[0]; + sum0 += pA[1] * ((unsigned char*)pB)[1]; + sum0 += pA[2] * ((unsigned char*)pB)[2]; + sum0 += pA[3] * ((unsigned char*)pB)[3]; + sum1 += pA[0] * ((unsigned char*)pB)[4]; + sum1 += pA[1] * ((unsigned char*)pB)[5]; + sum1 += pA[2] * ((unsigned char*)pB)[6]; + sum1 += pA[3] * ((unsigned char*)pB)[7]; + pA += 4; + pB += 8; + } + if (max_kk >= 4) + { + int w_shift = ((const int*)pA)[0]; + sum0 = sum0 - w_shift; + sum1 = sum1 - w_shift; + pA += 4; + } +#endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { sum0 += pA[0] * pB[0]; @@ -12607,6 +19734,23 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, const signed char* pA = pAT; int kk = 0; +#if __AVX512VNNI__ + for (; kk + 3 < max_kk; kk += 4) + { + sum += pA[0] * ((unsigned char*)pB)[0]; + sum += pA[1] * ((unsigned char*)pB)[1]; + sum += pA[2] * ((unsigned char*)pB)[2]; + sum += pA[3] * ((unsigned char*)pB)[3]; + pA += 4; + pB += 4; + } + if (max_kk >= 4) + { + int w_shift = ((const int*)pA)[0]; + sum = sum - w_shift; + pA += 4; + } +#endif // __AVX512VNNI__ for (; kk < max_kk; kk += 1) { sum += pA[0] * pB[0]; @@ -12620,6 +19764,12 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, } pAT += max_kk; +#if __AVX512VNNI__ + if (max_kk >= 4) + { + pAT += 4; + } +#endif // __AVX512VNNI__ } } diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index d0727d717a80..9eb9671c795e 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -7638,7 +7638,19 @@ static int gemm_x86_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob int nn_N = (N + TILE_N - 1) / TILE_N; int nn_K = (K + TILE_K - 1) / TILE_K; - Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + Mat ATX; +#if NCNN_AVX512VNNI || NCNN_AVXVNNI + if (TILE_K >= 4 && (ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni())) + { + int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1; + // NCNN_LOGE("w_shift_count = %d", w_shift_count); + ATX.create(TILE_K * TILE_M + w_shift_count * 4, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } + else +#endif // NCNN_AVX512VNNI || NCNN_AVXVNNI + { + ATX.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } if (ATX.empty()) return -100; Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 1u, opt.workspace_allocator); @@ -7891,7 +7903,18 @@ static int gemm_BT_x86_int8(const Mat& A, const Mat& BT, float B_int8_scale, con // NCNN_LOGE("scale %.4f %.4f", A_int8_scale, B_int8_scale); - Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + Mat ATX; +#if NCNN_AVX512VNNI || NCNN_AVXVNNI + if (TILE_K >= 4 && (ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni())) + { + int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1; + ATX.create(TILE_K * TILE_M + w_shift_count * 4, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } + else +#endif // NCNN_AVX512VNNI || NCNN_AVXVNNI + { + ATX.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 1u, opt.workspace_allocator); + } if (ATX.empty()) return -100; @@ -8053,7 +8076,17 @@ int Gemm_x86::create_pipeline_int8(const Option& opt) const int nn_M = (M + TILE_M - 1) / TILE_M; - AT_data.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 1u, (Allocator*)0); +#if NCNN_AVX512VNNI || NCNN_AVXVNNI + if (TILE_K >= 4 && (ncnn::cpu_support_x86_avx512_vnni() || ncnn::cpu_support_x86_avx_vnni())) + { + int w_shift_count = TILE_M >= 16 ? 16 : TILE_M >= 8 ? 8 : TILE_M >= 4 ? 4 : TILE_M >= 2 ? 2 : 1; + AT_data.create(TILE_K * TILE_M + w_shift_count * 4, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 1u, (Allocator*)0); + } + else +#endif // NCNN_AVX512VNNI || NCNN_AVXVNNI + { + AT_data.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 1u, (Allocator*)0); + } if (AT_data.empty()) return -100; diff --git a/src/layer/x86/gemm_x86_avx512vnni.cpp b/src/layer/x86/gemm_x86_avx512vnni.cpp new file mode 100644 index 000000000000..fd72dd66d205 --- /dev/null +++ b/src/layer/x86/gemm_x86_avx512vnni.cpp @@ -0,0 +1,79 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +void pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_avx512vnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_avx512vnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void gemm_transB_packed_tile_int8_avx512vnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn diff --git a/src/layer/x86/gemm_x86_avxvnni.cpp b/src/layer/x86/gemm_x86_avxvnni.cpp new file mode 100644 index 000000000000..b738df821793 --- /dev/null +++ b/src/layer/x86/gemm_x86_avxvnni.cpp @@ -0,0 +1,76 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +void pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_avxvnni(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_avxvnni(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void gemm_transB_packed_tile_int8_avxvnni(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +} // namespace ncnn