diff --git a/src/layer/x86/convolution_3x3_int8.h b/src/layer/x86/convolution_3x3_int8.h index a5c5dfe4e71d..ceaf75b92e1f 100644 --- a/src/layer/x86/convolution_3x3_int8.h +++ b/src/layer/x86/convolution_3x3_int8.h @@ -78,833 +78,6 @@ static void conv3x3s1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& } } -static void conv3x3s1_winograd23_transform_kernel_int8_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - kernel_tm.create(4 * 4, inch, outch, (size_t)2u); - - // G - const short ktm[4][3] = { - {2, 0, 0}, - {1, 1, 1}, - {1, -1, 1}, - {0, 0, 2} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[4][3]; - for (int i = 0; i < 4; i++) - { - tmp[i][0] = (short)k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = (short)k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = (short)k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 4; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 4; i++) - { - kernel_tm0[j * 4 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } -} - -static void conv3x3s1_winograd23_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 2n+2, winograd F(2,3) - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 1) / 2 * 2; - outh = (outh + 1) / 2 * 2; - - w = outw + 2; - h = outh + 2; - Option opt_b = opt; - opt_b.blob_allocator = opt.workspace_allocator; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - const int tiles = nColBlocks * nRowBlocks; - - bottom_blob_tm.create(4 * 4, tiles, inch, 2u, opt.workspace_allocator); - - // BT - // const float itm[4][4] = { - // {1.0f, 0.0f, -1.0f, 0.0f}, - // {0.0f, 1.0f, 1.00f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 0.0f}, - // {0.0f, -1.0f, 0.00f, 1.0f} - // }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const signed char* img = bottom_blob_bordered.channel(q); - short* out_tm0 = bottom_blob_tm.channel(q); - - for (int j = 0; j < nColBlocks; j++) - { - const signed char* r0 = img + w * j * 2; - const signed char* r1 = r0 + w; - const signed char* r2 = r1 + w; - const signed char* r3 = r2 + w; - - for (int i = 0; i < nRowBlocks; i++) - { - short d0[4], d1[4], d2[4], d3[4]; - short w0[4], w1[4], w2[4], w3[4]; - short t0[4], t1[4], t2[4], t3[4]; - // load - for (int n = 0; n < 4; n++) - { - d0[n] = r0[n]; - d1[n] = r1[n]; - d2[n] = r2[n]; - d3[n] = r3[n]; - } - // w = B_t * d - for (int n = 0; n < 4; n++) - { - w0[n] = d0[n] - d2[n]; - w1[n] = d1[n] + d2[n]; - w2[n] = d2[n] - d1[n]; - w3[n] = d3[n] - d1[n]; - } - // transpose d to d_t - { - t0[0] = w0[0]; - t1[0] = w0[1]; - t2[0] = w0[2]; - t3[0] = w0[3]; - t0[1] = w1[0]; - t1[1] = w1[1]; - t2[1] = w1[2]; - t3[1] = w1[3]; - t0[2] = w2[0]; - t1[2] = w2[1]; - t2[2] = w2[2]; - t3[2] = w2[3]; - t0[3] = w3[0]; - t1[3] = w3[1]; - t2[3] = w3[2]; - t3[3] = w3[3]; - } - // U = B_t * d_t - for (int n = 0; n < 4; n++) - { - d0[n] = t0[n] - t2[n]; - d1[n] = t1[n] + t2[n]; - d2[n] = t2[n] - t1[n]; - d3[n] = t3[n] - t1[n]; - } - // save to out_tm - for (int n = 0; n < 4; n++) - { - out_tm0[n] = d0[n]; - out_tm0[n + 4] = d1[n]; - out_tm0[n + 8] = d2[n]; - out_tm0[n + 12] = d3[n]; - } - - r0 += 2; - r1 += 2; - r2 += 2; - r3 += 2; - - out_tm0 += 16; - } - } - } - } - bottom_blob_bordered = Mat(); - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - const int tiles = nColBlocks * nRowBlocks; - - top_blob_tm.create(16, tiles, outch, 4u, opt.workspace_allocator); - - int nn_outch = outch >> 2; - int remain_outch_start = nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 4; - - Mat out0_tm = top_blob_tm.channel(p); - Mat out1_tm = top_blob_tm.channel(p + 1); - Mat out2_tm = top_blob_tm.channel(p + 2); - Mat out3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p); - const Mat kernel1_tm = kernel_tm.channel(p + 1); - const Mat kernel2_tm = kernel_tm.channel(p + 2); - const Mat kernel3_tm = kernel_tm.channel(p + 3); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - int* output1_tm = out1_tm.row(i); - int* output2_tm = out2_tm.row(i); - int* output3_tm = out3_tm.row(i); - - int sum0[16] = {0}; - int sum1[16] = {0}; - int sum2[16] = {0}; - int sum3[16] = {0}; - - int q = 0; - for (; q + 3 < inch; q += 4) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* r1 = bottom_blob_tm.channel(q + 1).row(i); - const short* r2 = bottom_blob_tm.channel(q + 2).row(i); - const short* r3 = bottom_blob_tm.channel(q + 3).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel1_tm.row(q); - const short* k2 = kernel2_tm.row(q); - const short* k3 = kernel3_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r1[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r2[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r3[n] * k0[n]; - k0 -= 16 * 3; - - sum1[n] += (int)r0[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r1[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r2[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r3[n] * k1[n]; - k1 -= 16 * 3; - - sum2[n] += (int)r0[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r1[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r2[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r3[n] * k2[n]; - k2 -= 16 * 3; - - sum3[n] += (int)r0[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r1[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r2[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r3[n] * k3[n]; - k3 -= 16 * 3; - } - } - - for (; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel1_tm.row(q); - const short* k2 = kernel2_tm.row(q); - const short* k3 = kernel3_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - sum1[n] += (int)r0[n] * k1[n]; - sum2[n] += (int)r0[n] * k2[n]; - sum3[n] += (int)r0[n] * k3[n]; - } - } - - for (int n = 0; n < 16; n++) - { - output0_tm[n] = sum0[n]; - output1_tm[n] = sum1[n]; - output2_tm[n] = sum2[n]; - output3_tm[n] = sum3[n]; - } - } - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - Mat out0_tm = top_blob_tm.channel(p); - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - - int sum0[16] = {0}; - - int q = 0; - for (; q + 3 < inch; q += 4) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* r1 = bottom_blob_tm.channel(q + 1).row(i); - const short* r2 = bottom_blob_tm.channel(q + 2).row(i); - const short* r3 = bottom_blob_tm.channel(q + 3).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel0_tm.row(q + 1); - const short* k2 = kernel0_tm.row(q + 2); - const short* k3 = kernel0_tm.row(q + 3); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - sum0[n] += (int)r1[n] * k1[n]; - sum0[n] += (int)r2[n] * k2[n]; - sum0[n] += (int)r3[n] * k3[n]; - } - } - - for (; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* k0 = kernel0_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - } - } - - for (int n = 0; n < 16; n++) - { - output0_tm[n] = sum0[n]; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator); - { - // AT - // const float itm[2][4] = { - // {1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 1.0f} - // }; - - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out_tm = top_blob_tm.channel(p); - Mat out = top_blob_bordered.channel(p); - - for (int j = 0; j < nColBlocks; j++) - { - int* outRow0 = out.row(j * 2); - int* outRow1 = out.row(j * 2 + 1); - - for (int i = 0; i < nRowBlocks; i++) - { - int* out_tile = out_tm.row(j * nRowBlocks + i); - - int s0[4], s1[4], s2[4], s3[4]; - int w0[4], w1[4]; - int d0[2], d1[2], d2[2], d3[2]; - int o0[2], o1[2]; - // load - for (int n = 0; n < 4; n++) - { - s0[n] = out_tile[n]; - s1[n] = out_tile[n + 4]; - s2[n] = out_tile[n + 8]; - s3[n] = out_tile[n + 12]; - } - // w = A_T * W - for (int n = 0; n < 4; n++) - { - w0[n] = s0[n] + s1[n] + s2[n]; - w1[n] = s1[n] - s2[n] + s3[n]; - } - // transpose w to w_t - { - d0[0] = w0[0]; - d0[1] = w1[0]; - d1[0] = w0[1]; - d1[1] = w1[1]; - d2[0] = w0[2]; - d2[1] = w1[2]; - d3[0] = w0[3]; - d3[1] = w1[3]; - } - // Y = A_T * w_t - for (int n = 0; n < 2; n++) - { - o0[n] = d0[n] + d1[n] + d2[n]; - o1[n] = d1[n] - d2[n] + d3[n]; - } - // save to top blob tm,why right 2,because the G' = G*2 - outRow0[0] = o0[0] >> 2; - outRow0[1] = o0[1] >> 2; - outRow1[0] = o1[0] >> 2; - outRow1[1] = o1[1] >> 2; - - outRow0 += 2; - outRow1 += 2; - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - -static void conv3x3s1_winograd43_transform_kernel_int8_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - kernel_tm.create(6 * 6, inch, outch, (size_t)2u); - - // G - // const float ktm[6][3] = { - // { 1.0f/4, 0.0f, 0.0f}, - // { -1.0f/6, -1.0f/6, -1.0f/6}, - // { -1.0f/6, 1.0f/6, -1.0f/6}, - // { 1.0f/24, 1.0f/12, 1.0f/6}, - // { 1.0f/24, -1.0f/12, 1.0f/6}, - // { 0.0f, 0.0f, 1.0f} - // }; - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 24} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } -} - -static void conv3x3s1_winograd43_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2, winograd F(4,3) - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - Option opt_b = opt; - opt_b.blob_allocator = opt.workspace_allocator; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - const int tiles = nColBlocks * nRowBlocks; - - bottom_blob_tm.create(6 * 6, tiles, inch, 2u, opt.workspace_allocator); - - // BT - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r03 + r04 - // 2 = 4 * (r01 - r02) - r03 + r04 - // 3 = -2 * r01 - r02 + 2 * r03 + r04 - // 4 = 2 * r01 - r02 - 2 * r03 + r04 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const signed char* img = bottom_blob_bordered.channel(q); - short* out_tm0 = bottom_blob_tm.channel(q); - - for (int j = 0; j < nColBlocks; j++) - { - const signed char* r0 = img + w * j * 4; - const signed char* r1 = r0 + w; - const signed char* r2 = r1 + w; - const signed char* r3 = r2 + w; - const signed char* r4 = r3 + w; - const signed char* r5 = r4 + w; - - for (int i = 0; i < nRowBlocks; i++) - { - short d0[6], d1[6], d2[6], d3[6], d4[6], d5[6]; - short w0[6], w1[6], w2[6], w3[6], w4[6], w5[6]; - short t0[6], t1[6], t2[6], t3[6], t4[6], t5[6]; - - // load - for (int n = 0; n < 6; n++) - { - d0[n] = r0[n]; - d1[n] = r1[n]; - d2[n] = r2[n]; - d3[n] = r3[n]; - d4[n] = r4[n]; - d5[n] = r5[n]; - } - // w = B_t * d - for (int n = 0; n < 6; n++) - { - w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n]; - w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n]; - w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n]; - w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n]; - w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n]; - w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n]; - } - // transpose d to d_t - { - t0[0] = w0[0]; - t1[0] = w0[1]; - t2[0] = w0[2]; - t3[0] = w0[3]; - t4[0] = w0[4]; - t5[0] = w0[5]; - t0[1] = w1[0]; - t1[1] = w1[1]; - t2[1] = w1[2]; - t3[1] = w1[3]; - t4[1] = w1[4]; - t5[1] = w1[5]; - t0[2] = w2[0]; - t1[2] = w2[1]; - t2[2] = w2[2]; - t3[2] = w2[3]; - t4[2] = w2[4]; - t5[2] = w2[5]; - t0[3] = w3[0]; - t1[3] = w3[1]; - t2[3] = w3[2]; - t3[3] = w3[3]; - t4[3] = w3[4]; - t5[3] = w3[5]; - t0[4] = w4[0]; - t1[4] = w4[1]; - t2[4] = w4[2]; - t3[4] = w4[3]; - t4[4] = w4[4]; - t5[4] = w4[5]; - t0[5] = w5[0]; - t1[5] = w5[1]; - t2[5] = w5[2]; - t3[5] = w5[3]; - t4[5] = w5[4]; - t5[5] = w5[5]; - } - // d = B_t * d_t - for (int n = 0; n < 6; n++) - { - d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n]; - d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n]; - d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n]; - d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n]; - d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n]; - d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n]; - } - // save to out_tm - for (int n = 0; n < 6; n++) - { - out_tm0[n] = d0[n]; - out_tm0[n + 6] = d1[n]; - out_tm0[n + 12] = d2[n]; - out_tm0[n + 18] = d3[n]; - out_tm0[n + 24] = d4[n]; - out_tm0[n + 30] = d5[n]; - } - - r0 += 4; - r1 += 4; - r2 += 4; - r3 += 4; - r4 += 4; - r5 += 4; - - out_tm0 += 36; - } - } - } - } - bottom_blob_bordered = Mat(); - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - const int tiles = nColBlocks * nRowBlocks; - - top_blob_tm.create(36, tiles, outch, 4u, opt.workspace_allocator); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out0_tm = top_blob_tm.channel(p); - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - - int sum0[36] = {0}; - - for (int q = 0; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* k0 = kernel0_tm.row(q); - - for (int n = 0; n < 36; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - } - } - - for (int n = 0; n < 36; n++) - { - output0_tm[n] = sum0[n]; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator); - { - // AT - // const float itm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + r01 + r02 + r03 + r04 - // 1 = r01 - r02 + 2 * (r03 - r04) - // 2 = r01 + r02 + 4 * (r03 + r04) - // 3 = r01 - r02 + 8 * (r03 - r04) + r05 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out_tm = top_blob_tm.channel(p); - Mat out = top_blob_bordered.channel(p); - - for (int j = 0; j < nColBlocks; j++) - { - int* outRow0 = out.row(j * 4); - int* outRow1 = out.row(j * 4 + 1); - int* outRow2 = out.row(j * 4 + 2); - int* outRow3 = out.row(j * 4 + 3); - - for (int i = 0; i < nRowBlocks; i++) - { - int* out_tile = out_tm.row(j * nRowBlocks + i); - - int s0[6], s1[6], s2[6], s3[6], s4[6], s5[6]; - int w0[6], w1[6], w2[6], w3[6]; - int d0[4], d1[4], d2[4], d3[4], d4[4], d5[4]; - int o0[4], o1[4], o2[4], o3[4]; - // load - for (int n = 0; n < 6; n++) - { - s0[n] = out_tile[n]; - s1[n] = out_tile[n + 6]; - s2[n] = out_tile[n + 12]; - s3[n] = out_tile[n + 18]; - s4[n] = out_tile[n + 24]; - s5[n] = out_tile[n + 30]; - } - // w = A_T * W - for (int n = 0; n < 6; n++) - { - w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n]; - w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n]; - w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n]; - w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + s5[n]; - } - // transpose w to w_t - { - d0[0] = w0[0]; - d0[1] = w1[0]; - d0[2] = w2[0]; - d0[3] = w3[0]; - d1[0] = w0[1]; - d1[1] = w1[1]; - d1[2] = w2[1]; - d1[3] = w3[1]; - d2[0] = w0[2]; - d2[1] = w1[2]; - d2[2] = w2[2]; - d2[3] = w3[2]; - d3[0] = w0[3]; - d3[1] = w1[3]; - d3[2] = w2[3]; - d3[3] = w3[3]; - d4[0] = w0[4]; - d4[1] = w1[4]; - d4[2] = w2[4]; - d4[3] = w3[4]; - d5[0] = w0[5]; - d5[1] = w1[5]; - d5[2] = w2[5]; - d5[3] = w3[5]; - } - // Y = A_T * w_t - for (int n = 0; n < 4; n++) - { - o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n]; - o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n]; - o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n]; - o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n]; - } - // save to top blob tm - for (int n = 0; n < 4; n++) - { - outRow0[n] = o0[n] / 576; - outRow1[n] = o1[n] / 576; - outRow2[n] = o2[n] / 576; - outRow3[n] = o3[n] / 576; - } - - outRow0 += 4; - outRow1 += 4; - outRow2 += 4; - outRow3 += 4; - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - static void conv3x3s2_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt) { int w = bottom_blob.w; diff --git a/src/layer/x86/convolution_3x3_pack8to1_int8.h b/src/layer/x86/convolution_3x3_pack8to1_int8.h deleted file mode 100644 index d5957faf6d89..000000000000 --- a/src/layer/x86/convolution_3x3_pack8to1_int8.h +++ /dev/null @@ -1,1125 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif -#endif - -static void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif -#endif - - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8to1.create(8 * inch / 8, 36, outch / 4 + outch % 4, (size_t)2u * 4, 4); - - int p = 0; - for (; p + 3 < outch; p += 4) - { - Mat g0 = kernel_tm_pack8to1.channel(p / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 4; i++) - { - for (int j = 0; j < 8; j++) - { - const short* k00 = kernel_tm.channel(p + i).row(q + j); - g00[0] = k00[k]; - g00 += 1; - } - } - } - } - } - for (; p < outch; p++) - { - const Mat k0 = kernel_tm.channel(p); - - Mat g0 = kernel_tm_pack8to1.channel(p / 4 + p % 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 8; i++) - { - g00[0] = k0.row(q + i)[k]; - - g00 += 1; - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avx2(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_xop(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif -#endif - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - // TODO use _mm_cvtepi8_epi16 on sse4.1 - __m128i _r00_01 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r02_03 = _mm_loadu_si128((const __m128i*)(r0 + 16)); - __m128i _r04_05 = _mm_loadu_si128((const __m128i*)(r0 + 32)); - __m128i _extr0001 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r00_01); - __m128i _extr0203 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r02_03); - __m128i _extr0405 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r04_05); - __m128i _r00 = _mm_unpacklo_epi8(_r00_01, _extr0001); - __m128i _r01 = _mm_unpackhi_epi8(_r00_01, _extr0001); - __m128i _r02 = _mm_unpacklo_epi8(_r02_03, _extr0203); - __m128i _r03 = _mm_unpackhi_epi8(_r02_03, _extr0203); - __m128i _r04 = _mm_unpacklo_epi8(_r04_05, _extr0405); - __m128i _r05 = _mm_unpackhi_epi8(_r04_05, _extr0405); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _tmp0m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r00, 2), _r04), _mm_mullo_epi16(_r02, _v5)); - __m128i _tmp1m = _mm_sub_epi16(_mm_add_epi16(_r04, _r03), _mm_slli_epi16(_mm_add_epi16(_r01, _r02), 2)); - __m128i _tmp2m = _mm_add_epi16(_mm_sub_epi16(_r04, _r03), _mm_slli_epi16(_mm_sub_epi16(_r01, _r02), 2)); - __m128i _tmp3m = _mm_sub_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp4m = _mm_add_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp5m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r01, 2), _r05), _mm_mullo_epi16(_r03, _v5)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4m); - _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tm / 6 + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _r0tm0 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp00, 2), _tmp04), _mm_mullo_epi16(_tmp02, _v5)); - __m128i _r0tm1 = _mm_sub_epi16(_mm_add_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_add_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm2 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm3 = _mm_sub_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm4 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm5 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp01, 2), _tmp05), _mm_mullo_epi16(_tmp03, _v5)); - - _mm_storeu_si128((__m128i*)r0_tm_0, _r0tm0); - _mm_storeu_si128((__m128i*)r0_tm_1, _r0tm1); - _mm_storeu_si128((__m128i*)r0_tm_2, _r0tm2); - _mm_storeu_si128((__m128i*)r0_tm_3, _r0tm3); - _mm_storeu_si128((__m128i*)r0_tm_4, _r0tm4); - _mm_storeu_si128((__m128i*)r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = h_tm / 6 * w_tm / 6; - - // permute - // bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); - Mat bottom_blob_tm2; -#if __AVX2__ - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#else - if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < 36; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - short* tmpptr = tm2.row(i / 4); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - _mm256_storeu_si256((__m256i*)tmpptr, _r0); - _mm256_storeu_si256((__m256i*)(tmpptr + 16), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 32; - } - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#else - short* tmpptr = tm2.row(i / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - _mm_storeu_si128((__m128i*)(tmpptr + 8), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 16; - } - } - for (; i < tiles; i++) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, 36, outch, 4u, 1, opt.workspace_allocator); - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 4; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p / 4); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); - - __m256i _sum04_15 = _mm256_setzero_si256(); - __m256i _sum14_05 = _mm256_setzero_si256(); - __m256i _sum06_17 = _mm256_setzero_si256(); - __m256i _sum16_07 = _mm256_setzero_si256(); - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif - - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m256i _val32 = _mm256_permute4x64_epi64(_val23, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum04_15 = _mm256_dpwssd_epi32(_sum04_15, _val23, _w01); - _sum14_05 = _mm256_dpwssd_epi32(_sum14_05, _val32, _w01); - _sum06_17 = _mm256_dpwssd_epi32(_sum06_17, _val23, _w23); - _sum16_07 = _mm256_dpwssd_epi32(_sum16_07, _val32, _w23); -#else - _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_madd_epi16(_val23, _w01)); - _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_madd_epi16(_val32, _w01)); - _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_madd_epi16(_val23, _w23)); - _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_madd_epi16(_val32, _w23)); -#endif - - r0 += 32; - k0 += 32; - } - - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); - _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); - _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); - _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); - _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); - - int sum[16]; - _mm256_storeu_si256((__m256i*)sum, _sum00_11); - _mm256_storeu_si256((__m256i*)(sum + 8), _sum04_15); - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm[1] = sum[4]; - output1_tm[1] = sum[5]; - output2_tm[1] = sum[6]; - output3_tm[1] = sum[7]; - output0_tm[2] = sum[8]; - output1_tm[2] = sum[9]; - output2_tm[2] = sum[10]; - output3_tm[2] = sum[11]; - output0_tm[3] = sum[12]; - output1_tm[3] = sum[13]; - output2_tm[3] = sum[14]; - output3_tm[3] = sum[15]; - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); -#else - __m128i _sum00 = _mm_setzero_si128(); - __m128i _sum01 = _mm_setzero_si128(); - __m128i _sum02 = _mm_setzero_si128(); - __m128i _sum03 = _mm_setzero_si128(); - __m128i _sum10 = _mm_setzero_si128(); - __m128i _sum11 = _mm_setzero_si128(); - __m128i _sum12 = _mm_setzero_si128(); - __m128i _sum13 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { -#if __AVX2__ - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif -#else - // 0 1 2 3 4 5 6 7 - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum00 = _mm_maddd_epi16(_val0, _w0, _sum00); - _sum01 = _mm_maddd_epi16(_val0, _w1, _sum01); - _sum02 = _mm_maddd_epi16(_val0, _w2, _sum02); - _sum03 = _mm_maddd_epi16(_val0, _w3, _sum03); - _sum10 = _mm_maddd_epi16(_val1, _w0, _sum10); - _sum11 = _mm_maddd_epi16(_val1, _w1, _sum11); - _sum12 = _mm_maddd_epi16(_val1, _w2, _sum12); - _sum13 = _mm_maddd_epi16(_val1, _w3, _sum13); -#else - _sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00); - _sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01); - _sum02 = _mm_add_epi32(_mm_madd_epi16(_val0, _w2), _sum02); - _sum03 = _mm_add_epi32(_mm_madd_epi16(_val0, _w3), _sum03); - _sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10); - _sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11); - _sum12 = _mm_add_epi32(_mm_madd_epi16(_val1, _w2), _sum12); - _sum13 = _mm_add_epi32(_mm_madd_epi16(_val1, _w3), _sum13); -#endif -#endif - - r0 += 16; - k0 += 32; - } - -#if __AVX2__ - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - - int sum[8]; - _mm256_storeu_si256((__m256i*)sum, _sum00_11); -#else - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); - _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); - _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); - _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); - _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); - _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); - _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); - _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); - _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00 = _mm_add_epi32(_sum00, _sum01); - _sum02 = _mm_add_epi32(_sum02, _sum03); - _sum10 = _mm_add_epi32(_sum10, _sum11); - _sum12 = _mm_add_epi32(_sum12, _sum13); - - _sum00 = _mm_add_epi32(_sum00, _sum02); - _sum10 = _mm_add_epi32(_sum10, _sum12); - - int sum[8]; - _mm_storeu_si128((__m128i*)sum, _sum00); - _mm_storeu_si128((__m128i*)(sum + 4), _sum10); -#endif - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm[1] = sum[4]; - output1_tm[1] = sum[5]; - output2_tm[1] = sum[6]; - output3_tm[1] = sum[7]; - output0_tm += 2; - output1_tm += 2; - output2_tm += 2; - output3_tm += 2; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum0_1 = _mm256_setzero_si256(); - __m256i _sum2_3 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); - __m128i _sum2 = _mm_setzero_si128(); - __m128i _sum3 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 - __m128i _val = _mm_loadu_si128((const __m128i*)r0); - -#if __AVX2__ - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _valval = _mm256_inserti128_si256(_mm256_castsi128_si256(_val), _val, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0_1 = _mm256_dpwssd_epi32(_sum0_1, _valval, _w01); - _sum2_3 = _mm256_dpwssd_epi32(_sum2_3, _valval, _w23); -#else - _sum0_1 = _mm256_add_epi32(_sum0_1, _mm256_madd_epi16(_valval, _w01)); - _sum2_3 = _mm256_add_epi32(_sum2_3, _mm256_madd_epi16(_valval, _w23)); -#endif -#else - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val, _w1, _sum1); - _sum2 = _mm_maddd_epi16(_val, _w2, _sum2); - _sum3 = _mm_maddd_epi16(_val, _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val, _w1), _sum1); - _sum2 = _mm_add_epi32(_mm_madd_epi16(_val, _w2), _sum2); - _sum3 = _mm_add_epi32(_mm_madd_epi16(_val, _w3), _sum3); -#endif -#endif - - r0 += 8; - k0 += 32; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum0_1, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum0_1, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum2_3, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum2_3, 1); -#endif - - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); - _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); - _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); - _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum0 = _mm_add_epi32(_sum0, _sum1); - _sum2 = _mm_add_epi32(_sum2, _sum3); - - _sum0 = _mm_add_epi32(_sum0, _sum2); - - int sum[4]; - _mm_storeu_si128((__m128i*)sum, _sum0); - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - } - } - } - - remain_outch_start += nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p / 4 + p % 4); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - __m256i _sum01 = _mm256_setzero_si256(); - __m256i _sum23 = _mm256_setzero_si256(); - - for (int q = 0; q < inch; q++) - { - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m256i _w01 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w0), _w0, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum01 = _mm256_dpwssd_epi32(_sum01, _val01, _w01); - _sum23 = _mm256_dpwssd_epi32(_sum23, _val23, _w01); -#else - _sum01 = _mm256_add_epi32(_sum01, _mm256_madd_epi16(_val01, _w01)); - _sum23 = _mm256_add_epi32(_sum23, _mm256_madd_epi16(_val23, _w01)); -#endif - - k0 += 8; - r0 += 32; - } - - __m128i _sum0 = _mm256_extracti128_si256(_sum01, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum01, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum23, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum23, 1); - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm[1] = _mm_reduce_add_epi32(_sum1); - output0_tm[2] = _mm_reduce_add_epi32(_sum2); - output0_tm[3] = _mm_reduce_add_epi32(_sum3); - output0_tm += 4; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - -#if __AVX2__ - __m256i _sum01 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); -#endif - - for (int q = 0; q < inch; q++) - { -#if __AVX2__ - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m256i _w01 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w0), _w0, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum01 = _mm256_dpwssd_epi32(_sum01, _val01, _w01); -#else - _sum01 = _mm256_add_epi32(_sum01, _mm256_madd_epi16(_val01, _w01)); -#endif -#else - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val0, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val1, _w0, _sum1); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum1); -#endif -#endif - - k0 += 8; - r0 += 16; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum01, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum01, 1); -#endif - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm[1] = _mm_reduce_add_epi32(_sum1); - output0_tm += 2; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - __m128i _sum0 = _mm_setzero_si128(); - - for (int q = 0; q < inch; q++) - { - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val0, _w0, _sum0); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum0); -#endif - - k0 += 8; - r0 += 8; - } - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm++; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); - } - { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - int tmp[4][6]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, 4u, 1, opt.workspace_allocator); - - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 1; - const int* output0_tm_1 = output0_tm_0 + tiles * 1; - const int* output0_tm_2 = output0_tm_0 + tiles * 2; - const int* output0_tm_3 = output0_tm_0 + tiles * 3; - const int* output0_tm_4 = output0_tm_0 + tiles * 4; - const int* output0_tm_5 = output0_tm_0 + tiles * 5; - - int* output0 = out0.row(i * 4) + j * 4; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - // TODO sse optimize - for (int m = 0; m < 5; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = output0_tm_0[0] + tmp02a + tmp02b; - tmp[1][m] = tmp13a + tmp13b * 2; - tmp[2][m] = tmp02a + tmp02b * 4; - tmp[3][m] = output0_tm_5[0] * 4 + tmp13a + tmp13b * 8; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - for (int m = 5; m < 6; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = (output0_tm_0[0] + tmp02a + tmp02b) * 4; - tmp[1][m] = (tmp13a + tmp13b * 2) * 4; - tmp[2][m] = (tmp02a + tmp02b * 4) * 4; - tmp[3][m] = (output0_tm_5[0] * 4 + tmp13a + tmp13b * 8) * 4; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - - for (int m = 0; m < 4; m++) - { - const int* tmp0 = tmp[m]; - - int tmp02a = tmp0[1] + tmp0[2]; - int tmp13a = tmp0[1] - tmp0[2]; - - int tmp02b = tmp0[3] + tmp0[4]; - int tmp13b = tmp0[3] - tmp0[4]; - - output0[0] = (tmp0[0] + tmp02a + tmp02b) / 576; - output0[1] = (tmp13a + tmp13b * 2) / 576; - output0[2] = (tmp02a + tmp02b * 4) / 576; - output0[3] = (tmp0[5] + tmp13a + tmp13b * 8) / 576; - - output0 += outw; - } - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/x86/convolution_3x3_pack8to4_int8.h b/src/layer/x86/convolution_3x3_pack8to4_int8.h deleted file mode 100644 index 2bb48ce1903a..000000000000 --- a/src/layer/x86/convolution_3x3_pack8to4_int8.h +++ /dev/null @@ -1,945 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif -#endif - -static void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif -#endif - - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8.create(inch / 8, 36, outch / 4, (size_t)2u * 32, 32); - - int q = 0; - for (; q + 3 < outch; q += 4) - { - Mat g0 = kernel_tm_pack8.channel(q / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int p = 0; p + 7 < inch; p += 8) - { - for (int i = 0; i < 4; i++) - { - for (int j = 0; j < 8; j++) - { - const short* k00 = kernel_tm.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00 += 1; - } - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avx2(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_xop(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif -#endif - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - // TODO use _mm_cvtepi8_epi16 on sse4.1 - __m128i _r00_01 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r02_03 = _mm_loadu_si128((const __m128i*)(r0 + 16)); - __m128i _r04_05 = _mm_loadu_si128((const __m128i*)(r0 + 32)); - __m128i _extr0001 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r00_01); - __m128i _extr0203 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r02_03); - __m128i _extr0405 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r04_05); - __m128i _r00 = _mm_unpacklo_epi8(_r00_01, _extr0001); - __m128i _r01 = _mm_unpackhi_epi8(_r00_01, _extr0001); - __m128i _r02 = _mm_unpacklo_epi8(_r02_03, _extr0203); - __m128i _r03 = _mm_unpackhi_epi8(_r02_03, _extr0203); - __m128i _r04 = _mm_unpacklo_epi8(_r04_05, _extr0405); - __m128i _r05 = _mm_unpackhi_epi8(_r04_05, _extr0405); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _tmp0m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r00, 2), _r04), _mm_mullo_epi16(_r02, _v5)); - __m128i _tmp1m = _mm_sub_epi16(_mm_add_epi16(_r04, _r03), _mm_slli_epi16(_mm_add_epi16(_r01, _r02), 2)); - __m128i _tmp2m = _mm_add_epi16(_mm_sub_epi16(_r04, _r03), _mm_slli_epi16(_mm_sub_epi16(_r01, _r02), 2)); - __m128i _tmp3m = _mm_sub_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp4m = _mm_add_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp5m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r01, 2), _r05), _mm_mullo_epi16(_r03, _v5)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4m); - _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tm / 6 + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _r0tm0 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp00, 2), _tmp04), _mm_mullo_epi16(_tmp02, _v5)); - __m128i _r0tm1 = _mm_sub_epi16(_mm_add_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_add_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm2 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm3 = _mm_sub_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm4 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm5 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp01, 2), _tmp05), _mm_mullo_epi16(_tmp03, _v5)); - - _mm_storeu_si128((__m128i*)r0_tm_0, _r0tm0); - _mm_storeu_si128((__m128i*)r0_tm_1, _r0tm1); - _mm_storeu_si128((__m128i*)r0_tm_2, _r0tm2); - _mm_storeu_si128((__m128i*)r0_tm_3, _r0tm3); - _mm_storeu_si128((__m128i*)r0_tm_4, _r0tm4); - _mm_storeu_si128((__m128i*)r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = h_tm / 6 * w_tm / 6; - - // permute - // bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); - Mat bottom_blob_tm2; -#if __AVX2__ - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#else - if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < 36; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - short* tmpptr = tm2.row(i / 4); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - _mm256_storeu_si256((__m256i*)tmpptr, _r0); - _mm256_storeu_si256((__m256i*)(tmpptr + 16), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 32; - } - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#else - short* tmpptr = tm2.row(i / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - _mm_storeu_si128((__m128i*)(tmpptr + 8), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 16; - } - } - for (; i < tiles; i++) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, 36, outch, 4u * 4, 4, opt.workspace_allocator); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); - - __m256i _sum04_15 = _mm256_setzero_si256(); - __m256i _sum14_05 = _mm256_setzero_si256(); - __m256i _sum06_17 = _mm256_setzero_si256(); - __m256i _sum16_07 = _mm256_setzero_si256(); - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif - - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m256i _val32 = _mm256_permute4x64_epi64(_val23, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum04_15 = _mm256_dpwssd_epi32(_sum04_15, _val23, _w01); - _sum14_05 = _mm256_dpwssd_epi32(_sum14_05, _val32, _w01); - _sum06_17 = _mm256_dpwssd_epi32(_sum06_17, _val23, _w23); - _sum16_07 = _mm256_dpwssd_epi32(_sum16_07, _val32, _w23); -#else - _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_madd_epi16(_val23, _w01)); - _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_madd_epi16(_val32, _w01)); - _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_madd_epi16(_val23, _w23)); - _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_madd_epi16(_val32, _w23)); -#endif - - r0 += 32; - k0 += 32; - } - - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); - _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); - _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); - _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); - _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); - - _mm256_storeu_si256((__m256i*)output0_tm, _sum00_11); - _mm256_storeu_si256((__m256i*)(output0_tm + 8), _sum04_15); - output0_tm += 16; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); -#else - __m128i _sum00 = _mm_setzero_si128(); - __m128i _sum01 = _mm_setzero_si128(); - __m128i _sum02 = _mm_setzero_si128(); - __m128i _sum03 = _mm_setzero_si128(); - __m128i _sum10 = _mm_setzero_si128(); - __m128i _sum11 = _mm_setzero_si128(); - __m128i _sum12 = _mm_setzero_si128(); - __m128i _sum13 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { -#if __AVX2__ - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif -#else - // 0 1 2 3 4 5 6 7 - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum00 = _mm_maddd_epi16(_val0, _w0, _sum00); - _sum01 = _mm_maddd_epi16(_val0, _w1, _sum01); - _sum02 = _mm_maddd_epi16(_val0, _w2, _sum02); - _sum03 = _mm_maddd_epi16(_val0, _w3, _sum03); - _sum10 = _mm_maddd_epi16(_val1, _w0, _sum10); - _sum11 = _mm_maddd_epi16(_val1, _w1, _sum11); - _sum12 = _mm_maddd_epi16(_val1, _w2, _sum12); - _sum13 = _mm_maddd_epi16(_val1, _w3, _sum13); -#else - _sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00); - _sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01); - _sum02 = _mm_add_epi32(_mm_madd_epi16(_val0, _w2), _sum02); - _sum03 = _mm_add_epi32(_mm_madd_epi16(_val0, _w3), _sum03); - _sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10); - _sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11); - _sum12 = _mm_add_epi32(_mm_madd_epi16(_val1, _w2), _sum12); - _sum13 = _mm_add_epi32(_mm_madd_epi16(_val1, _w3), _sum13); -#endif -#endif - - r0 += 16; - k0 += 32; - } - -#if __AVX2__ - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - - _mm256_storeu_si256((__m256i*)output0_tm, _sum00_11); -#else - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); - _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); - _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); - _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); - _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); - _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); - _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); - _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); - _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00 = _mm_add_epi32(_sum00, _sum01); - _sum02 = _mm_add_epi32(_sum02, _sum03); - _sum10 = _mm_add_epi32(_sum10, _sum11); - _sum12 = _mm_add_epi32(_sum12, _sum13); - - _sum00 = _mm_add_epi32(_sum00, _sum02); - _sum10 = _mm_add_epi32(_sum10, _sum12); - - _mm_storeu_si128((__m128i*)output0_tm, _sum00); - _mm_storeu_si128((__m128i*)(output0_tm + 4), _sum10); -#endif - output0_tm += 8; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum0_1 = _mm256_setzero_si256(); - __m256i _sum2_3 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); - __m128i _sum2 = _mm_setzero_si128(); - __m128i _sum3 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 - __m128i _val = _mm_loadu_si128((const __m128i*)r0); -#if __AVX2__ - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _valval = _mm256_inserti128_si256(_mm256_castsi128_si256(_val), _val, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0_1 = _mm256_dpwssd_epi32(_sum0_1, _valval, _w01); - _sum2_3 = _mm256_dpwssd_epi32(_sum2_3, _valval, _w23); -#else - _sum0_1 = _mm256_add_epi32(_sum0_1, _mm256_madd_epi16(_valval, _w01)); - _sum2_3 = _mm256_add_epi32(_sum2_3, _mm256_madd_epi16(_valval, _w23)); -#endif -#else - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val, _w1, _sum1); - _sum2 = _mm_maddd_epi16(_val, _w2, _sum2); - _sum3 = _mm_maddd_epi16(_val, _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val, _w1), _sum1); - _sum2 = _mm_add_epi32(_mm_madd_epi16(_val, _w2), _sum2); - _sum3 = _mm_add_epi32(_mm_madd_epi16(_val, _w3), _sum3); -#endif -#endif - - r0 += 8; - k0 += 32; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum0_1, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum0_1, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum2_3, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum2_3, 1); -#endif - - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); - _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); - _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); - _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum0 = _mm_add_epi32(_sum0, _sum1); - _sum2 = _mm_add_epi32(_sum2, _sum3); - - _sum0 = _mm_add_epi32(_sum0, _sum2); - - _mm_storeu_si128((__m128i*)output0_tm, _sum0); - output0_tm += 4; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u * 4, 4, opt.workspace_allocator); - } - { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - int tmp[4][6][4]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 4; - const int* output0_tm_1 = output0_tm_0 + tiles * 4; - const int* output0_tm_2 = output0_tm_0 + tiles * 8; - const int* output0_tm_3 = output0_tm_0 + tiles * 12; - const int* output0_tm_4 = output0_tm_0 + tiles * 16; - const int* output0_tm_5 = output0_tm_0 + tiles * 20; - - int* output0 = out0.row(i * 4) + (j * 4) * 4; - - // TODO sse optimize - for (int m = 0; m < 5; m++) - { - __m128i _out0tm0 = _mm_loadu_si128((const __m128i*)output0_tm_0); - __m128i _out0tm1 = _mm_loadu_si128((const __m128i*)output0_tm_1); - __m128i _out0tm2 = _mm_loadu_si128((const __m128i*)output0_tm_2); - __m128i _out0tm3 = _mm_loadu_si128((const __m128i*)output0_tm_3); - __m128i _out0tm4 = _mm_loadu_si128((const __m128i*)output0_tm_4); - __m128i _out0tm5 = _mm_loadu_si128((const __m128i*)output0_tm_5); - - __m128i _tmp02a = _mm_add_epi32(_out0tm1, _out0tm2); - __m128i _tmp13a = _mm_sub_epi32(_out0tm1, _out0tm2); - - __m128i _tmp02b = _mm_add_epi32(_out0tm3, _out0tm4); - __m128i _tmp13b = _mm_sub_epi32(_out0tm3, _out0tm4); - - __m128i _tmp0m = _mm_add_epi32(_mm_add_epi32(_out0tm0, _tmp02a), _tmp02b); - __m128i _tmp1m = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _tmp2m = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _tmp3m = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_out0tm5, 2)), _mm_slli_epi32(_tmp13b, 3)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - for (int m = 5; m < 6; m++) - { - __m128i _out0tm0 = _mm_loadu_si128((const __m128i*)output0_tm_0); - __m128i _out0tm1 = _mm_loadu_si128((const __m128i*)output0_tm_1); - __m128i _out0tm2 = _mm_loadu_si128((const __m128i*)output0_tm_2); - __m128i _out0tm3 = _mm_loadu_si128((const __m128i*)output0_tm_3); - __m128i _out0tm4 = _mm_loadu_si128((const __m128i*)output0_tm_4); - __m128i _out0tm5 = _mm_loadu_si128((const __m128i*)output0_tm_5); - - __m128i _tmp02a = _mm_add_epi32(_out0tm1, _out0tm2); - __m128i _tmp13a = _mm_sub_epi32(_out0tm1, _out0tm2); - - __m128i _tmp02b = _mm_add_epi32(_out0tm3, _out0tm4); - __m128i _tmp13b = _mm_sub_epi32(_out0tm3, _out0tm4); - - __m128i _tmp0m = _mm_add_epi32(_mm_add_epi32(_out0tm0, _tmp02a), _tmp02b); - __m128i _tmp1m = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _tmp2m = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _tmp3m = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_out0tm5, 2)), _mm_slli_epi32(_tmp13b, 3)); - - _tmp0m = _mm_slli_epi32(_tmp0m, 2); - _tmp1m = _mm_slli_epi32(_tmp1m, 2); - _tmp2m = _mm_slli_epi32(_tmp2m, 2); - _tmp3m = _mm_slli_epi32(_tmp3m, 2); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - - for (int m = 0; m < 4; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _tmp02a = _mm_add_epi32(_tmp01, _tmp02); - __m128i _tmp13a = _mm_sub_epi32(_tmp01, _tmp02); - - __m128i _tmp02b = _mm_add_epi32(_tmp03, _tmp04); - __m128i _tmp13b = _mm_sub_epi32(_tmp03, _tmp04); - - __m128i _out00 = _mm_add_epi32(_mm_add_epi32(_tmp00, _tmp02a), _tmp02b); - __m128i _out01 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _out02 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _out03 = _mm_add_epi32(_mm_add_epi32(_tmp05, _tmp13a), _mm_slli_epi32(_tmp13b, 3)); - - // TODO use integer trick for division by 576 - __m128 _v576 = _mm_set1_ps(1.0 / 576); - _out00 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out00), _v576)); - _out01 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out01), _v576)); - _out02 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out02), _v576)); - _out03 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out03), _v576)); - - _mm_storeu_si128((__m128i*)output0, _out00); - _mm_storeu_si128((__m128i*)(output0 + 4), _out01); - _mm_storeu_si128((__m128i*)(output0 + 8), _out02); - _mm_storeu_si128((__m128i*)(output0 + 12), _out03); - - output0 += outw * 4; - } - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/x86/convolution_3x3_winograd_int8.h b/src/layer/x86/convolution_3x3_winograd_int8.h new file mode 100644 index 000000000000..8c7b891b0dda --- /dev/null +++ b/src/layer/x86/convolution_3x3_winograd_int8.h @@ -0,0 +1,6407 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ +void conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +void conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +void conv3x3s1_winograd23_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); +void conv3x3s1_winograd23_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); +void conv3x3s1_winograd43_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +void conv3x3s1_winograd23_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif +#endif + +static void pack_A_tile_int8(const Mat& A, Mat& AT, int batch, int max_ii, int max_kk) +{ + const int N = max_kk * batch; + + for (int b = 0; b < batch; b++) + { + short* pp = AT.row(b); + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + pp[8] = p0[N * 4]; + pp[9] = p0[N * 4 + batch]; + pp[10] = p0[N * 5]; + pp[11] = p0[N * 5 + batch]; + pp[12] = p0[N * 6]; + pp[13] = p0[N * 6 + batch]; + pp[14] = p0[N * 7]; + pp[15] = p0[N * 7 + batch]; + pp[16] = p0[N * 8]; + pp[17] = p0[N * 8 + batch]; + pp[18] = p0[N * 9]; + pp[19] = p0[N * 9 + batch]; + pp[20] = p0[N * 10]; + pp[21] = p0[N * 10 + batch]; + pp[22] = p0[N * 11]; + pp[23] = p0[N * 11 + batch]; + pp[24] = p0[N * 12]; + pp[25] = p0[N * 12 + batch]; + pp[26] = p0[N * 13]; + pp[27] = p0[N * 13 + batch]; + pp[28] = p0[N * 14]; + pp[29] = p0[N * 14 + batch]; + pp[30] = p0[N * 15]; + pp[31] = p0[N * 15 + batch]; + p0 += batch * 2; + pp += 32; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + pp[8] = p0[N * 8]; + pp[9] = p0[N * 9]; + pp[10] = p0[N * 10]; + pp[11] = p0[N * 11]; + pp[12] = p0[N * 12]; + pp[13] = p0[N * 13]; + pp[14] = p0[N * 14]; + pp[15] = p0[N * 15]; + p0 += batch; + pp += 16; + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + pp[8] = p0[N * 4]; + pp[9] = p0[N * 4 + batch]; + pp[10] = p0[N * 5]; + pp[11] = p0[N * 5 + batch]; + pp[12] = p0[N * 6]; + pp[13] = p0[N * 6 + batch]; + pp[14] = p0[N * 7]; + pp[15] = p0[N * 7 + batch]; + p0 += batch * 2; + pp += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + p0 += batch; + pp += 8; + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + p0 += batch * 2; + pp += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + p0 += batch; + pp += 4; + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + p0 += batch * 2; + pp += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + p0 += batch; + pp += 2; + } + } + for (; ii < max_ii; ii++) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += batch; + pp += 1; + } + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int batch, int max_jj, int max_kk, int nT) +{ + #pragma omp parallel for num_threads(nT) + for (int b = 0; b < batch; b++) + { + short* pp = BT.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _r4 = _mm512_loadu_si512((const __m512i*)(p0 + 128)); + __m512i _r5 = _mm512_loadu_si512((const __m512i*)(p0 + 160)); + __m512i _r6 = _mm512_loadu_si512((const __m512i*)(p0 + 192)); + __m512i _r7 = _mm512_loadu_si512((const __m512i*)(p0 + 224)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_r4, _r6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_r4, _r6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_r5, _r7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_r5, _r7, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _r4 = _mm512_unpacklo_epi32(_tmp4, _tmp5); + _r5 = _mm512_unpackhi_epi32(_tmp4, _tmp5); + _r6 = _mm512_unpacklo_epi32(_tmp6, _tmp7); + _r7 = _mm512_unpackhi_epi32(_tmp6, _tmp7); + _tmp0 = _mm512_unpacklo_epi64(_r0, _r2); + _tmp1 = _mm512_unpackhi_epi64(_r0, _r2); + _tmp2 = _mm512_unpacklo_epi64(_r1, _r3); + _tmp3 = _mm512_unpackhi_epi64(_r1, _r3); + _tmp4 = _mm512_unpacklo_epi64(_r4, _r6); + _tmp5 = _mm512_unpackhi_epi64(_r4, _r6); + _tmp6 = _mm512_unpacklo_epi64(_r5, _r7); + _tmp7 = _mm512_unpackhi_epi64(_r5, _r7); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_i32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_i32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _r4 = _mm512_shuffle_i32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 1, 3, 1)); + _r5 = _mm512_shuffle_i32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _r6 = _mm512_shuffle_i32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _r7 = _mm512_shuffle_i32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + _mm512_storeu_si512((__m512i*)(pp + 128), _r4); + _mm512_storeu_si512((__m512i*)(pp + 160), _r5); + _mm512_storeu_si512((__m512i*)(pp + 192), _r6); + _mm512_storeu_si512((__m512i*)(pp + 224), _r7); + p0 += max_jj * batch * 16; + pp += 256; + } + p0 -= (b * max_jj + jj) * 16; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r2, _r3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r2, _r3, _MM_SHUFFLE(3, 1, 3, 1)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _tmp0 = _mm512_permutex_epi64(_r0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm512_permutex_epi64(_r1, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp2 = _mm512_permutex_epi64(_r2, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp3 = _mm512_permutex_epi64(_r3, _MM_SHUFFLE(3, 1, 2, 0)); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 3, 1)); + _r2 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + p0 += max_jj * batch * 8; + pp += 128; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + _mm512_storeu_si512((__m512i*)pp, _r0); + p0 += max_jj * batch * 2; + pp += 32; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); + _mm256_store_si256((__m256i*)pp, _r0); + p0 += max_jj * batch; + pp += 16; + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* p0 = B; + + int kk = 0; +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _tmp0 = _mm512_unpacklo_epi64(_r0, _r2); + _tmp1 = _mm512_unpackhi_epi64(_r0, _r2); + _tmp2 = _mm512_unpacklo_epi64(_r1, _r3); + _tmp3 = _mm512_unpackhi_epi64(_r1, _r3); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _r3 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + p0 += max_jj * batch * 16; + pp += 128; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if __AVX__ + __m256 _r0 = _mm256_loadu_ps((const float*)p0); + __m256 _r1 = _mm256_loadu_ps((const float*)(p0 + 16)); + __m256 _r2 = _mm256_loadu_ps((const float*)(p0 + 32)); + __m256 _r3 = _mm256_loadu_ps((const float*)(p0 + 48)); + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp2 = _mm256_permute2f128_ps(_r1, _r3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_r1, _r3, _MM_SHUFFLE(0, 3, 0, 1)); + _r0 = _mm256_unpacklo_ps(_tmp0, _tmp1); + _r1 = _mm256_unpackhi_ps(_tmp0, _tmp1); + _r2 = _mm256_unpacklo_ps(_tmp2, _tmp3); + _r3 = _mm256_unpackhi_ps(_tmp2, _tmp3); + _tmp0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_r0), _mm256_castps_pd(_r2))); + _tmp1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_r0), _mm256_castps_pd(_r2))); + _tmp2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_r1), _mm256_castps_pd(_r3))); + _tmp3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_r1), _mm256_castps_pd(_r3))); + _mm256_storeu_ps((float*)pp, _tmp0); + _mm256_storeu_ps((float*)(pp + 16), _tmp1); + _mm256_storeu_ps((float*)(pp + 32), _tmp2); + _mm256_storeu_ps((float*)(pp + 48), _tmp3); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_load_si128((const __m128i*)(p0 + 8 * 2)); + __m128i _r3 = _mm_load_si128((const __m128i*)(p0 + 8 * 3)); + __m128i _r4 = _mm_load_si128((const __m128i*)(p0 + 8 * 4)); + __m128i _r5 = _mm_load_si128((const __m128i*)(p0 + 8 * 5)); + __m128i _r6 = _mm_load_si128((const __m128i*)(p0 + 8 * 6)); + __m128i _r7 = _mm_load_si128((const __m128i*)(p0 + 8 * 7)); + transpose4x8_epi32(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm_store_si128((__m128i*)pp, _r0); + _mm_store_si128((__m128i*)(pp + 8), _r1); + _mm_store_si128((__m128i*)(pp + 8 * 2), _r2); + _mm_store_si128((__m128i*)(pp + 8 * 3), _r3); + _mm_store_si128((__m128i*)(pp + 8 * 4), _r4); + _mm_store_si128((__m128i*)(pp + 8 * 5), _r5); + _mm_store_si128((__m128i*)(pp + 8 * 6), _r6); + _mm_store_si128((__m128i*)(pp + 8 * 7), _r7); +#endif // __AVX__ + p0 += max_jj * batch * 8; + pp += 64; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX__ + __m256 _r0 = _mm256_loadu_ps((const float*)p0); + _mm256_storeu_ps((float*)pp, _r0); +#else + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + _mm_store_si128((__m128i*)pp, _r0); + _mm_store_si128((__m128i*)(pp + 8), _r1); +#endif // __AVX__ + p0 += max_jj * batch * 2; + pp += 16; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + _mm_store_si128((__m128i*)pp, _r0); + p0 += max_jj * batch; + pp += 8; + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* p0 = B; + + int kk = 0; +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r0 = _mm512_permutex_epi64(_r0, _MM_SHUFFLE(3, 1, 2, 0)); + _r1 = _mm512_permutex_epi64(_r1, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(1, 0, 1, 0)); + _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi64(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi64(_tmp0, _tmp1); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + p0 += max_jj * batch * 16; + pp += 64; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_load_si128((const __m128i*)(p0 + 8 * 2)); + __m128i _r3 = _mm_load_si128((const __m128i*)(p0 + 8 * 3)); + transpose4x4_epi32(_r0, _r1, _r2, _r3); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); + _mm_storeu_si128((__m128i*)(pp + 8 * 2), _r2); + _mm_storeu_si128((__m128i*)(pp + 8 * 3), _r3); + p0 += max_jj * batch * 8; + pp += 32; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _r0); + p0 += max_jj * batch * 2; + pp += 8; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + p0 += max_jj * batch; + pp += 4; + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* p0 = B; + + int kk = 0; +#if __SSE2__ +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)p0); + __m256i _r1 = _mm256_load_si256((const __m256i*)(p0 + 16)); + transpose8x2_epi32(_r0, _r1); + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + p0 += max_jj * batch * 16; + pp += 32; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _tmp0); + _mm_storeu_si128((__m128i*)(pp + 8), _tmp1); + p0 += max_jj * batch * 8; + pp += 16; + } + p0 -= (b * max_jj + jj) * 8; +#endif // __SSE2__ + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + p0 += max_jj * batch * 2; + pp += 4; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch; + pp += 2; + } + } + for (; jj < max_jj; jj++) + { + const short* p0 = B; + + int kk = 0; +#if __SSE2__ +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)p0); + _mm256_storeu_si256((__m256i*)pp, _r0); + p0 += max_jj * batch * 16; + pp += 16; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _r0); + p0 += max_jj * batch * 8; + pp += 8; + } + p0 -= (b * max_jj + jj) * 8; +#endif // __SSE2__ + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch * 2; + pp += 2; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += max_jj * batch; + pp += 1; + } + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& top_blob, int batch, int max_ii, int max_jj, int k, int max_kk, bool k_end) +{ + int* outptr = top_blob; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + __m512i _sum8; + __m512i _sum9; + __m512i _suma; + __m512i _sumb; + __m512i _sumc; + __m512i _sumd; + __m512i _sume; + __m512i _sumf; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + _sum8 = _mm512_setzero_si512(); + _sum9 = _mm512_setzero_si512(); + _suma = _mm512_setzero_si512(); + _sumb = _mm512_setzero_si512(); + _sumc = _mm512_setzero_si512(); + _sumd = _mm512_setzero_si512(); + _sume = _mm512_setzero_si512(); + _sumf = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + _sum8 = _mm512_load_si512((const __m512i*)(outptr + 128)); + _sum9 = _mm512_load_si512((const __m512i*)(outptr + 128 + 16)); + _suma = _mm512_load_si512((const __m512i*)(outptr + 128 + 32)); + _sumb = _mm512_load_si512((const __m512i*)(outptr + 128 + 48)); + _sumc = _mm512_load_si512((const __m512i*)(outptr + 128 + 64)); + _sumd = _mm512_load_si512((const __m512i*)(outptr + 128 + 80)); + _sume = _mm512_load_si512((const __m512i*)(outptr + 128 + 96)); + _sumf = _mm512_load_si512((const __m512i*)(outptr + 128 + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 1, 0, 3)); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pA3 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); + _sum8 = _mm512_dpwssd_epi32(_sum8, _pA2, _pB0); + _sum9 = _mm512_dpwssd_epi32(_sum9, _pA2, _pB1); + _suma = _mm512_dpwssd_epi32(_suma, _pA2, _pB2); + _sumb = _mm512_dpwssd_epi32(_sumb, _pA2, _pB3); + _sumc = _mm512_dpwssd_epi32(_sumc, _pA3, _pB0); + _sumd = _mm512_dpwssd_epi32(_sumd, _pA3, _pB1); + _sume = _mm512_dpwssd_epi32(_sume, _pA3, _pB2); + _sumf = _mm512_dpwssd_epi32(_sumf, _pA3, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); + _sum8 = _mm512_add_epi32(_sum8, _mm512_madd_epi16(_pA2, _pB0)); + _sum9 = _mm512_add_epi32(_sum9, _mm512_madd_epi16(_pA2, _pB1)); + _suma = _mm512_add_epi32(_suma, _mm512_madd_epi16(_pA2, _pB2)); + _sumb = _mm512_add_epi32(_sumb, _mm512_madd_epi16(_pA2, _pB3)); + _sumc = _mm512_add_epi32(_sumc, _mm512_madd_epi16(_pA3, _pB0)); + _sumd = _mm512_add_epi32(_sumd, _mm512_madd_epi16(_pA3, _pB1)); + _sume = _mm512_add_epi32(_sume, _mm512_madd_epi16(_pA3, _pB2)); + _sumf = _mm512_add_epi32(_sumf, _mm512_madd_epi16(_pA3, _pB3)); +#endif + + pA += 32; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pB)); + + __m512i _pA1 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 1, 0, 3)); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pA3 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA0, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA0, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA1, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA1, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA1, _pB3); + __m512i _s8 = _mm512_mullo_epi32(_pA2, _pB0); + __m512i _s9 = _mm512_mullo_epi32(_pA2, _pB1); + __m512i _sa = _mm512_mullo_epi32(_pA2, _pB2); + __m512i _sb = _mm512_mullo_epi32(_pA2, _pB3); + __m512i _sc = _mm512_mullo_epi32(_pA3, _pB0); + __m512i _sd = _mm512_mullo_epi32(_pA3, _pB1); + __m512i _se = _mm512_mullo_epi32(_pA3, _pB2); + __m512i _sf = _mm512_mullo_epi32(_pA3, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + _sum8 = _mm512_add_epi32(_sum8, _s8); + _sum9 = _mm512_add_epi32(_sum9, _s9); + _suma = _mm512_add_epi32(_suma, _sa); + _sumb = _mm512_add_epi32(_sumb, _sb); + _sumc = _mm512_add_epi32(_sumc, _sc); + _sumd = _mm512_add_epi32(_sumd, _sd); + _sume = _mm512_add_epi32(_sume, _se); + _sumf = _mm512_add_epi32(_sumf, _sf); + + pA += 16; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff + // 01 12 23 30 45 56 67 74 89 9a ab b8 cd de ef fc + // 02 13 20 31 46 57 64 75 8a 9b a8 b9 ce df ec fd + // 03 10 21 32 47 54 65 76 8b 98 a9 ba cf dc ed fe + // c0 d1 e2 f3 04 15 26 37 48 59 6a 7b 8c 9d ae bf + // c1 d2 e3 f0 05 16 27 34 49 5a 6b 78 8d 9e af bc + // c2 d3 e0 f1 06 17 24 35 4a 5b 68 79 8e 9f ac bd + // c3 d0 e1 f2 07 14 25 36 4b 58 69 7a 8f 9c ad be + // 80 91 a2 b3 c4 d5 e6 f7 08 19 2a 3b 4c 5d 6e 7f + // 81 92 a3 b0 c5 d6 e7 f4 09 1a 2b 38 4d 5e 6f 7c + // 82 93 a0 b1 c6 d7 e4 f5 0a 1b 28 39 4e 5f 6c 7d + // 83 90 a1 b2 c7 d4 e5 f6 0b 18 29 3a 4f 5c 6d 7e + // 40 51 62 73 84 95 a6 b7 c8 d9 ea fb 0c 1d 2e 3f + // 41 52 63 70 85 96 a7 b4 c9 da eb f8 0d 1e 2f 3c + // 42 53 60 71 86 97 a4 b5 ca db e8 f9 0e 1f 2c 3d + // 43 50 61 72 87 94 a5 b6 cb d8 e9 fa 0f 1c 2d 3e + // to + // 00 10 20 30 44 54 64 74 88 98 a8 b8 cc dc ec fc + // 01 11 21 31 45 55 65 75 89 99 a9 b9 cd dd ed fd + // 02 12 22 32 46 56 66 76 8a 9a aa ba ce de ee fe + // 03 13 23 33 47 57 67 77 8b 9b ab bb cf df ef ff + // c0 d0 e0 f0 04 14 24 34 48 58 68 78 8c 9c ac bc + // c1 d1 e1 f1 05 15 25 35 49 59 69 79 8d 9d ad bd + // c2 d2 e2 f2 06 16 26 36 4a 5a 6a 7a 8e 9e ae be + // c3 d3 e3 f3 07 17 27 37 4b 5b 6b 7b 8f 9f af bf + // 80 90 a0 b0 c4 d4 e4 f4 08 18 28 38 4c 5c 6c 7c + // 81 91 a1 b1 c5 d5 e5 f5 09 19 29 39 4d 5d 6d 7d + // 82 92 a2 b2 c6 d6 e6 f6 0a 1a 2a 3a 4e 5e 6e 7e + // 83 93 a3 b3 c7 d7 e7 f7 0b 1b 2b 3b 4f 5f 6f 7f + // 40 50 60 70 84 94 a4 b4 c8 d8 e8 f8 0c 1c 2c 3c + // 41 51 61 71 85 95 a5 b5 c9 d9 e9 f9 0d 1d 2d 3d + // 42 52 62 72 86 96 a6 b6 ca da ea fa 0e 1e 2e 3e + // 43 53 63 73 87 97 a7 b7 cb db eb fb 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + _sum9 = _mm512_shuffle_epi32(_sum9, _MM_PERM_CBAD); + _suma = _mm512_shuffle_epi32(_suma, _MM_PERM_BADC); + _sumb = _mm512_shuffle_epi32(_sumb, _MM_PERM_ADCB); + _sumd = _mm512_shuffle_epi32(_sumd, _MM_PERM_CBAD); + _sume = _mm512_shuffle_epi32(_sume, _MM_PERM_BADC); + _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + __m512i _tmp8 = _mm512_unpacklo_epi32(_sum8, _sumb); + __m512i _tmp9 = _mm512_unpackhi_epi32(_sum8, _sumb); + __m512i _tmpa = _mm512_unpacklo_epi32(_suma, _sum9); + __m512i _tmpb = _mm512_unpackhi_epi32(_suma, _sum9); + __m512i _tmpc = _mm512_unpacklo_epi32(_sumc, _sumf); + __m512i _tmpd = _mm512_unpackhi_epi32(_sumc, _sumf); + __m512i _tmpe = _mm512_unpacklo_epi32(_sume, _sumd); + __m512i _tmpf = _mm512_unpackhi_epi32(_sume, _sumd); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum8 = _mm512_unpacklo_epi64(_tmp8, _tmpa); + _sum9 = _mm512_unpackhi_epi64(_tmp8, _tmpa); + _suma = _mm512_unpacklo_epi64(_tmpb, _tmp9); + _sumb = _mm512_unpackhi_epi64(_tmpb, _tmp9); + _sumc = _mm512_unpacklo_epi64(_tmpc, _tmpe); + _sumd = _mm512_unpackhi_epi64(_tmpc, _tmpe); + _sume = _mm512_unpacklo_epi64(_tmpf, _tmpd); + _sumf = _mm512_unpackhi_epi64(_tmpf, _tmpd); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + _sum9 = _mm512_shuffle_epi32(_sum9, _MM_PERM_CBAD); + _sumb = _mm512_shuffle_epi32(_sumb, _MM_PERM_CBAD); + _sumd = _mm512_shuffle_epi32(_sumd, _MM_PERM_CBAD); + _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sumc, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum8, _sum4, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sumc, _sum8, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum1, _sumd, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum5, _sum1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum9, _sum5, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sumd, _sum9, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmp8 = _mm512_shuffle_i32x4(_sum2, _sume, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp9 = _mm512_shuffle_i32x4(_sum6, _sum2, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmpa = _mm512_shuffle_i32x4(_suma, _sum6, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmpb = _mm512_shuffle_i32x4(_sume, _suma, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmpc = _mm512_shuffle_i32x4(_sum3, _sumf, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmpd = _mm512_shuffle_i32x4(_sum7, _sum3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmpe = _mm512_shuffle_i32x4(_sumb, _sum7, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmpf = _mm512_shuffle_i32x4(_sumf, _sumb, _MM_SHUFFLE(1, 3, 1, 3)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp8, _tmpa, _MM_SHUFFLE(3, 1, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmpc, _tmpe, _MM_SHUFFLE(3, 1, 2, 0)); + _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp9, _tmpb, _MM_SHUFFLE(3, 1, 2, 0)); + _sum7 = _mm512_shuffle_i32x4(_tmpd, _tmpf, _MM_SHUFFLE(3, 1, 2, 0)); + _sum8 = _mm512_shuffle_i32x4(_tmp2, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum9 = _mm512_shuffle_i32x4(_tmp6, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _suma = _mm512_shuffle_i32x4(_tmpa, _tmp8, _MM_SHUFFLE(3, 1, 2, 0)); + _sumb = _mm512_shuffle_i32x4(_tmpe, _tmpc, _MM_SHUFFLE(3, 1, 2, 0)); + _sumc = _mm512_shuffle_i32x4(_tmp3, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _sumd = _mm512_shuffle_i32x4(_tmp7, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _sume = _mm512_shuffle_i32x4(_tmpb, _tmp9, _MM_SHUFFLE(3, 1, 2, 0)); + _sumf = _mm512_shuffle_i32x4(_tmpf, _tmpd, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + _mm512_store_si512((__m512i*)(outptr + 128), _sum8); + _mm512_store_si512((__m512i*)(outptr + 128 + 16), _sum9); + _mm512_store_si512((__m512i*)(outptr + 128 + 32), _suma); + _mm512_store_si512((__m512i*)(outptr + 128 + 48), _sumb); + _mm512_store_si512((__m512i*)(outptr + 128 + 64), _sumc); + _mm512_store_si512((__m512i*)(outptr + 128 + 80), _sumd); + _mm512_store_si512((__m512i*)(outptr + 128 + 96), _sume); + _mm512_store_si512((__m512i*)(outptr + 128 + 112), _sumf); + outptr += 256; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 16 * 2)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 16 * 3)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 16 * 4)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 16 * 5)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 16 * 6)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 16 * 7)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); +#endif + + pA += 32; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m256i _pB = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pB)); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA0, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA0, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA1, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA1, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA1, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 16; + pB += 8; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 80 91 a2 b3 c4 d5 e6 f7 + // 01 12 23 30 45 56 67 74 81 92 a3 b0 c5 d6 e7 f4 + // 02 13 20 31 46 57 64 75 82 93 a0 b1 c6 d7 e4 f5 + // 03 10 21 32 47 54 65 76 83 90 a1 b2 c7 d4 e5 f6 + // 40 51 62 73 04 15 26 37 c0 d1 e2 f3 84 95 a6 b7 + // 41 52 63 70 05 16 27 34 c1 d2 e3 f0 85 96 a7 b4 + // 42 53 60 71 06 17 24 35 c2 d3 e0 f1 86 97 a4 b5 + // 43 50 61 72 07 14 25 36 c3 d0 e1 f2 87 94 a5 b6 + // to + // 00 10 20 30 44 54 64 74 80 90 a0 b0 c4 d4 e4 f4 + // 01 11 21 31 45 55 65 75 81 91 a1 b1 c5 d5 e5 f5 + // 02 12 22 32 46 56 66 76 82 92 a2 b2 c6 d6 e6 f6 + // 03 13 23 33 47 57 67 77 83 93 a3 b3 c7 d7 e7 f7 + // 40 50 60 70 04 14 24 34 c0 d0 e0 f0 84 94 a4 b4 + // 41 51 61 71 05 15 25 35 c1 d1 e1 f1 85 95 a5 b5 + // 42 52 62 72 06 16 26 36 c2 d2 e2 f2 86 96 a6 b6 + // 43 53 63 73 07 17 27 37 c3 d3 e3 f3 87 97 a7 b7 + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + } + + // TODO + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum5, _sum1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum6, _sum2, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sum7, _sum3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _sum4 = _mm512_shuffle_i32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _sum7 = _mm512_shuffle_i32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 16 * 2), _sum2); + _mm512_store_si512((__m512i*)(outptr + 16 * 3), _sum3); + _mm512_store_si512((__m512i*)(outptr + 16 * 4), _sum4); + _mm512_store_si512((__m512i*)(outptr + 16 * 5), _sum5); + _mm512_store_si512((__m512i*)(outptr + 16 * 6), _sum6); + _mm512_store_si512((__m512i*)(outptr + 16 * 7), _sum7); + outptr += 16 * 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 16 * 2)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 16 * 3)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_castsi128_si512(_mm_loadu_si128((const __m128i*)pB)); + __m512i _pB0 = _mm512_shuffle_i32x4(_pB, _pB, _MM_SHUFFLE(0, 0, 0, 0)); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); +#endif + + pA += 32; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB))); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s3 = _mm512_mullo_epi32(_pA1, _pB1); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 16; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 40 51 62 73 80 91 a2 b3 c0 d1 e2 f3 + // 01 12 23 30 41 52 63 70 81 92 a3 b0 c1 d2 e3 f0 + // 20 31 02 13 60 71 42 53 a0 b1 82 93 e0 f1 c2 d3 + // 21 32 03 10 61 72 43 50 a1 b2 83 90 e1 f2 c3 d0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 16 * 2), _sum2); + _mm512_store_si512((__m512i*)(outptr + 16 * 3), _sum3); + outptr += 16 * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA, _pB1)); +#endif + + pA += 32; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_set1_epi32(((const int*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ABAB); + + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA, _pB1); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 16; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 40 51 60 71 80 91 a0 b1 c0 d1 e0 f1 + // 01 10 21 30 41 50 61 70 81 90 a1 b0 c1 d0 e1 f0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + { + __m512i _tmp0 = _mm512_shuffle_epi32(_sum0, _MM_PERM_DBCA); + __m512i _tmp1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_ACDB); + _sum0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + } + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + outptr += 16 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_set1_epi32(((const int*)pB)[0]); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB)); +#endif + + pA += 32; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB = _mm512_set1_epi32(pB[0]); + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB); + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 16; + pB += 1; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + outptr += 16; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_loadu_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_loadu_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_loadu_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_loadu_si512((const __m512i*)(outptr + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA00, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA00, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA11, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA11, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA11, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA11, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA00, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA00, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA11, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA11, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA11, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA11, _pB3)); +#endif // __AVX512VNNI__ + + pA += 16; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA00, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA00, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA00, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA00, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA11, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA11, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA11, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA11, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 8; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 08 19 2a 3b 4c 5d 6e 7f + // 01 12 23 30 45 56 67 74 09 1a 2b 38 4d 5e 6f 7c + // 02 13 20 31 46 57 64 75 0a 1b 28 39 4e 5f 6c 7d + // 03 10 21 32 47 54 65 76 0b 18 29 3a 4f 5c 6d 7e + // 40 51 62 73 04 15 26 37 48 59 6a 7b 0c 1d 2e 3f + // 41 52 63 70 05 16 27 34 49 5a 6b 78 0d 1e 2f 3c + // 42 53 60 71 06 17 24 35 4a 5b 68 79 0e 1f 2c 3d + // 43 50 61 72 07 14 25 36 4b 58 69 7a 0f 1c 2d 3e + // to + // 00 10 20 30 44 54 64 74 08 18 28 38 4c 5c 6c 7c + // 01 11 21 31 45 55 65 75 09 19 29 39 4d 5d 6d 7d + // 02 12 22 32 46 56 66 76 0a 1a 2a 3a 4e 5e 6e 7e + // 03 13 23 33 47 57 67 77 0b 1b 2b 3b 4f 5f 6f 7f + // 40 50 60 70 04 14 24 34 48 58 68 78 0c 1c 2c 3c + // 41 51 61 71 05 15 25 35 49 59 69 79 0d 1d 2d 3d + // 42 52 62 72 06 16 26 36 4a 5a 6a 7a 0e 1e 2e 3e + // 43 53 63 73 07 17 27 37 4b 5b 6b 7b 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(1, 3, 1, 3)); + _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); + _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + _mm512_storeu_si512((__m512i*)(outptr + 64), _sum4); + _mm512_storeu_si512((__m512i*)(outptr + 80), _sum5); + _mm512_storeu_si512((__m512i*)(outptr + 96), _sum6); + _mm512_storeu_si512((__m512i*)(outptr + 112), _sum7); + outptr += 128; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX512F__ + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; +#else + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + __m256i _sum4; + __m256i _sum5; + __m256i _sum6; + __m256i _sum7; +#endif // __AVX512F__ + + if (k == 0) + { +#if __AVX512F__ + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); +#else + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + _sum4 = _mm256_setzero_si256(); + _sum5 = _mm256_setzero_si256(); + _sum6 = _mm256_setzero_si256(); + _sum7 = _mm256_setzero_si256(); +#endif // __AVX512F__ + } + else + { +#if __AVX512F__ + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); +#else + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + _sum4 = _mm256_load_si256((const __m256i*)(outptr + 32)); + _sum5 = _mm256_load_si256((const __m256i*)(outptr + 40)); + _sum6 = _mm256_load_si256((const __m256i*)(outptr + 48)); + _sum7 = _mm256_load_si256((const __m256i*)(outptr + 56)); +#endif // __AVX512F__ + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_shuffle_epi32(_pB01, _MM_PERM_BADC); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB23); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA11, _pB01); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA11, _pB23); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB01)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB23)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB01)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB23)); +#endif // __AVX512VNNI__ +#else // __AVX512F__ + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + +#if __AVXVNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm256_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm256_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm256_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm256_dpwssd_epi32(_sum7, _pA1, _pB3); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA0, _pB2)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA0, _pB3)); + _sum4 = _mm256_add_epi32(_sum4, _mm256_madd_epi16(_pA1, _pB0)); + _sum5 = _mm256_add_epi32(_sum5, _mm256_madd_epi16(_pA1, _pB1)); + _sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2)); + _sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3)); +#endif // __AVXVNNI__ +#endif // __AVX512F__ + + pA += 16; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_shuffle_i32x4(_pA00, _pA00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_permutex_epi64(_pB01, _MM_SHUFFLE(2, 3, 0, 1)); + + __m512i _s01 = _mm512_mullo_epi32(_pA00, _pB01); + __m512i _s23 = _mm512_mullo_epi32(_pA00, _pB23); + __m512i _s45 = _mm512_mullo_epi32(_pA11, _pB01); + __m512i _s67 = _mm512_mullo_epi32(_pA11, _pB23); + _sum0 = _mm512_add_epi32(_sum0, _s01); + _sum1 = _mm512_add_epi32(_sum1, _s23); + _sum2 = _mm512_add_epi32(_sum2, _s45); + _sum3 = _mm512_add_epi32(_sum3, _s67); +#else + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA0, _pB2); + __m256i _s3 = _mm256_mullo_epi32(_pA0, _pB3); + __m256i _s4 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s5 = _mm256_mullo_epi32(_pA1, _pB1); + __m256i _s6 = _mm256_mullo_epi32(_pA1, _pB2); + __m256i _s7 = _mm256_mullo_epi32(_pA1, _pB3); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + _sum4 = _mm256_add_epi32(_sum4, _s4); + _sum5 = _mm256_add_epi32(_sum5, _s5); + _sum6 = _mm256_add_epi32(_sum6, _s6); + _sum7 = _mm256_add_epi32(_sum7, _s7); +#endif // __AVX512F__ + + pA += 8; + pB += 8; + } + +#if __AVX512F__ + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 01 12 23 30 45 56 67 74 + // 02 13 20 31 46 57 64 75 03 10 21 32 47 54 65 76 + // 40 51 62 73 04 15 26 37 41 52 63 70 05 16 27 34 + // 42 53 60 71 06 17 24 35 43 50 61 72 07 14 25 36 + // to + // 00 10 20 30 44 54 64 74 04 14 24 34 40 50 60 70 + // 01 11 21 31 45 55 65 75 05 15 25 35 41 51 61 71 + // 02 12 22 32 46 56 66 76 06 16 26 36 42 52 62 72 + // 03 13 23 33 47 57 67 77 07 17 27 37 43 53 63 73 + { + __m512i _s0 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s1 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(2, 3, 3, 2)); + __m512i _s2 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s3 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(2, 3, 3, 2)); + _s1 = _mm512_shuffle_epi32(_s1, _MM_PERM_ADCB); + _s2 = _mm512_shuffle_epi32(_s2, _MM_PERM_BADC); + _s3 = _mm512_shuffle_epi32(_s3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_s0, _s1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_s0, _s1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_s2, _s3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_s2, _s3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 2, 1, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 2, 1, 2)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; +#else + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 + // 01 12 23 30 45 56 67 74 + // 02 13 20 31 46 57 64 75 + // 03 10 21 32 47 54 65 76 + // 40 51 62 73 04 15 26 37 + // 41 52 63 70 05 16 27 34 + // 42 53 60 71 06 17 24 35 + // 43 50 61 72 07 14 25 36 + // to + // 00 10 20 30 44 54 64 74 + // 01 11 21 31 45 55 65 75 + // 02 12 22 32 46 56 66 76 + // 03 13 23 33 47 57 67 77 + // 40 50 60 70 04 14 24 34 + // 41 51 61 71 05 15 25 35 + // 42 52 62 72 06 16 26 36 + // 43 53 63 73 07 17 27 37 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(0, 3, 2, 1)); + _sum5 = _mm256_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum6 = _mm256_shuffle_epi32(_sum6, _MM_SHUFFLE(1, 0, 3, 2)); + _sum7 = _mm256_shuffle_epi32(_sum7, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + __m256i _tmp4 = _mm256_unpacklo_epi32(_sum4, _sum7); + __m256i _tmp5 = _mm256_unpackhi_epi32(_sum4, _sum7); + __m256i _tmp6 = _mm256_unpacklo_epi32(_sum6, _sum5); + __m256i _tmp7 = _mm256_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm256_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm256_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm256_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm256_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum5 = _mm256_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm256_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + } + + __m256i _tmp0 = _mm256_permute2x128_si256(_sum0, _sum4, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp1 = _mm256_permute2x128_si256(_sum1, _sum5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp2 = _mm256_permute2x128_si256(_sum2, _sum6, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp3 = _mm256_permute2x128_si256(_sum3, _sum7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp4 = _mm256_permute2x128_si256(_sum4, _sum0, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp5 = _mm256_permute2x128_si256(_sum5, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp6 = _mm256_permute2x128_si256(_sum6, _sum2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp7 = _mm256_permute2x128_si256(_sum7, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + _sum4 = _tmp4; + _sum5 = _tmp5; + _sum6 = _tmp6; + _sum7 = _tmp7; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 8 * 2), _sum2); + _mm256_store_si256((__m256i*)(outptr + 8 * 3), _sum3); + _mm256_store_si256((__m256i*)(outptr + 8 * 4), _sum4); + _mm256_store_si256((__m256i*)(outptr + 8 * 5), _sum5); + _mm256_store_si256((__m256i*)(outptr + 8 * 6), _sum6); + _mm256_store_si256((__m256i*)(outptr + 8 * 7), _sum7); + outptr += 8 * 8; +#endif // __AVX512F__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif + + pA += 16; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m256i _pA0 = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB0 = _mm256_cvtepi16_epi32(_mm_castpd_si128(_mm_load1_pd((const double*)pB))); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + + pA += 8; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 40 51 62 73 + // 01 12 23 30 41 52 63 70 + // 20 31 02 13 60 71 42 53 + // 21 32 03 10 61 72 43 50 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 16), _sum2); + _mm256_store_si256((__m256i*)(outptr + 24), _sum3); + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA, _pB1)); +#endif + + pA += 16; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB0 = _mm256_cvtepi16_epi32(_mm_castps_si128(_mm_load1_ps((const float*)pB))); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA, _pB1); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + + pA += 8; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 40 51 60 71 + // 01 10 21 30 41 50 61 70 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + { + __m256i _tmp0 = _mm256_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _tmp1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1)); + _sum0 = _mm256_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm256_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + outptr += 16; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m256i _sum0; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA, _pB)); + + pA += 16; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB = _mm256_set1_epi32(pB[0]); + __m256i _s0 = _mm256_mullo_epi32(_pA, _pB); + _sum0 = _mm256_add_epi32(_sum0, _s0); + + pA += 8; + pB += 1; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + outptr += 8; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m256i _pAA = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m512i _pA0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pAA), _pAA, 1); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); +#endif + + pA += 8; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi16_epi32(_pA); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s3 = _mm512_mullo_epi32(_pA1, _pB1); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 4; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 08 19 2a 3b 0c 1d 2e 3f + // 01 12 23 30 05 16 27 34 09 1a 2b 38 0d 1e 2f 3c + // 20 31 02 13 24 35 06 17 28 3a 0a 1b 2c 3d 0e 1f + // 21 32 03 10 25 36 07 14 29 3a 0b 18 2d 3e 0f 1c + // to + // 00 10 20 30 04 14 24 34 08 18 28 38 0c 1c 2c 3c + // 01 11 21 31 05 15 25 35 09 19 29 39 0d 1d 2d 3d + // 02 12 22 32 06 16 26 36 0a 1a 2a 3a 0e 1e 2e 3e + // 03 13 23 33 07 17 27 37 0b 1b 2b 3b 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; +#else + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + __m128i _sum4; + __m128i _sum5; + __m128i _sum6; + __m128i _sum7; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + _sum4 = _mm_setzero_si128(); + _sum5 = _mm_setzero_si128(); + _sum6 = _mm_setzero_si128(); + _sum7 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_loadu_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_loadu_si256((const __m256i*)(outptr + 24)); +#else + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + _sum4 = _mm_load_si128((const __m128i*)(outptr + 16)); + _sum5 = _mm_load_si128((const __m128i*)(outptr + 20)); + _sum6 = _mm_load_si128((const __m128i*)(outptr + 24)); + _sum7 = _mm_load_si128((const __m128i*)(outptr + 28)); +#endif + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif +#else // __AVX2__ + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maddd_epi16(_pA0, _pB2, _sum2); + _sum3 = _mm_maddd_epi16(_pA0, _pB3, _sum3); + _sum4 = _mm_maddd_epi16(_pA1, _pB0, _sum4); + _sum5 = _mm_maddd_epi16(_pA1, _pB1, _sum5); + _sum6 = _mm_maddd_epi16(_pA1, _pB2, _sum6); + _sum7 = _mm_maddd_epi16(_pA1, _pB3, _sum7); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA0, _pB2)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA0, _pB3)); + _sum4 = _mm_add_epi32(_sum4, _mm_madd_epi16(_pA1, _pB0)); + _sum5 = _mm_add_epi32(_sum5, _mm_madd_epi16(_pA1, _pB1)); + _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); + _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); +#endif +#endif // __AVX2__ + + pA += 8; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + +#if __AVX2__ + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1); + + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); +#else // __AVX2__ +#if __XOP__ + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_unpackhi_epi16(_pB, _pB); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA0, _pB2, _sum2); + _sum3 = _mm_maccd_epi16(_pA0, _pB3, _sum3); + _sum4 = _mm_maccd_epi16(_pA1, _pB0, _sum4); + _sum5 = _mm_maccd_epi16(_pA1, _pB1, _sum5); + _sum6 = _mm_maccd_epi16(_pA1, _pB2, _sum6); + _sum7 = _mm_maccd_epi16(_pA1, _pB3, _sum7); +#else + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB01 = _pB; + __m128i _pB23 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA0, _pB23); + __m128i _sh1 = _mm_mulhi_epi16(_pA0, _pB23); + __m128i _sl2 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh2 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _sl3 = _mm_mullo_epi16(_pA1, _pB23); + __m128i _sh3 = _mm_mulhi_epi16(_pA1, _pB23); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + __m128i _s4 = _mm_unpacklo_epi16(_sl2, _sh2); + __m128i _s5 = _mm_unpackhi_epi16(_sl2, _sh2); + __m128i _s6 = _mm_unpacklo_epi16(_sl3, _sh3); + __m128i _s7 = _mm_unpackhi_epi16(_sl3, _sh3); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); + _sum4 = _mm_add_epi32(_sum4, _s4); + _sum5 = _mm_add_epi32(_sum5, _s5); + _sum6 = _mm_add_epi32(_sum6, _s6); + _sum7 = _mm_add_epi32(_sum7, _s7); +#endif +#endif // __AVX2__ + + pA += 4; + pB += 8; + } + +#if __AVX2__ + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + // to + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp0 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp1 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp2 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + _tmp3 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + } + + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); + _mm256_storeu_si256((__m256i*)(outptr + 16), _sum2); + _mm256_storeu_si256((__m256i*)(outptr + 24), _sum3); + outptr += 32; +#else // __AVX2__ + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + // to + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + { + _sum2 = _mm_shuffle_epi32(_sum2, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum6 = _mm_shuffle_epi32(_sum6, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum6); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum6); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum7); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum7); + __m128i _tmp4 = _mm_unpacklo_epi32(_sum4, _sum2); + __m128i _tmp5 = _mm_unpackhi_epi32(_sum4, _sum2); + __m128i _tmp6 = _mm_unpacklo_epi32(_sum5, _sum3); + __m128i _tmp7 = _mm_unpackhi_epi32(_sum5, _sum3); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp4); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp4); + _sum2 = _mm_unpacklo_epi64(_tmp5, _tmp1); + _sum3 = _mm_unpackhi_epi64(_tmp5, _tmp1); + _sum4 = _mm_unpacklo_epi64(_tmp2, _tmp6); + _sum5 = _mm_unpackhi_epi64(_tmp2, _tmp6); + _sum6 = _mm_unpacklo_epi64(_tmp7, _tmp3); + _sum7 = _mm_unpackhi_epi64(_tmp7, _tmp3); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum5 = _mm_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + _mm_store_si128((__m128i*)(outptr + 16), _sum4); + _mm_store_si128((__m128i*)(outptr + 20), _sum5); + _mm_store_si128((__m128i*)(outptr + 24), _sum6); + _mm_store_si128((__m128i*)(outptr + 28), _sum7); + outptr += 32; +#endif // __AVX2__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); +#endif + + pA += 8; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); +#if __XOP__ + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maccd_epi16(_pA1, _pB1, _sum3); +#else + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif + + pA += 4; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 + // 01 12 23 30 + // 20 31 02 13 + // 21 32 03 10 + // to + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 + { + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum3); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum3); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum2, _sum1); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); +#endif + + pA += 8; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __XOP__ + _pA = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + _sum0 = _mm_maccd_epi16(_pA, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA, _pB1, _sum1); +#else + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 1, 0, 1)); + __m128i _sl = _mm_mullo_epi16(_pA, _pB01); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); +#endif + + pA += 4; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 + // 01 10 21 30 + // to + // 00 10 20 30 + // 01 11 21 31 + { + __m128i _tmp0 = _mm_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _tmp1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1)); + _sum0 = _mm_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + outptr += 8; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA, _pB, _sum0); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); +#endif + + pA += 8; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(pB[0]); + +#if __XOP__ + _pA = _mm_unpacklo_epi16(_pA, _pA); + _sum0 = _mm_maccd_epi16(_pA, _pB, _sum0); +#else + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); +#endif + + pA += 4; + pB += 1; + } + + _mm_store_si128((__m128i*)outptr, _sum0); + outptr += 4; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pA1 = _mm512_set1_epi32(((const int*)pA)[1]); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA1, _pB0); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA1, _pB0)); +#endif // __AVX512VNNI__ + + pA += 4; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_set1_epi32(pA[0]); + __m512i _pA1 = _mm512_set1_epi32(pA[1]); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA1, _pB0); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 2; + pB += 16; + } + + if (k_end) + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(1, 0, 1, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_sum0, _sum0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_sum1, _sum1, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + outptr += 32; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; +#else + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); +#else + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_loadu_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_loadu_si128((const __m128i*)(outptr + 12)); +#endif + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m256i _pA0 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pA)); + __m256i _pA1 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(pA + 2))); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + + // vs2019 internal compiler error with avx512 vnni intrinsics here + // fallback to avx2 madd anyway as a workaround --- nihui + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA1, _pB0)); +#else // __AVX2__ + __m128i _pA0 = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pA1 = _mm_castps_si128(_mm_load1_ps((const float*)(pA + 2))); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); +#endif // __AVX2__ + + pA += 4; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pB = _mm_load_si128((const __m128i*)pB); +#if __AVX2__ + __m256i _pA0 = _mm256_set1_epi32(pA[0]); + __m256i _pA1 = _mm256_set1_epi32(pA[1]); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA1, _pB0); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); +#else // __AVX2__ + __m128i _pA0 = _mm_set1_epi16(pA[0]); + __m128i _pA1 = _mm_set1_epi16(pA[1]); + + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif // __AVX2__ + pA += 2; + pB += 8; + } + +#if __AVX2__ + if (k_end) + { + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum1); + _sum0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + } + + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); + outptr += 16; +#else // __AVX2__ + if (k_end) + { + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum2); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum2); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum3); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum3); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + _mm_storeu_si128((__m128i*)(outptr + 8), _sum2); + _mm_storeu_si128((__m128i*)(outptr + 12), _sum3); + outptr += 16; +#endif // __AVX2__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pA1 = _mm_castps_si128(_mm_load1_ps((const float*)(pA + 2))); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA1, _pB)); + + pA += 4; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA0 = _mm_set1_epi16(pA[0]); + __m128i _pA1 = _mm_set1_epi16(pA[1]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpacklo_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + pA += 2; + pB += 4; + } + + if (k_end) + { + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum1); + _sum0 = _tmp0; + _sum1 = _tmp1; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + outptr += 2 * 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum00 = 0; + int sum01 = 0; + int sum10 = 0; + int sum11 = 0; + + if (k == 0) + { + sum00 = 0; + sum01 = 0; + sum10 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum00 += pA[0] * pB[0]; + sum00 += pA[1] * pB[1]; + sum01 += pA[2] * pB[0]; + sum01 += pA[3] * pB[1]; + sum10 += pA[0] * pB[2]; + sum10 += pA[1] * pB[3]; + sum11 += pA[2] * pB[2]; + sum11 += pA[3] * pB[3]; + + pA += 4; + pB += 4; + } + for (; kk < max_kk; kk++) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + outptr += 2 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[2] * pB[0]; + sum1 += pA[3] * pB[1]; + pA += 4; + pB += 2; + } + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + } + } + for (; ii < max_ii; ii++) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); +#endif + + pA += 2; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_set1_epi32(pA[0]); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_loadu_si256((const __m256i*)pB)); + + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB0); + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 1; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + outptr += 16; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); + + pA += 2; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + pA += 1; + pB += 8; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + outptr += 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); + + pA += 2; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + pA += 1; + pB += 4; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + outptr += 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[0] * pB[2]; + sum1 += pA[1] * pB[3]; + pA += 2; + pB += 4; + } + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum = 0; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + outptr += 1; + } + } + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size_int8 = (int)(get_cpu_level2_cache_size() / sizeof(short)); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + // solve M + { + int tile_size = (int)sqrt((float)l2_cache_size_int8 / 3); + +#if __AVX512F__ + TILE_M = std::max(16, tile_size / 16 * 16); +#elif __AVX2__ + TILE_M = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_M = std::max(4, tile_size / 4 * 4); +#else + TILE_M = std::max(2, tile_size / 2 * 2); +#endif + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __AVX512F__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); +#elif __AVX2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + + if (nT > 1) + { +#if __AVX512F__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); +#elif __AVX2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + } + + // solve K + { + int tile_size = (int)(sqrt((float)l2_cache_size_int8) - TILE_M); + +#if __AVX512F__ + TILE_K = std::max(16, tile_size / 16 * 16); +#elif __AVX2__ + TILE_K = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_K = std::max(4, tile_size / 4 * 4); +#else + TILE_K = std::max(2, tile_size / 2 * 2); +#endif + + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __AVX512F__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); +#elif __AVX2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __SSE2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + } + + if (N > 0) + { + int tile_size = (int)((l2_cache_size_int8 - TILE_M * TILE_K) / (TILE_M * 2 + TILE_K)); + +#if __SSE2__ + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_N = std::max(1, tile_size); +#endif + + int nn_N = (N + TILE_N - 1) / TILE_N; +#if __SSE2__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + } +} + +static inline void conv3x3s1_winograd23_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const signed char ktm[4][3] = { + // {2, 0, 0}, + // {1, 1, 1}, + // {1, -1, 1}, + // {0, 0, 2} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[4][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 2; + tmp[1][m] = r0 + r1 + r2; + tmp[2][m] = r0 - r1 + r2; + tmp[3][m] = r2 * 2; + + k0 += 3; + } + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 2; + short z1 = r0 + r1 + r2; + short z2 = r0 - r1 + r2; + short z3 = r2 * 2; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp += 4; + } + } + } +} + +static void conv3x3s1_winograd23_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd23_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); + return; + } +#endif +#endif + + const int M = outch; + const int K = inch; + const int B = 16; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 2u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd23_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd23_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const signed char itm[4][4] = { + // {1, 0, -1, 0}, + // {0, 1, 1, 0}, + // {0, -1, 1, 0}, + // {0, -1, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w - 1) / 2; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __SSE2__ +#if __AVX512F__ + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + short tmp[4][4][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_setzero_si256(); + __m256i _r1 = _mm256_setzero_si256(); + __m256i _r2 = _mm256_setzero_si256(); + __m256i _r3 = _mm256_setzero_si256(); + + if (ti * 2 + m < h) + { + if (elempack == 16) + { + _r0 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)r0)); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 16))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 32))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 48))); + } + if (elempack == 8) + { + const signed char* r1 = r0 + N; + + _r0 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)r0), _mm_loadl_epi64((const __m128i*)r1))); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 8)), _mm_loadl_epi64((const __m128i*)(r1 + 8)))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 16)), _mm_loadl_epi64((const __m128i*)(r1 + 16)))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 24)), _mm_loadl_epi64((const __m128i*)(r1 + 24)))); + } + if (elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _r0 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)r0, sizeof(signed char)))); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 1), sizeof(signed char)))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 2), sizeof(signed char)))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 3), sizeof(signed char)))); + } + } + + __m256i _tmp0 = _mm256_sub_epi16(_r0, _r2); + __m256i _tmp1 = _mm256_add_epi16(_r1, _r2); + __m256i _tmp2 = _mm256_sub_epi16(_r2, _r1); + __m256i _tmp3 = _mm256_sub_epi16(_r3, _r1); + + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 16; + short* p1 = p0 + max_jj * 16; + short* p2 = p0 + max_jj * 16 * 2; + short* p3 = p0 + max_jj * 16 * 3; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + + __m256i _tmp0 = _mm256_sub_epi16(_r0, _r2); + __m256i _tmp1 = _mm256_add_epi16(_r1, _r2); + __m256i _tmp2 = _mm256_sub_epi16(_r2, _r1); + __m256i _tmp3 = _mm256_sub_epi16(_r3, _r1); + + _mm256_store_si256((__m256i*)p0, _tmp0); + _mm256_store_si256((__m256i*)p1, _tmp1); + _mm256_store_si256((__m256i*)p2, _tmp2); + _mm256_store_si256((__m256i*)p3, _tmp3); + + p0 += max_jj * 4 * 16; + p1 += max_jj * 4 * 16; + p2 += max_jj * 4 * 16; + p3 += max_jj * 4 * 16; + } + } + } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) +#endif // __AVX512F__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + short tmp[4][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + __m128i _r0 = _mm_setzero_si128(); + __m128i _r1 = _mm_setzero_si128(); + __m128i _r2 = _mm_setzero_si128(); + __m128i _r3 = _mm_setzero_si128(); + + if (ti * 2 + m < h) + { + if (elempack == 8) + { + _r0 = _mm_loadl_epi64((const __m128i*)r0); + _r0 = _mm_unpacklo_epi8(_r0, _mm_cmpgt_epi8(_mm_setzero_si128(), _r0)); + if (tj * 2 + 1 < w) + { + _r1 = _mm_loadl_epi64((const __m128i*)(r0 + 8)); + _r1 = _mm_unpacklo_epi8(_r1, _mm_cmpgt_epi8(_mm_setzero_si128(), _r1)); + } + if (tj * 2 + 2 < w) + { + _r2 = _mm_loadl_epi64((const __m128i*)(r0 + 16)); + _r2 = _mm_unpacklo_epi8(_r2, _mm_cmpgt_epi8(_mm_setzero_si128(), _r2)); + } + if (tj * 2 + 3 < w) + { + _r3 = _mm_loadl_epi64((const __m128i*)(r0 + 24)); + _r3 = _mm_unpacklo_epi8(_r3, _mm_cmpgt_epi8(_mm_setzero_si128(), _r3)); + } + } + if (elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); +#if __AVX512F__ + _r0 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)))); + if (tj * 2 + 1 < w) _r1 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)))); + if (tj * 2 + 2 < w) _r2 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)))); + if (tj * 2 + 3 < w) _r3 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)))); +#else + __m128i _sindex8 = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m256i _sindex88 = _mm256_inserti128_si256(_mm256_castsi128_si256(_sindex8), _sindex8, 1); + __m256i _val0_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)), _sindex88); + _r0 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val0_32, 0), _mm256_extracti128_si256(_val0_32, 1))); + if (tj * 2 + 1 < w) + { + __m256i _val1_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)), _sindex88); + _r1 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val1_32, 0), _mm256_extracti128_si256(_val1_32, 1))); + } + if (tj * 2 + 2 < w) + { + __m256i _val2_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)), _sindex88); + _r2 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val2_32, 0), _mm256_extracti128_si256(_val2_32, 1))); + } + if (tj * 2 + 3 < w) + { + __m256i _val3_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)), _sindex88); + _r3 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val3_32, 0), _mm256_extracti128_si256(_val3_32, 1))); + } +#endif // __AVX512F__ +#else // __AVX2__ + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + __m128i _t0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _t1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _t2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _t3 = _mm_loadl_epi64((const __m128i*)r3); + __m128i _t4 = _mm_loadl_epi64((const __m128i*)r4); + __m128i _t5 = _mm_loadl_epi64((const __m128i*)r5); + __m128i _t6 = _mm_loadl_epi64((const __m128i*)r6); + __m128i _t7 = _mm_loadl_epi64((const __m128i*)r7); + + __m128i _t01 = _mm_unpacklo_epi8(_t0, _t1); + __m128i _t23 = _mm_unpacklo_epi8(_t2, _t3); + __m128i _t45 = _mm_unpacklo_epi8(_t4, _t5); + __m128i _t67 = _mm_unpacklo_epi8(_t6, _t7); + _t0 = _mm_unpacklo_epi16(_t01, _t23); + _t1 = _mm_unpacklo_epi16(_t45, _t67); + _t2 = _mm_unpacklo_epi32(_t0, _t1); + _t3 = _mm_unpackhi_epi32(_t0, _t1); + + __m128i _extt2 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t2); + __m128i _extt3 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t3); + + _r0 = _mm_unpacklo_epi8(_t2, _extt2); + if (tj * 2 + 1 < w) _r1 = _mm_unpackhi_epi8(_t2, _extt2); + if (tj * 2 + 2 < w) _r2 = _mm_unpacklo_epi8(_t3, _extt3); + if (tj * 2 + 3 < w) _r3 = _mm_unpackhi_epi8(_t3, _extt3); +#endif // __AVX2__ + } + } + + __m128i _tmp0 = _mm_sub_epi16(_r0, _r2); + __m128i _tmp1 = _mm_add_epi16(_r1, _r2); + __m128i _tmp2 = _mm_sub_epi16(_r2, _r1); + __m128i _tmp3 = _mm_sub_epi16(_r3, _r1); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + // old gcc breaks stack variable alignement + // ref https://gcc.gnu.org/bugzilla/show_bug.cgi?id=16660 + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); +#endif + + __m128i _tmp0 = _mm_sub_epi16(_r0, _r2); + __m128i _tmp1 = _mm_add_epi16(_r1, _r2); + __m128i _tmp2 = _mm_sub_epi16(_r2, _r1); + __m128i _tmp3 = _mm_sub_epi16(_r3, _r1); + + _mm_store_si128((__m128i*)p0, _tmp0); + _mm_store_si128((__m128i*)p1, _tmp1); + _mm_store_si128((__m128i*)p2, _tmp2); + _mm_store_si128((__m128i*)p3, _tmp3); + + p0 += max_jj * 4 * 8; + p1 += max_jj * 4 * 8; + p2 += max_jj * 4 * 8; + p3 += max_jj * 4 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __SSE2__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[4][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 2 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 2 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 2 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + } + } + + tmp[0][m][0] = r00 - r20; + tmp[0][m][1] = r01 - r21; + tmp[1][m][0] = r10 + r20; + tmp[1][m][1] = r11 + r21; + tmp[2][m][0] = r20 - r10; + tmp[2][m][1] = r21 - r11; + tmp[3][m][0] = r30 - r10; + tmp[3][m][1] = r31 - r11; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + + p0[0] = r00 - r20; + p0[1] = r01 - r21; + p1[0] = r10 + r20; + p1[1] = r11 + r21; + p2[0] = r20 - r10; + p2[1] = r21 - r11; + p3[0] = r30 - r10; + p3[1] = r31 - r11; + + p0 += max_jj * 4 * 2; + p1 += max_jj * 4 * 2; + p2 += max_jj * 4 * 2; + p3 += max_jj * 4 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 2 + 1 < w) r1 = r0123[1]; + if (tj * 2 + 2 < w) r2 = r0123[2]; + if (tj * 2 + 3 < w) r3 = r0123[3]; + } + } + + tmp[0][m] = r0 - r2; + tmp[1][m] = r1 + r2; + tmp[2][m] = r2 - r1; + tmp[3][m] = r3 - r1; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + + p0[0] = r0 - r2; + p1[0] = r1 + r2; + p2[0] = r2 - r1; + p3[0] = r3 - r1; + + p0 += max_jj * 4; + p1 += max_jj * 4; + p2 += max_jj * 4; + p3 += max_jj * 4; + } + } + } +} + +static inline void conv3x3s1_winograd23_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[2][4] = { + // {1, 1, 1, 0}, + // {0, 1, -1, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 1) / 2; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + int tmp[2][4][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 16; + const int* r1 = r0 + max_jj * 16; + const int* r2 = r0 + max_jj * 16 * 2; + const int* r3 = r0 + max_jj * 16 * 3; + + for (int m = 0; m < 4; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_r0, _r1), _r2); + __m512i _tmp1 = _mm512_add_epi32(_mm512_sub_epi32(_r1, _r2), _r3); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + + r0 += max_jj * 4 * 16; + r1 += max_jj * 4 * 16; + r2 += max_jj * 4 * 16; + r3 += max_jj * 4 * 16; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + __m512i _r0 = _mm512_load_si512((const __m512i*)tmp[m][0]); + __m512i _r1 = _mm512_load_si512((const __m512i*)tmp[m][1]); + __m512i _r2 = _mm512_load_si512((const __m512i*)tmp[m][2]); + __m512i _r3 = _mm512_load_si512((const __m512i*)tmp[m][3]); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_r0, _r1), _r2); + __m512i _tmp1 = _mm512_add_epi32(_mm512_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm512_srai_epi32(_tmp0, 2); + _tmp1 = _mm512_srai_epi32(_tmp1, 2); + + if (out_elempack == 16) + { + _mm512_store_si512((__m512i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) + { + _mm512_store_si512((__m512i*)(outptr0 + 16), _tmp1); + } + } + if (out_elempack == 8) + { + int* outptr1 = outptr0 + N; + + _mm256_store_si256((__m256i*)outptr0, _mm512_extracti32x8_epi32(_tmp0, 0)); + _mm256_store_si256((__m256i*)outptr1, _mm512_extracti32x8_epi32(_tmp0, 1)); + if (tj * 2 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _mm512_extracti32x8_epi32(_tmp1, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 8), _mm512_extracti32x8_epi32(_tmp1, 1)); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + _mm_store_si128((__m128i*)outptr0, _mm512_extracti32x4_epi32(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm512_extracti32x4_epi32(_tmp0, 1)); + _mm_store_si128((__m128i*)outptr2, _mm512_extracti32x4_epi32(_tmp0, 2)); + _mm_store_si128((__m128i*)outptr3, _mm512_extracti32x4_epi32(_tmp0, 3)); + if (tj * 2 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm512_extracti32x4_epi32(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm512_extracti32x4_epi32(_tmp1, 1)); + _mm_store_si128((__m128i*)(outptr2 + 4), _mm512_extracti32x4_epi32(_tmp1, 2)); + _mm_store_si128((__m128i*)(outptr3 + 4), _mm512_extracti32x4_epi32(_tmp1, 3)); + } + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _mm512_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm512_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + int tmp[2][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_r0, _r1), _r2); + __m256i _tmp1 = _mm256_add_epi32(_mm256_sub_epi32(_r1, _r2), _r3); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); +#endif + + r0 += max_jj * 4 * 8; + r1 += max_jj * 4 * 8; + r2 += max_jj * 4 * 8; + r3 += max_jj * 4 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m256i _r0 = _mm256_loadu_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)tmp[m][3]); +#else + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); +#endif + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_r0, _r1), _r2); + __m256i _tmp1 = _mm256_add_epi32(_mm256_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm256_srai_epi32(_tmp0, 2); + _tmp1 = _mm256_srai_epi32(_tmp1, 2); + + if (out_elempack == 8) + { + _mm256_store_si256((__m256i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _tmp1); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + _mm_store_si128((__m128i*)outptr0, _mm256_extracti128_si256(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm256_extracti128_si256(_tmp0, 1)); + if (tj * 2 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm256_extracti128_si256(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm256_extracti128_si256(_tmp1, 1)); + } + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); + _mm256_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm256_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); +#else + int tmp0[8]; + int tmp1[8]; + _mm256_storeu_si256((__m256i*)tmp0, _tmp0); + _mm256_storeu_si256((__m256i*)tmp1, _tmp1); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + outptr4[0] = tmp0[4]; + outptr5[0] = tmp0[5]; + outptr6[0] = tmp0[6]; + outptr7[0] = tmp0[7]; + + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + outptr4[1] = tmp1[4]; + outptr5[1] = tmp1[5]; + outptr6[1] = tmp1[6]; + outptr7[1] = tmp1[7]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + int tmp[2][4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + + for (int m = 0; m < 4; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_r0, _r1), _r2); + __m128i _tmp1 = _mm_add_epi32(_mm_sub_epi32(_r1, _r2), _r3); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); +#endif + + r0 += max_jj * 4 * 4; + r1 += max_jj * 4 * 4; + r2 += max_jj * 4 * 4; + r3 += max_jj * 4 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); +#endif + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_r0, _r1), _r2); + __m128i _tmp1 = _mm_add_epi32(_mm_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm_srai_epi32(_tmp0, 2); + _tmp1 = _mm_srai_epi32(_tmp1, 2); + + if (out_elempack == 4) + { + _mm_store_si128((__m128i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) _mm_store_si128((__m128i*)(outptr0 + 4), _tmp1); + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(N)); + _mm_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); +#else + int tmp0[4]; + int tmp1[4]; + _mm_storeu_si128((__m128i*)tmp0, _tmp0); + _mm_storeu_si128((__m128i*)tmp1, _tmp1); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[2][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m][0] = r0[0] + r1[0] + r2[0]; + tmp[0][m][1] = r0[1] + r1[1] + r2[1]; + tmp[1][m][0] = r1[0] - r2[0] + r3[0]; + tmp[1][m][1] = r1[1] - r2[1] + r3[1]; + + r0 += max_jj * 4 * 2; + r1 += max_jj * 4 * 2; + r2 += max_jj * 4 * 2; + r3 += max_jj * 4 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp00 = tmp[m][0][0] + tmp[m][1][0] + tmp[m][2][0]; + int tmp01 = tmp[m][0][1] + tmp[m][1][1] + tmp[m][2][1]; + int tmp10 = tmp[m][1][0] - tmp[m][2][0] + tmp[m][3][0]; + int tmp11 = tmp[m][1][1] - tmp[m][2][1] + tmp[m][3][1]; + + tmp00 = tmp00 >> 2; + tmp01 = tmp01 >> 2; + tmp10 = tmp10 >> 2; + tmp11 = tmp11 >> 2; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[2][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m] = r0[0] + r1[0] + r2[0]; + tmp[1][m] = r1[0] - r2[0] + r3[0]; + + r0 += max_jj * 4; + r1 += max_jj * 4; + r2 += max_jj * 4; + r3 += max_jj * 4; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp0 = tmp[m][0] + tmp[m][1] + tmp[m][2]; + int tmp1 = tmp[m][1] - tmp[m][2] + tmp[m][3]; + + tmp0 = tmp0 >> 2; + tmp1 = tmp1 >> 2; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 2 + 1 < outw) outptr0[1] = tmp1; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + conv3x3s1_winograd23_int8_avx512vnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + conv3x3s1_winograd23_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd23_int8_avx2(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ + if (ncnn::cpu_support_x86_xop()) + { + conv3x3s1_winograd23_int8_xop(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif +#endif + + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 2n+2, winograd F(2,3) + int w_tiles = (outw + 1) / 2; + int h_tiles = (outh + 1) / 2; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 16; + + // NCNN_LOGE("conv3x3s1_winograd23_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 2u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 2u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + bool k_end = k + TILE_K >= K; + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk, k_end); + } + + // transform output + conv3x3s1_winograd23_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} + +static inline void conv3x3s1_winograd43_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const short ktm[6][3] = { + // {6, 0, 0}, + // {-4, -4, -4}, + // {-4, 4, -4}, + // {1, 2, 4}, + // {1, -2, 4}, + // {0, 0, 6} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[6][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 6; + tmp[1][m] = -r0 * 4 - r1 * 4 - r2 * 4; + tmp[2][m] = -r0 * 4 + r1 * 4 - r2 * 4; + tmp[3][m] = r0 + r1 * 2 + r2 * 4; + tmp[4][m] = r0 - r1 * 2 + r2 * 4; + tmp[5][m] = r2 * 6; + + k0 += 3; + } + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 6; + short z1 = -r0 * 4 - r1 * 4 - r2 * 4; + short z2 = -r0 * 4 + r1 * 4 - r2 * 4; + short z3 = r0 + r1 * 2 + r2 * 4; + short z4 = r0 - r1 * 2 + r2 * 4; + short z5 = r2 * 6; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp[4] = z4; + ptmp[5] = z5; + ptmp += 6; + } + } + } +} + +static void conv3x3s1_winograd43_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd43_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); + return; + } +#endif +#endif + + const int M = outch; + const int K = inch; + const int B = 36; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 4u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 4u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd43_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd43_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const float itm[4][4] = { + // {4, 0, -5, 0, 1, 0}, + // {0, -4, -4, 1, 1, 0}, + // {0, 4, -4, -1, 1, 0}, + // {0, -2, -1, 2, 1, 0}, + // {0, 2, -1, -2, 1, 0}, + // {0, 4, 0, -5, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w + 1) / 4; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __SSE2__ +#if __AVX512F__ + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + short tmp[6][6][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + __m256i _v2 = _mm256_set1_epi16(2); + __m256i _v4 = _mm256_set1_epi16(4); + __m256i _v5 = _mm256_set1_epi16(5); + + for (int m = 0; m < 6; m++) + { + __m256i _r0 = _mm256_setzero_si256(); + __m256i _r1 = _mm256_setzero_si256(); + __m256i _r2 = _mm256_setzero_si256(); + __m256i _r3 = _mm256_setzero_si256(); + __m256i _r4 = _mm256_setzero_si256(); + __m256i _r5 = _mm256_setzero_si256(); + + if (ti * 4 + m < h) + { + if (elempack == 16) + { + _r0 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)r0)); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 16))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 32))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 48))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 64))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 80))); + } + if (elempack == 8) + { + const signed char* r1 = r0 + N; + + _r0 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)r0), _mm_loadl_epi64((const __m128i*)r1))); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 8)), _mm_loadl_epi64((const __m128i*)(r1 + 8)))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 16)), _mm_loadl_epi64((const __m128i*)(r1 + 16)))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 24)), _mm_loadl_epi64((const __m128i*)(r1 + 24)))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 32)), _mm_loadl_epi64((const __m128i*)(r1 + 32)))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 40)), _mm_loadl_epi64((const __m128i*)(r1 + 40)))); + } + if (elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _r0 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)r0, sizeof(signed char)))); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 1), sizeof(signed char)))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 2), sizeof(signed char)))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 3), sizeof(signed char)))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 4), sizeof(signed char)))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 5), sizeof(signed char)))); + } + } + + __m256i _tmp12a = _mm256_sub_epi16(_r3, _mm256_mullo_epi16(_r1, _v4)); + __m256i _tmp12b = _mm256_sub_epi16(_r4, _mm256_mullo_epi16(_r2, _v4)); + __m256i _tmp34a = _mm256_mullo_epi16(_mm256_sub_epi16(_r3, _r1), _v2); + __m256i _tmp34b = _mm256_sub_epi16(_r4, _r2); + + __m256i _tmp0 = _mm256_add_epi16(_r4, _mm256_sub_epi16(_mm256_mullo_epi16(_r0, _v4), _mm256_mullo_epi16(_r2, _v5))); + __m256i _tmp1 = _mm256_add_epi16(_tmp12b, _tmp12a); + __m256i _tmp2 = _mm256_sub_epi16(_tmp12b, _tmp12a); + __m256i _tmp3 = _mm256_add_epi16(_tmp34b, _tmp34a); + __m256i _tmp4 = _mm256_sub_epi16(_tmp34b, _tmp34a); + __m256i _tmp5 = _mm256_add_epi16(_r5, _mm256_sub_epi16(_mm256_mullo_epi16(_r1, _v4), _mm256_mullo_epi16(_r3, _v5))); + + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); + _mm256_store_si256((__m256i*)tmp[4][m], _tmp4); + _mm256_store_si256((__m256i*)tmp[5][m], _tmp5); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 16; + short* p1 = p0 + max_jj * 16; + short* p2 = p0 + max_jj * 16 * 2; + short* p3 = p0 + max_jj * 16 * 3; + short* p4 = p0 + max_jj * 16 * 4; + short* p5 = p0 + max_jj * 16 * 5; + + for (int m = 0; m < 6; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_load_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_load_si256((const __m256i*)tmp[m][5]); + + __m256i _tmp12a = _mm256_sub_epi16(_r3, _mm256_mullo_epi16(_r1, _v4)); + __m256i _tmp12b = _mm256_sub_epi16(_r4, _mm256_mullo_epi16(_r2, _v4)); + __m256i _tmp34a = _mm256_mullo_epi16(_mm256_sub_epi16(_r3, _r1), _v2); + __m256i _tmp34b = _mm256_sub_epi16(_r4, _r2); + + __m256i _tmp0 = _mm256_add_epi16(_r4, _mm256_sub_epi16(_mm256_mullo_epi16(_r0, _v4), _mm256_mullo_epi16(_r2, _v5))); + __m256i _tmp1 = _mm256_add_epi16(_tmp12b, _tmp12a); + __m256i _tmp2 = _mm256_sub_epi16(_tmp12b, _tmp12a); + __m256i _tmp3 = _mm256_add_epi16(_tmp34b, _tmp34a); + __m256i _tmp4 = _mm256_sub_epi16(_tmp34b, _tmp34a); + __m256i _tmp5 = _mm256_add_epi16(_r5, _mm256_sub_epi16(_mm256_mullo_epi16(_r1, _v4), _mm256_mullo_epi16(_r3, _v5))); + + _mm256_store_si256((__m256i*)p0, _tmp0); + _mm256_store_si256((__m256i*)p1, _tmp1); + _mm256_store_si256((__m256i*)p2, _tmp2); + _mm256_store_si256((__m256i*)p3, _tmp3); + _mm256_store_si256((__m256i*)p4, _tmp4); + _mm256_store_si256((__m256i*)p5, _tmp5); + + p0 += max_jj * 6 * 16; + p1 += max_jj * 6 * 16; + p2 += max_jj * 6 * 16; + p3 += max_jj * 6 * 16; + p4 += max_jj * 6 * 16; + p5 += max_jj * 6 * 16; + } + } + } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) +#endif // __AVX512F__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + short tmp[6][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + __m128i _v2 = _mm_set1_epi16(2); + __m128i _v4 = _mm_set1_epi16(4); + __m128i _v5 = _mm_set1_epi16(5); + + for (int m = 0; m < 6; m++) + { + __m128i _r0 = _mm_setzero_si128(); + __m128i _r1 = _mm_setzero_si128(); + __m128i _r2 = _mm_setzero_si128(); + __m128i _r3 = _mm_setzero_si128(); + __m128i _r4 = _mm_setzero_si128(); + __m128i _r5 = _mm_setzero_si128(); + + if (ti * 4 + m < h) + { + if (elempack == 8) + { + _r0 = _mm_loadl_epi64((const __m128i*)r0); + _r0 = _mm_unpacklo_epi8(_r0, _mm_cmpgt_epi8(_mm_setzero_si128(), _r0)); + if (tj * 4 + 1 < w) + { + _r1 = _mm_loadl_epi64((const __m128i*)(r0 + 8)); + _r1 = _mm_unpacklo_epi8(_r1, _mm_cmpgt_epi8(_mm_setzero_si128(), _r1)); + } + if (tj * 4 + 2 < w) + { + _r2 = _mm_loadl_epi64((const __m128i*)(r0 + 16)); + _r2 = _mm_unpacklo_epi8(_r2, _mm_cmpgt_epi8(_mm_setzero_si128(), _r2)); + } + if (tj * 4 + 3 < w) + { + _r3 = _mm_loadl_epi64((const __m128i*)(r0 + 24)); + _r3 = _mm_unpacklo_epi8(_r3, _mm_cmpgt_epi8(_mm_setzero_si128(), _r3)); + } + if (tj * 4 + 4 < w) + { + _r4 = _mm_loadl_epi64((const __m128i*)(r0 + 32)); + _r4 = _mm_unpacklo_epi8(_r4, _mm_cmpgt_epi8(_mm_setzero_si128(), _r4)); + } + if (tj * 4 + 5 < w) + { + _r5 = _mm_loadl_epi64((const __m128i*)(r0 + 40)); + _r5 = _mm_unpacklo_epi8(_r5, _mm_cmpgt_epi8(_mm_setzero_si128(), _r5)); + } + } + if (elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); +#if __AVX512F__ + _r0 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)))); + if (tj * 4 + 1 < w) _r1 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)))); + if (tj * 4 + 2 < w) _r2 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)))); + if (tj * 4 + 3 < w) _r3 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)))); + if (tj * 4 + 4 < w) _r4 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 4), _vindex, sizeof(signed char)))); + if (tj * 4 + 5 < w) _r5 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 5), _vindex, sizeof(signed char)))); +#else + __m128i _sindex8 = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m256i _sindex88 = _mm256_inserti128_si256(_mm256_castsi128_si256(_sindex8), _sindex8, 1); + __m256i _val0_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)), _sindex88); + _r0 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val0_32, 0), _mm256_extracti128_si256(_val0_32, 1))); + if (tj * 4 + 1 < w) + { + __m256i _val1_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)), _sindex88); + _r1 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val1_32, 0), _mm256_extracti128_si256(_val1_32, 1))); + } + if (tj * 4 + 2 < w) + { + __m256i _val2_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)), _sindex88); + _r2 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val2_32, 0), _mm256_extracti128_si256(_val2_32, 1))); + } + if (tj * 4 + 3 < w) + { + __m256i _val3_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)), _sindex88); + _r3 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val3_32, 0), _mm256_extracti128_si256(_val3_32, 1))); + } + if (tj * 4 + 4 < w) + { + __m256i _val4_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 4), _vindex, sizeof(signed char)), _sindex88); + _r4 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val4_32, 0), _mm256_extracti128_si256(_val4_32, 1))); + } + if (tj * 4 + 5 < w) + { + __m256i _val5_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 5), _vindex, sizeof(signed char)), _sindex88); + _r5 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val5_32, 0), _mm256_extracti128_si256(_val5_32, 1))); + } +#endif // __AVX512F__ +#else // __AVX2__ + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + __m128i _t0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _t1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _t2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _t3 = _mm_loadl_epi64((const __m128i*)r3); + __m128i _t4 = _mm_loadl_epi64((const __m128i*)r4); + __m128i _t5 = _mm_loadl_epi64((const __m128i*)r5); + __m128i _t6 = _mm_loadl_epi64((const __m128i*)r6); + __m128i _t7 = _mm_loadl_epi64((const __m128i*)r7); + + __m128i _t01 = _mm_unpacklo_epi8(_t0, _t1); + __m128i _t23 = _mm_unpacklo_epi8(_t2, _t3); + __m128i _t45 = _mm_unpacklo_epi8(_t4, _t5); + __m128i _t67 = _mm_unpacklo_epi8(_t6, _t7); + _t0 = _mm_unpacklo_epi16(_t01, _t23); + _t1 = _mm_unpacklo_epi16(_t45, _t67); + _t2 = _mm_unpacklo_epi32(_t0, _t1); + _t3 = _mm_unpackhi_epi32(_t0, _t1); + + __m128i _extt2 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t2); + __m128i _extt3 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t3); + + _r0 = _mm_unpacklo_epi8(_t2, _extt2); + if (tj * 4 + 1 < w) _r1 = _mm_unpackhi_epi8(_t2, _extt2); + if (tj * 4 + 2 < w) _r2 = _mm_unpacklo_epi8(_t3, _extt3); + if (tj * 4 + 3 < w) _r3 = _mm_unpackhi_epi8(_t3, _extt3); + if (tj * 4 + 4 < w) _r4 = _mm_setr_epi16(r0[4], r1[4], r2[4], r3[4], r4[4], r5[4], r6[4], r7[4]); + if (tj * 4 + 5 < w) _r5 = _mm_setr_epi16(r0[5], r1[5], r2[5], r3[5], r4[5], r5[5], r6[5], r7[5]); +#endif // __AVX2__ + } + } + + __m128i _tmp12a = _mm_sub_epi16(_r3, _mm_mullo_epi16(_r1, _v4)); + __m128i _tmp12b = _mm_sub_epi16(_r4, _mm_mullo_epi16(_r2, _v4)); + __m128i _tmp34a = _mm_mullo_epi16(_mm_sub_epi16(_r3, _r1), _v2); + __m128i _tmp34b = _mm_sub_epi16(_r4, _r2); + + __m128i _tmp0 = _mm_add_epi16(_r4, _mm_sub_epi16(_mm_mullo_epi16(_r0, _v4), _mm_mullo_epi16(_r2, _v5))); + __m128i _tmp1 = _mm_add_epi16(_tmp12b, _tmp12a); + __m128i _tmp2 = _mm_sub_epi16(_tmp12b, _tmp12a); + __m128i _tmp3 = _mm_add_epi16(_tmp34b, _tmp34a); + __m128i _tmp4 = _mm_sub_epi16(_tmp34b, _tmp34a); + __m128i _tmp5 = _mm_add_epi16(_r5, _mm_sub_epi16(_mm_mullo_epi16(_r1, _v4), _mm_mullo_epi16(_r3, _v5))); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); + _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4); + _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); + _mm_store_si128((__m128i*)tmp[4][m], _tmp4); + _mm_store_si128((__m128i*)tmp[5][m], _tmp5); +#endif + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + short* p4 = p0 + max_jj * 8 * 4; + short* p5 = p0 + max_jj * 8 * 5; + + for (int m = 0; m < 6; m++) + { +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_loadu_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_loadu_si128((const __m128i*)tmp[m][5]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_load_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_load_si128((const __m128i*)tmp[m][5]); +#endif + + __m128i _tmp12a = _mm_sub_epi16(_r3, _mm_mullo_epi16(_r1, _v4)); + __m128i _tmp12b = _mm_sub_epi16(_r4, _mm_mullo_epi16(_r2, _v4)); + __m128i _tmp34a = _mm_mullo_epi16(_mm_sub_epi16(_r3, _r1), _v2); + __m128i _tmp34b = _mm_sub_epi16(_r4, _r2); + + __m128i _tmp0 = _mm_add_epi16(_r4, _mm_sub_epi16(_mm_mullo_epi16(_r0, _v4), _mm_mullo_epi16(_r2, _v5))); + __m128i _tmp1 = _mm_add_epi16(_tmp12b, _tmp12a); + __m128i _tmp2 = _mm_sub_epi16(_tmp12b, _tmp12a); + __m128i _tmp3 = _mm_add_epi16(_tmp34b, _tmp34a); + __m128i _tmp4 = _mm_sub_epi16(_tmp34b, _tmp34a); + __m128i _tmp5 = _mm_add_epi16(_r5, _mm_sub_epi16(_mm_mullo_epi16(_r1, _v4), _mm_mullo_epi16(_r3, _v5))); + + _mm_store_si128((__m128i*)p0, _tmp0); + _mm_store_si128((__m128i*)p1, _tmp1); + _mm_store_si128((__m128i*)p2, _tmp2); + _mm_store_si128((__m128i*)p3, _tmp3); + _mm_store_si128((__m128i*)p4, _tmp4); + _mm_store_si128((__m128i*)p5, _tmp5); + + p0 += max_jj * 6 * 8; + p1 += max_jj * 6 * 8; + p2 += max_jj * 6 * 8; + p3 += max_jj * 6 * 8; + p4 += max_jj * 6 * 8; + p5 += max_jj * 6 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __SSE2__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[6][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + signed char r40 = 0; + signed char r41 = 0; + signed char r50 = 0; + signed char r51 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 4 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 4 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 4 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + if (tj * 4 + 4 < w) + { + r40 = r0[4]; + r41 = r1[4]; + } + if (tj * 4 + 5 < w) + { + r50 = r0[5]; + r51 = r1[5]; + } + } + } + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + tmp[0][m][0] = r40 + r00 * 4 - r20 * 5; + tmp[0][m][1] = r41 + r01 * 4 - r21 * 5; + tmp[1][m][0] = tmp120b + tmp120a; + tmp[1][m][1] = tmp121b + tmp121a; + tmp[2][m][0] = tmp120b - tmp120a; + tmp[2][m][1] = tmp121b - tmp121a; + tmp[3][m][0] = tmp340b + tmp340a; + tmp[3][m][1] = tmp341b + tmp341a; + tmp[4][m][0] = tmp340b - tmp340a; + tmp[4][m][1] = tmp341b - tmp341a; + tmp[5][m][0] = r50 + r10 * 4 - r30 * 5; + tmp[5][m][1] = r51 + r11 * 4 - r31 * 5; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + short* p4 = p0 + max_jj * 2 * 4; + short* p5 = p0 + max_jj * 2 * 5; + + for (int m = 0; m < 6; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + short r40 = tmp[m][4][0]; + short r41 = tmp[m][4][1]; + short r50 = tmp[m][5][0]; + short r51 = tmp[m][5][1]; + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + p0[0] = r40 + r00 * 4 - r20 * 5; + p0[1] = r41 + r01 * 4 - r21 * 5; + p1[0] = tmp120b + tmp120a; + p1[1] = tmp121b + tmp121a; + p2[0] = tmp120b - tmp120a; + p2[1] = tmp121b - tmp121a; + p3[0] = tmp340b + tmp340a; + p3[1] = tmp341b + tmp341a; + p4[0] = tmp340b - tmp340a; + p4[1] = tmp341b - tmp341a; + p5[0] = r50 + r10 * 4 - r30 * 5; + p5[1] = r51 + r11 * 4 - r31 * 5; + + p0 += max_jj * 6 * 2; + p1 += max_jj * 6 * 2; + p2 += max_jj * 6 * 2; + p3 += max_jj * 6 * 2; + p4 += max_jj * 6 * 2; + p5 += max_jj * 6 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[6][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + signed char r4 = 0; + signed char r5 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 4 + 1 < w) r1 = r0123[1]; + if (tj * 4 + 2 < w) r2 = r0123[2]; + if (tj * 4 + 3 < w) r3 = r0123[3]; + if (tj * 4 + 4 < w) r4 = r0123[4]; + if (tj * 4 + 5 < w) r5 = r0123[5]; + } + } + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + tmp[0][m] = r4 + r0 * 4 - r2 * 5; + tmp[1][m] = tmp12b + tmp12a; + tmp[2][m] = tmp12b - tmp12a; + tmp[3][m] = tmp34b + tmp34a; + tmp[4][m] = tmp34b - tmp34a; + tmp[5][m] = r5 + r1 * 4 - r3 * 5; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + short* p4 = p0 + max_jj * 4; + short* p5 = p0 + max_jj * 5; + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + short r4 = tmp[m][4]; + short r5 = tmp[m][5]; + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + p0[0] = r4 + r0 * 4 - r2 * 5; + p1[0] = tmp12b + tmp12a; + p2[0] = tmp12b - tmp12a; + p3[0] = tmp34b + tmp34a; + p4[0] = tmp34b - tmp34a; + p5[0] = r5 + r1 * 4 - r3 * 5; + + p0 += max_jj * 6; + p1 += max_jj * 6; + p2 += max_jj * 6; + p3 += max_jj * 6; + p4 += max_jj * 6; + p5 += max_jj * 6; + } + } + } +} + +static inline void conv3x3s1_winograd43_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[4][6] = { + // {1, 1, 1, 1, 1, 0}, + // {0, 1, -1, 2, -2, 0}, + // {0, 1, 1, 4, 4, 0}, + // {0, 1, -1, 8, -8, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 3) / 4; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + int tmp[4][6][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 16; + const int* r1 = r0 + max_jj * 16; + const int* r2 = r0 + max_jj * 16 * 2; + const int* r3 = r0 + max_jj * 16 * 3; + const int* r4 = r0 + max_jj * 16 * 4; + const int* r5 = r0 + max_jj * 16 * 5; + + for (int m = 0; m < 5; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + __m512i _r4 = _mm512_load_si512((const __m512i*)r4); + __m512i _r5 = _mm512_load_si512((const __m512i*)r5); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _mm512_slli_epi32(_r5, 2)); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + _mm512_store_si512((__m512i*)tmp[2][m], _tmp2); + _mm512_store_si512((__m512i*)tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 16; + r1 += max_jj * 6 * 16; + r2 += max_jj * 6 * 16; + r3 += max_jj * 6 * 16; + r4 += max_jj * 6 * 16; + r5 += max_jj * 6 * 16; + } + for (int m = 5; m < 6; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + __m512i _r4 = _mm512_load_si512((const __m512i*)r4); + __m512i _r5 = _mm512_load_si512((const __m512i*)r5); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _mm512_slli_epi32(_r5, 2)); + + _tmp0 = _mm512_slli_epi32(_tmp0, 2); + _tmp1 = _mm512_slli_epi32(_tmp1, 2); + _tmp2 = _mm512_slli_epi32(_tmp2, 2); + _tmp3 = _mm512_slli_epi32(_tmp3, 2); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + _mm512_store_si512((__m512i*)tmp[2][m], _tmp2); + _mm512_store_si512((__m512i*)tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 16; + r1 += max_jj * 6 * 16; + r2 += max_jj * 6 * 16; + r3 += max_jj * 6 * 16; + r4 += max_jj * 6 * 16; + r5 += max_jj * 6 * 16; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + __m512i _r0 = _mm512_load_si512((const __m512i*)tmp[m][0]); + __m512i _r1 = _mm512_load_si512((const __m512i*)tmp[m][1]); + __m512i _r2 = _mm512_load_si512((const __m512i*)tmp[m][2]); + __m512i _r3 = _mm512_load_si512((const __m512i*)tmp[m][3]); + __m512i _r4 = _mm512_load_si512((const __m512i*)tmp[m][4]); + __m512i _r5 = _mm512_load_si512((const __m512i*)tmp[m][5]); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m512 _v576 = _mm512_set1_ps(1.0 / 576); + _tmp0 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 16) + { + _mm512_store_si512((__m512i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm512_store_si512((__m512i*)(outptr0 + 16), _tmp1); + if (tj * 4 + 2 < outw) _mm512_store_si512((__m512i*)(outptr0 + 32), _tmp2); + if (tj * 4 + 3 < outw) _mm512_store_si512((__m512i*)(outptr0 + 48), _tmp3); + } + if (out_elempack == 8) + { + int* outptr1 = outptr0 + N; + + _mm256_store_si256((__m256i*)outptr0, _mm512_extracti32x8_epi32(_tmp0, 0)); + _mm256_store_si256((__m256i*)outptr1, _mm512_extracti32x8_epi32(_tmp0, 1)); + if (tj * 4 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _mm512_extracti32x8_epi32(_tmp1, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 8), _mm512_extracti32x8_epi32(_tmp1, 1)); + } + if (tj * 4 + 2 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 16), _mm512_extracti32x8_epi32(_tmp2, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 16), _mm512_extracti32x8_epi32(_tmp2, 1)); + } + if (tj * 4 + 3 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 24), _mm512_extracti32x8_epi32(_tmp3, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 24), _mm512_extracti32x8_epi32(_tmp3, 1)); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + _mm_store_si128((__m128i*)outptr0, _mm512_extracti32x4_epi32(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm512_extracti32x4_epi32(_tmp0, 1)); + _mm_store_si128((__m128i*)outptr2, _mm512_extracti32x4_epi32(_tmp0, 2)); + _mm_store_si128((__m128i*)outptr3, _mm512_extracti32x4_epi32(_tmp0, 3)); + if (tj * 4 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm512_extracti32x4_epi32(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm512_extracti32x4_epi32(_tmp1, 1)); + _mm_store_si128((__m128i*)(outptr2 + 4), _mm512_extracti32x4_epi32(_tmp1, 2)); + _mm_store_si128((__m128i*)(outptr3 + 4), _mm512_extracti32x4_epi32(_tmp1, 3)); + } + if (tj * 4 + 2 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 8), _mm512_extracti32x4_epi32(_tmp2, 0)); + _mm_store_si128((__m128i*)(outptr1 + 8), _mm512_extracti32x4_epi32(_tmp2, 1)); + _mm_store_si128((__m128i*)(outptr2 + 8), _mm512_extracti32x4_epi32(_tmp2, 2)); + _mm_store_si128((__m128i*)(outptr3 + 8), _mm512_extracti32x4_epi32(_tmp2, 3)); + } + if (tj * 4 + 3 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 12), _mm512_extracti32x4_epi32(_tmp3, 0)); + _mm_store_si128((__m128i*)(outptr1 + 12), _mm512_extracti32x4_epi32(_tmp3, 1)); + _mm_store_si128((__m128i*)(outptr2 + 12), _mm512_extracti32x4_epi32(_tmp3, 2)); + _mm_store_si128((__m128i*)(outptr3 + 12), _mm512_extracti32x4_epi32(_tmp3, 3)); + } + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _mm512_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm512_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm512_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm512_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + int tmp[4][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + const int* r4 = r0 + max_jj * 8 * 4; + const int* r5 = r0 + max_jj * 8 * 5; + + for (int m = 0; m < 5; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + __m256i _r4 = _mm256_load_si256((const __m256i*)r4); + __m256i _r5 = _mm256_load_si256((const __m256i*)r5); + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _mm256_slli_epi32(_r5, 2)); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_storeu_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_storeu_si256((__m256i*)tmp[3][m], _tmp3); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + for (int m = 5; m < 6; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + __m256i _r4 = _mm256_load_si256((const __m256i*)r4); + __m256i _r5 = _mm256_load_si256((const __m256i*)r5); + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _mm256_slli_epi32(_r5, 2)); + + _tmp0 = _mm256_slli_epi32(_tmp0, 2); + _tmp1 = _mm256_slli_epi32(_tmp1, 2); + _tmp2 = _mm256_slli_epi32(_tmp2, 2); + _tmp3 = _mm256_slli_epi32(_tmp3, 2); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_storeu_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_storeu_si256((__m256i*)tmp[3][m], _tmp3); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m256i _r0 = _mm256_loadu_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_loadu_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_loadu_si256((const __m256i*)tmp[m][5]); +#else + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_load_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_load_si256((const __m256i*)tmp[m][5]); +#endif + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m256 _v576 = _mm256_set1_ps(1.0 / 576); + _tmp0 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 8) + { + _mm256_store_si256((__m256i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm256_store_si256((__m256i*)(outptr0 + 8), _tmp1); + if (tj * 4 + 2 < outw) _mm256_store_si256((__m256i*)(outptr0 + 16), _tmp2); + if (tj * 4 + 3 < outw) _mm256_store_si256((__m256i*)(outptr0 + 24), _tmp3); + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + _mm_store_si128((__m128i*)(outptr0), _mm256_extracti128_si256(_tmp0, 0)); + _mm_store_si128((__m128i*)(outptr1), _mm256_extracti128_si256(_tmp0, 1)); + if (tj * 4 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm256_extracti128_si256(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm256_extracti128_si256(_tmp1, 1)); + } + if (tj * 4 + 2 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 8), _mm256_extracti128_si256(_tmp2, 0)); + _mm_store_si128((__m128i*)(outptr1 + 8), _mm256_extracti128_si256(_tmp2, 1)); + } + if (tj * 4 + 3 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 12), _mm256_extracti128_si256(_tmp3, 0)); + _mm_store_si128((__m128i*)(outptr1 + 12), _mm256_extracti128_si256(_tmp3, 1)); + } + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); + _mm256_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm256_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm256_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm256_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); +#else + int tmp0[8]; + int tmp1[8]; + int tmp2[8]; + int tmp3[8]; + _mm256_storeu_si256((__m256i*)tmp0, _tmp0); + _mm256_storeu_si256((__m256i*)tmp1, _tmp1); + _mm256_storeu_si256((__m256i*)tmp2, _tmp2); + _mm256_storeu_si256((__m256i*)tmp3, _tmp3); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + outptr4[0] = tmp0[4]; + outptr5[0] = tmp0[5]; + outptr6[0] = tmp0[6]; + outptr7[0] = tmp0[7]; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + outptr4[1] = tmp1[4]; + outptr5[1] = tmp1[5]; + outptr6[1] = tmp1[6]; + outptr7[1] = tmp1[7]; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp2[0]; + outptr1[2] = tmp2[1]; + outptr2[2] = tmp2[2]; + outptr3[2] = tmp2[3]; + outptr4[2] = tmp2[4]; + outptr5[2] = tmp2[5]; + outptr6[2] = tmp2[6]; + outptr7[2] = tmp2[7]; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp3[0]; + outptr1[3] = tmp3[1]; + outptr2[3] = tmp3[2]; + outptr3[3] = tmp3[3]; + outptr4[3] = tmp3[4]; + outptr5[3] = tmp3[5]; + outptr6[3] = tmp3[6]; + outptr7[3] = tmp3[7]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + int tmp[4][6][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + const int* r4 = r0 + max_jj * 4 * 4; + const int* r5 = r0 + max_jj * 4 * 5; + + for (int m = 0; m < 5; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + __m128i _r4 = _mm_load_si128((const __m128i*)r4); + __m128i _r5 = _mm_load_si128((const __m128i*)r5); + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _mm_slli_epi32(_r5, 2)); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + for (int m = 5; m < 6; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + __m128i _r4 = _mm_load_si128((const __m128i*)r4); + __m128i _r5 = _mm_load_si128((const __m128i*)r5); + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _mm_slli_epi32(_r5, 2)); + + _tmp0 = _mm_slli_epi32(_tmp0, 2); + _tmp1 = _mm_slli_epi32(_tmp1, 2); + _tmp2 = _mm_slli_epi32(_tmp2, 2); + _tmp3 = _mm_slli_epi32(_tmp3, 2); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_loadu_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_loadu_si128((const __m128i*)tmp[m][5]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_load_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_load_si128((const __m128i*)tmp[m][5]); +#endif + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m128 _v576 = _mm_set1_ps(1.0 / 576); + _tmp0 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 4) + { + _mm_store_si128((__m128i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm_store_si128((__m128i*)(outptr0 + 4), _tmp1); + if (tj * 4 + 2 < outw) _mm_store_si128((__m128i*)(outptr0 + 8), _tmp2); + if (tj * 4 + 3 < outw) _mm_store_si128((__m128i*)(outptr0 + 12), _tmp3); + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(N)); + _mm_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); +#else + int tmp0[4]; + int tmp1[4]; + int tmp2[4]; + int tmp3[4]; + _mm_storeu_si128((__m128i*)tmp0, _tmp0); + _mm_storeu_si128((__m128i*)tmp1, _tmp1); + _mm_storeu_si128((__m128i*)tmp2, _tmp2); + _mm_storeu_si128((__m128i*)tmp3, _tmp3); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp2[0]; + outptr1[2] = tmp2[1]; + outptr2[2] = tmp2[2]; + outptr3[2] = tmp2[3]; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp3[0]; + outptr1[3] = tmp3[1]; + outptr2[3] = tmp3[2]; + outptr3[3] = tmp3[3]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[4][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + const int* r4 = r0 + max_jj * 2 * 4; + const int* r5 = r0 + max_jj * 2 * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + for (int m = 5; m < 6; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp00 = tmp00 * 4; + tmp01 = tmp01 * 4; + tmp10 = tmp10 * 4; + tmp11 = tmp11 * 4; + tmp20 = tmp20 * 4; + tmp21 = tmp21 * 4; + tmp30 = tmp30 * 4; + tmp31 = tmp31 * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a0 = tmp[m][1][0] + tmp[m][2][0]; + int tmp02a1 = tmp[m][1][1] + tmp[m][2][1]; + int tmp02b0 = tmp[m][3][0] + tmp[m][4][0]; + int tmp02b1 = tmp[m][3][1] + tmp[m][4][1]; + int tmp13a0 = tmp[m][1][0] - tmp[m][2][0]; + int tmp13a1 = tmp[m][1][1] - tmp[m][2][1]; + int tmp13b0 = tmp[m][3][0] - tmp[m][4][0]; + int tmp13b1 = tmp[m][3][1] - tmp[m][4][1]; + + int tmp00 = tmp02a0 + tmp02b0 + tmp[m][0][0]; + int tmp01 = tmp02a1 + tmp02b1 + tmp[m][0][1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + tmp[m][5][0]; + int tmp31 = tmp13a1 + tmp13b1 * 8 + tmp[m][5][1]; + + tmp00 = tmp00 / 576; + tmp01 = tmp01 / 576; + tmp10 = tmp10 / 576; + tmp11 = tmp11 / 576; + tmp20 = tmp20 / 576; + tmp21 = tmp21 / 576; + tmp30 = tmp30 / 576; + tmp31 = tmp31 / 576; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp20; + outptr1[2] = tmp21; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp30; + outptr1[3] = tmp31; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[4][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + const int* r4 = r0 + max_jj * 4; + const int* r5 = r0 + max_jj * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + for (int m = 5; m < 6; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp0 = tmp0 * 4; + tmp1 = tmp1 * 4; + tmp2 = tmp2 * 4; + tmp3 = tmp3 * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a = tmp[m][1] + tmp[m][2]; + int tmp02b = tmp[m][3] + tmp[m][4]; + int tmp13a = tmp[m][1] - tmp[m][2]; + int tmp13b = tmp[m][3] - tmp[m][4]; + + int tmp0 = tmp02a + tmp02b + tmp[m][0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + tmp[m][5]; + + tmp0 = tmp0 / 576; + tmp1 = tmp1 / 576; + tmp2 = tmp2 / 576; + tmp3 = tmp3 / 576; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 4 + 1 < outw) outptr0[1] = tmp1; + if (tj * 4 + 2 < outw) outptr0[2] = tmp2; + if (tj * 4 + 3 < outw) outptr0[3] = tmp3; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + conv3x3s1_winograd43_int8_avx512vnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + conv3x3s1_winograd43_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd43_int8_avx2(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ + if (ncnn::cpu_support_x86_xop()) + { + conv3x3s1_winograd43_int8_xop(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif +#endif + + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 4n+2, winograd F(4,3) + int w_tiles = (outw + 3) / 4; + int h_tiles = (outh + 3) / 4; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 36; + + // NCNN_LOGE("conv3x3s1_winograd43_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 4u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + bool k_end = k + TILE_K >= K; + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk, k_end); + } + + // transform output + conv3x3s1_winograd43_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index 56a4aa4763af..da504677a68f 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -934,7 +934,6 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_CBAD); } - // TODO __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sumc, _MM_SHUFFLE(2, 0, 2, 0)); __m512i _tmp1 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); __m512i _tmp2 = _mm512_shuffle_i32x4(_sum8, _sum4, _MM_SHUFFLE(0, 2, 0, 2)); @@ -2547,12 +2546,12 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _tmp7 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); - _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 3, 1)); - _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(1, 3, 1, 3)); + _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); - _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); - _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); _mm512_storeu_si512((__m512i*)outptr0, _sum0); _mm512_storeu_si512((__m512i*)(outptr0 + 16), _sum1); @@ -6142,14 +6141,13 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo } } -static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +template +#if __AVX512F__ +void convolution_im2col_input_tile_int8_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else // __AVX512F__ +void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif // __AVX512F__ { - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - convolution_im2col_input_tile_conv1x1s1d1_int8(bottom_blob, B, j, max_jj, k, max_kk); - return; - } - const int w = bottom_blob.w; // const int channels = bottom_blob.c; const int elempack = bottom_blob.elempack; @@ -6206,288 +6204,468 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dxe = (j + jj + 14) % outw; int dxf = (j + jj + 15) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dyf) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int x08 = stride_w * dx8 + dilation_w * v0; - int x09 = stride_w * dx9 + dilation_w * v0; - int x0a = stride_w * dxa + dilation_w * v0; - int x0b = stride_w * dxb + dilation_w * v0; - int x0c = stride_w * dxc + dilation_w * v0; - int x0d = stride_w * dxd + dilation_w * v0; - int x0e = stride_w * dxe + dilation_w * v0; - int x0f = stride_w * dxf + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - int y08 = stride_h * dy8 + dilation_h * u0; - int y09 = stride_h * dy9 + dilation_h * u0; - int y0a = stride_h * dya + dilation_h * u0; - int y0b = stride_h * dyb + dilation_h * u0; - int y0c = stride_h * dyc + dilation_h * u0; - int y0d = stride_h * dyd + dilation_h * u0; - int y0e = stride_h * dye + dilation_h * u0; - int y0f = stride_h * dyf + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int x18 = stride_w * dx8 + dilation_w * v1; - int x19 = stride_w * dx9 + dilation_w * v1; - int x1a = stride_w * dxa + dilation_w * v1; - int x1b = stride_w * dxb + dilation_w * v1; - int x1c = stride_w * dxc + dilation_w * v1; - int x1d = stride_w * dxd + dilation_w * v1; - int x1e = stride_w * dxe + dilation_w * v1; - int x1f = stride_w * dxf + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - int y18 = stride_h * dy8 + dilation_h * u1; - int y19 = stride_h * dy9 + dilation_h * u1; - int y1a = stride_h * dya + dilation_h * u1; - int y1b = stride_h * dyb + dilation_h * u1; - int y1c = stride_h * dyc + dilation_h * u1; - int y1d = stride_h * dyd + dilation_h * u1; - int y1e = stride_h * dye + dilation_h * u1; - int y1f = stride_h * dyf + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - const signed char* sptr08 = img0.row(y08) + x08; - const signed char* sptr09 = img0.row(y09) + x09; - const signed char* sptr0a = img0.row(y0a) + x0a; - const signed char* sptr0b = img0.row(y0b) + x0b; - const signed char* sptr0c = img0.row(y0c) + x0c; - const signed char* sptr0d = img0.row(y0d) + x0d; - const signed char* sptr0e = img0.row(y0e) + x0e; - const signed char* sptr0f = img0.row(y0f) + x0f; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - const signed char* sptr18 = img1.row(y18) + x18; - const signed char* sptr19 = img1.row(y19) + x19; - const signed char* sptr1a = img1.row(y1a) + x1a; - const signed char* sptr1b = img1.row(y1b) + x1b; - const signed char* sptr1c = img1.row(y1c) + x1c; - const signed char* sptr1d = img1.row(y1d) + x1d; - const signed char* sptr1e = img1.row(y1e) + x1e; - const signed char* sptr1f = img1.row(y1f) + x1f; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp[16 + 0] = sptr08[0]; - pp[16 + 1] = sptr18[0]; - pp[16 + 2] = sptr09[0]; - pp[16 + 3] = sptr19[0]; - pp[16 + 4] = sptr0a[0]; - pp[16 + 5] = sptr1a[0]; - pp[16 + 6] = sptr0b[0]; - pp[16 + 7] = sptr1b[0]; - pp[16 + 8] = sptr0c[0]; - pp[16 + 9] = sptr1c[0]; - pp[16 + 10] = sptr0d[0]; - pp[16 + 11] = sptr1d[0]; - pp[16 + 12] = sptr0e[0]; - pp[16 + 13] = sptr1e[0]; - pp[16 + 14] = sptr0f[0]; - pp[16 + 15] = sptr1f[0]; - pp += 32; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _mm_store_si128((__m128i*)pp, _tmp0); + _mm_store_si128((__m128i*)(pp + 16), _tmp1); + pp += 32; + } + else if (stride_w == 2) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); + __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _mm256_storeu_si256((__m256i*)pp, _r01); + pp += 32; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp[16 + 0] = sptr0[stride_w * 8]; + pp[16 + 1] = sptr1[stride_w * 8]; + pp[16 + 2] = sptr0[stride_w * 9]; + pp[16 + 3] = sptr1[stride_w * 9]; + pp[16 + 4] = sptr0[stride_w * 10]; + pp[16 + 5] = sptr1[stride_w * 10]; + pp[16 + 6] = sptr0[stride_w * 11]; + pp[16 + 7] = sptr1[stride_w * 11]; + pp[16 + 8] = sptr0[stride_w * 12]; + pp[16 + 9] = sptr1[stride_w * 12]; + pp[16 + 10] = sptr0[stride_w * 13]; + pp[16 + 11] = sptr1[stride_w * 13]; + pp[16 + 12] = sptr0[stride_w * 14]; + pp[16 + 13] = sptr1[stride_w * 14]; + pp[16 + 14] = sptr0[stride_w * 15]; + pp[16 + 15] = sptr1[stride_w * 15]; + pp += 32; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int x8 = stride_w * dx8 + dilation_w * v; - int x9 = stride_w * dx9 + dilation_w * v; - int xa = stride_w * dxa + dilation_w * v; - int xb = stride_w * dxb + dilation_w * v; - int xc = stride_w * dxc + dilation_w * v; - int xd = stride_w * dxd + dilation_w * v; - int xe = stride_w * dxe + dilation_w * v; - int xf = stride_w * dxf + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - int y8 = stride_h * dy8 + dilation_h * u; - int y9 = stride_h * dy9 + dilation_h * u; - int ya = stride_h * dya + dilation_h * u; - int yb = stride_h * dyb + dilation_h * u; - int yc = stride_h * dyc + dilation_h * u; - int yd = stride_h * dyd + dilation_h * u; - int ye = stride_h * dye + dilation_h * u; - int yf = stride_h * dyf + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - const signed char* sptr8 = img.row(y8) + x8 * elempack; - const signed char* sptr9 = img.row(y9) + x9 * elempack; - const signed char* sptra = img.row(ya) + xa * elempack; - const signed char* sptrb = img.row(yb) + xb * elempack; - const signed char* sptrc = img.row(yc) + xc * elempack; - const signed char* sptrd = img.row(yd) + xd * elempack; - const signed char* sptre = img.row(ye) + xe * elempack; - const signed char* sptrf = img.row(yf) + xf * elempack; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); - __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); - __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); - __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); - __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); - __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); - __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); - __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); - __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); - __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); - __m128i _ref = _mm_unpacklo_epi16(_re, _rf); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi32(_r89, _rab); - _r5 = _mm_unpackhi_epi32(_r89, _rab); - _r6 = _mm_unpacklo_epi32(_rcd, _ref); - _r7 = _mm_unpackhi_epi32(_rcd, _ref); - _r8 = _mm_unpacklo_epi64(_r0, _r2); - _r9 = _mm_unpacklo_epi64(_r4, _r6); - _ra = _mm_unpackhi_epi64(_r0, _r2); - _rb = _mm_unpackhi_epi64(_r4, _r6); - _rc = _mm_unpacklo_epi64(_r1, _r3); - _rd = _mm_unpacklo_epi64(_r5, _r7); - _re = _mm_unpackhi_epi64(_r1, _r3); - _rf = _mm_unpackhi_epi64(_r5, _r7); - _mm_storeu_si128((__m128i*)pp, _r8); - _mm_storeu_si128((__m128i*)(pp + 16), _r9); - _mm_storeu_si128((__m128i*)(pp + 32), _ra); - _mm_storeu_si128((__m128i*)(pp + 48), _rb); - _mm_storeu_si128((__m128i*)(pp + 64), _rc); - _mm_storeu_si128((__m128i*)(pp + 80), _rd); - _mm_storeu_si128((__m128i*)(pp + 96), _re); - _mm_storeu_si128((__m128i*)(pp + 112), _rf); - pp += 128; + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); + __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); + __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); + __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); + __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); + __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); + __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp[8] = sptr[stride_w * 8]; + pp[9] = sptr[stride_w * 9]; + pp[10] = sptr[stride_w * 10]; + pp[11] = sptr[stride_w * 11]; + pp[12] = sptr[stride_w * 12]; + pp[13] = sptr[stride_w * 13]; + pp[14] = sptr[stride_w * 14]; + pp[15] = sptr[stride_w * 15]; + pp += 16; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr8[0]; - pp[9] = sptr9[0]; - pp[10] = sptra[0]; - pp[11] = sptrb[0]; - pp[12] = sptrc[0]; - pp[13] = sptrd[0]; - pp[14] = sptre[0]; - pp[15] = sptrf[0]; - pp += 16; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int x08 = stride_w * dx8 + dilation_w * v0; + int x09 = stride_w * dx9 + dilation_w * v0; + int x0a = stride_w * dxa + dilation_w * v0; + int x0b = stride_w * dxb + dilation_w * v0; + int x0c = stride_w * dxc + dilation_w * v0; + int x0d = stride_w * dxd + dilation_w * v0; + int x0e = stride_w * dxe + dilation_w * v0; + int x0f = stride_w * dxf + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + int y08 = stride_h * dy8 + dilation_h * u0; + int y09 = stride_h * dy9 + dilation_h * u0; + int y0a = stride_h * dya + dilation_h * u0; + int y0b = stride_h * dyb + dilation_h * u0; + int y0c = stride_h * dyc + dilation_h * u0; + int y0d = stride_h * dyd + dilation_h * u0; + int y0e = stride_h * dye + dilation_h * u0; + int y0f = stride_h * dyf + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int x18 = stride_w * dx8 + dilation_w * v1; + int x19 = stride_w * dx9 + dilation_w * v1; + int x1a = stride_w * dxa + dilation_w * v1; + int x1b = stride_w * dxb + dilation_w * v1; + int x1c = stride_w * dxc + dilation_w * v1; + int x1d = stride_w * dxd + dilation_w * v1; + int x1e = stride_w * dxe + dilation_w * v1; + int x1f = stride_w * dxf + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + int y18 = stride_h * dy8 + dilation_h * u1; + int y19 = stride_h * dy9 + dilation_h * u1; + int y1a = stride_h * dya + dilation_h * u1; + int y1b = stride_h * dyb + dilation_h * u1; + int y1c = stride_h * dyc + dilation_h * u1; + int y1d = stride_h * dyd + dilation_h * u1; + int y1e = stride_h * dye + dilation_h * u1; + int y1f = stride_h * dyf + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + const signed char* sptr08 = img0.row(y08) + x08; + const signed char* sptr09 = img0.row(y09) + x09; + const signed char* sptr0a = img0.row(y0a) + x0a; + const signed char* sptr0b = img0.row(y0b) + x0b; + const signed char* sptr0c = img0.row(y0c) + x0c; + const signed char* sptr0d = img0.row(y0d) + x0d; + const signed char* sptr0e = img0.row(y0e) + x0e; + const signed char* sptr0f = img0.row(y0f) + x0f; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + const signed char* sptr18 = img1.row(y18) + x18; + const signed char* sptr19 = img1.row(y19) + x19; + const signed char* sptr1a = img1.row(y1a) + x1a; + const signed char* sptr1b = img1.row(y1b) + x1b; + const signed char* sptr1c = img1.row(y1c) + x1c; + const signed char* sptr1d = img1.row(y1d) + x1d; + const signed char* sptr1e = img1.row(y1e) + x1e; + const signed char* sptr1f = img1.row(y1f) + x1f; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp[16 + 0] = sptr08[0]; + pp[16 + 1] = sptr18[0]; + pp[16 + 2] = sptr09[0]; + pp[16 + 3] = sptr19[0]; + pp[16 + 4] = sptr0a[0]; + pp[16 + 5] = sptr1a[0]; + pp[16 + 6] = sptr0b[0]; + pp[16 + 7] = sptr1b[0]; + pp[16 + 8] = sptr0c[0]; + pp[16 + 9] = sptr1c[0]; + pp[16 + 10] = sptr0d[0]; + pp[16 + 11] = sptr1d[0]; + pp[16 + 12] = sptr0e[0]; + pp[16 + 13] = sptr1e[0]; + pp[16 + 14] = sptr0f[0]; + pp[16 + 15] = sptr1f[0]; + pp += 32; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int x8 = stride_w * dx8 + dilation_w * v; + int x9 = stride_w * dx9 + dilation_w * v; + int xa = stride_w * dxa + dilation_w * v; + int xb = stride_w * dxb + dilation_w * v; + int xc = stride_w * dxc + dilation_w * v; + int xd = stride_w * dxd + dilation_w * v; + int xe = stride_w * dxe + dilation_w * v; + int xf = stride_w * dxf + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + int y8 = stride_h * dy8 + dilation_h * u; + int y9 = stride_h * dy9 + dilation_h * u; + int ya = stride_h * dya + dilation_h * u; + int yb = stride_h * dyb + dilation_h * u; + int yc = stride_h * dyc + dilation_h * u; + int yd = stride_h * dyd + dilation_h * u; + int ye = stride_h * dye + dilation_h * u; + int yf = stride_h * dyf + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + const signed char* sptr8 = img.row(y8) + x8 * elempack; + const signed char* sptr9 = img.row(y9) + x9 * elempack; + const signed char* sptra = img.row(ya) + xa * elempack; + const signed char* sptrb = img.row(yb) + xb * elempack; + const signed char* sptrc = img.row(yc) + xc * elempack; + const signed char* sptrd = img.row(yd) + xd * elempack; + const signed char* sptre = img.row(ye) + xe * elempack; + const signed char* sptrf = img.row(yf) + xf * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); + __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); + __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); + __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); + __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); + __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); + __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp[8] = sptr8[0]; + pp[9] = sptr9[0]; + pp[10] = sptra[0]; + pp[11] = sptrb[0]; + pp[12] = sptrc[0]; + pp[13] = sptrd[0]; + pp[14] = sptre[0]; + pp[15] = sptrf[0]; + pp += 16; + } } } } @@ -6511,168 +6689,298 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx6 = (j + jj + 6) % outw; int dx7 = (j + jj + 7) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy7) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp += 16; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp += 16; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi64(_r0, _r2); - _r5 = _mm_unpackhi_epi64(_r0, _r2); - _r6 = _mm_unpacklo_epi64(_r1, _r3); - _r7 = _mm_unpackhi_epi64(_r1, _r3); - _mm_storeu_si128((__m128i*)pp, _r4); - _mm_storeu_si128((__m128i*)(pp + 16), _r5); - _mm_storeu_si128((__m128i*)(pp + 32), _r6); - _mm_storeu_si128((__m128i*)(pp + 48), _r7); - pp += 64; + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp += 8; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp += 8; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp += 16; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp += 8; + } } } } @@ -6688,106 +6996,206 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx2 = (j + jj + 2) % outw; int dx3 = (j + jj + 3) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy3) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp += 8; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp += 8; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr = img.row(y0) + x0 * elempack; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _mm_storeu_si128((__m128i*)pp, _r0); - _mm_storeu_si128((__m128i*)(pp + 16), _r1); - pp += 32; + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp += 4; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp += 4; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp += 8; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp += 4; + } } } } @@ -6799,44 +7207,154 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx0 = (j + jj) % outw; int dx1 = (j + jj + 1) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy1) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp += 4; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp += 2; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp += 2; + } } } + } + for (; jj < max_jj; jj++) + { + int dy = (j + jj) / outw; + int dx = (j + jj) % outw; + + int kk = 0; for (; kk < max_kk / elempack; kk++) { int p = (k / elempack + kk) / maxk; @@ -6846,29 +7364,1309 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; + int x = stride_w * dx + dilation_w * v; + int y = stride_h * dy + dilation_h * u; - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr = img.row(y) + x * elempack; #if __SSE2__ if (elempack == 8) { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)sptr)); + pp += 8; } #endif // __SSE2__ if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp += 2; + pp[0] = sptr[0]; + pp += 1; + } + } + } +} + +#if __AVX512F__ +template void convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#else // __AVX512F__ +template void convolution_im2col_input_tile_int8<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#endif // __AVX512F__ + +static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +{ + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_input_tile_conv1x1s1d1_int8(bottom_blob, B, j, max_jj, k, max_kk); + return; + } + + if (kernel_w == 1 && kernel_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + const int w = bottom_blob.w; + // const int channels = bottom_blob.c; + const int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int outw = (w - kernel_extent_w) / stride_w + 1; + + // j max_jj outw*outh split w and h + + // k max_kk pa*maxk*(inch/pa) split inch + + // k/max_kk shall be multiple of maxk + + const int maxk = kernel_w * kernel_h; + + signed char* pp = B; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dy8 = (j + jj + 8) / outw; + int dy9 = (j + jj + 9) / outw; + int dya = (j + jj + 10) / outw; + int dyb = (j + jj + 11) / outw; + int dyc = (j + jj + 12) / outw; + int dyd = (j + jj + 13) / outw; + int dye = (j + jj + 14) / outw; + int dyf = (j + jj + 15) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + int dx8 = (j + jj + 8) % outw; + int dx9 = (j + jj + 9) % outw; + int dxa = (j + jj + 10) % outw; + int dxb = (j + jj + 11) % outw; + int dxc = (j + jj + 12) % outw; + int dxd = (j + jj + 13) % outw; + int dxe = (j + jj + 14) % outw; + int dxf = (j + jj + 15) % outw; + + if (dy0 == dyf) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _mm_store_si128((__m128i*)pp, _tmp0); + _mm_store_si128((__m128i*)(pp + 16), _tmp1); + pp += 32; + } + else if (stride_w == 2) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); + __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _mm256_storeu_si256((__m256i*)pp, _r01); + pp += 32; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp[16 + 0] = sptr0[stride_w * 8]; + pp[16 + 1] = sptr1[stride_w * 8]; + pp[16 + 2] = sptr0[stride_w * 9]; + pp[16 + 3] = sptr1[stride_w * 9]; + pp[16 + 4] = sptr0[stride_w * 10]; + pp[16 + 5] = sptr1[stride_w * 10]; + pp[16 + 6] = sptr0[stride_w * 11]; + pp[16 + 7] = sptr1[stride_w * 11]; + pp[16 + 8] = sptr0[stride_w * 12]; + pp[16 + 9] = sptr1[stride_w * 12]; + pp[16 + 10] = sptr0[stride_w * 13]; + pp[16 + 11] = sptr1[stride_w * 13]; + pp[16 + 12] = sptr0[stride_w * 14]; + pp[16 + 13] = sptr1[stride_w * 14]; + pp[16 + 14] = sptr0[stride_w * 15]; + pp[16 + 15] = sptr1[stride_w * 15]; + pp += 32; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); + __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); + __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); + __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); + __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); + __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); + __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp[8] = sptr[stride_w * 8]; + pp[9] = sptr[stride_w * 9]; + pp[10] = sptr[stride_w * 10]; + pp[11] = sptr[stride_w * 11]; + pp[12] = sptr[stride_w * 12]; + pp[13] = sptr[stride_w * 13]; + pp[14] = sptr[stride_w * 14]; + pp[15] = sptr[stride_w * 15]; + pp += 16; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int x08 = stride_w * dx8 + dilation_w * v0; + int x09 = stride_w * dx9 + dilation_w * v0; + int x0a = stride_w * dxa + dilation_w * v0; + int x0b = stride_w * dxb + dilation_w * v0; + int x0c = stride_w * dxc + dilation_w * v0; + int x0d = stride_w * dxd + dilation_w * v0; + int x0e = stride_w * dxe + dilation_w * v0; + int x0f = stride_w * dxf + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + int y08 = stride_h * dy8 + dilation_h * u0; + int y09 = stride_h * dy9 + dilation_h * u0; + int y0a = stride_h * dya + dilation_h * u0; + int y0b = stride_h * dyb + dilation_h * u0; + int y0c = stride_h * dyc + dilation_h * u0; + int y0d = stride_h * dyd + dilation_h * u0; + int y0e = stride_h * dye + dilation_h * u0; + int y0f = stride_h * dyf + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int x18 = stride_w * dx8 + dilation_w * v1; + int x19 = stride_w * dx9 + dilation_w * v1; + int x1a = stride_w * dxa + dilation_w * v1; + int x1b = stride_w * dxb + dilation_w * v1; + int x1c = stride_w * dxc + dilation_w * v1; + int x1d = stride_w * dxd + dilation_w * v1; + int x1e = stride_w * dxe + dilation_w * v1; + int x1f = stride_w * dxf + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + int y18 = stride_h * dy8 + dilation_h * u1; + int y19 = stride_h * dy9 + dilation_h * u1; + int y1a = stride_h * dya + dilation_h * u1; + int y1b = stride_h * dyb + dilation_h * u1; + int y1c = stride_h * dyc + dilation_h * u1; + int y1d = stride_h * dyd + dilation_h * u1; + int y1e = stride_h * dye + dilation_h * u1; + int y1f = stride_h * dyf + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + const signed char* sptr08 = img0.row(y08) + x08; + const signed char* sptr09 = img0.row(y09) + x09; + const signed char* sptr0a = img0.row(y0a) + x0a; + const signed char* sptr0b = img0.row(y0b) + x0b; + const signed char* sptr0c = img0.row(y0c) + x0c; + const signed char* sptr0d = img0.row(y0d) + x0d; + const signed char* sptr0e = img0.row(y0e) + x0e; + const signed char* sptr0f = img0.row(y0f) + x0f; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + const signed char* sptr18 = img1.row(y18) + x18; + const signed char* sptr19 = img1.row(y19) + x19; + const signed char* sptr1a = img1.row(y1a) + x1a; + const signed char* sptr1b = img1.row(y1b) + x1b; + const signed char* sptr1c = img1.row(y1c) + x1c; + const signed char* sptr1d = img1.row(y1d) + x1d; + const signed char* sptr1e = img1.row(y1e) + x1e; + const signed char* sptr1f = img1.row(y1f) + x1f; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp[16 + 0] = sptr08[0]; + pp[16 + 1] = sptr18[0]; + pp[16 + 2] = sptr09[0]; + pp[16 + 3] = sptr19[0]; + pp[16 + 4] = sptr0a[0]; + pp[16 + 5] = sptr1a[0]; + pp[16 + 6] = sptr0b[0]; + pp[16 + 7] = sptr1b[0]; + pp[16 + 8] = sptr0c[0]; + pp[16 + 9] = sptr1c[0]; + pp[16 + 10] = sptr0d[0]; + pp[16 + 11] = sptr1d[0]; + pp[16 + 12] = sptr0e[0]; + pp[16 + 13] = sptr1e[0]; + pp[16 + 14] = sptr0f[0]; + pp[16 + 15] = sptr1f[0]; + pp += 32; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int x8 = stride_w * dx8 + dilation_w * v; + int x9 = stride_w * dx9 + dilation_w * v; + int xa = stride_w * dxa + dilation_w * v; + int xb = stride_w * dxb + dilation_w * v; + int xc = stride_w * dxc + dilation_w * v; + int xd = stride_w * dxd + dilation_w * v; + int xe = stride_w * dxe + dilation_w * v; + int xf = stride_w * dxf + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + int y8 = stride_h * dy8 + dilation_h * u; + int y9 = stride_h * dy9 + dilation_h * u; + int ya = stride_h * dya + dilation_h * u; + int yb = stride_h * dyb + dilation_h * u; + int yc = stride_h * dyc + dilation_h * u; + int yd = stride_h * dyd + dilation_h * u; + int ye = stride_h * dye + dilation_h * u; + int yf = stride_h * dyf + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + const signed char* sptr8 = img.row(y8) + x8 * elempack; + const signed char* sptr9 = img.row(y9) + x9 * elempack; + const signed char* sptra = img.row(ya) + xa * elempack; + const signed char* sptrb = img.row(yb) + xb * elempack; + const signed char* sptrc = img.row(yc) + xc * elempack; + const signed char* sptrd = img.row(yd) + xd * elempack; + const signed char* sptre = img.row(ye) + xe * elempack; + const signed char* sptrf = img.row(yf) + xf * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); + __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); + __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); + __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); + __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); + __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); + __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp[8] = sptr8[0]; + pp[9] = sptr9[0]; + pp[10] = sptra[0]; + pp[11] = sptrb[0]; + pp[12] = sptrc[0]; + pp[13] = sptrd[0]; + pp[14] = sptre[0]; + pp[15] = sptrf[0]; + pp += 16; + } + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + + if (dy0 == dy7) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp += 16; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp += 8; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp += 16; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp += 8; + } + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + + if (dy0 == dy3) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp += 8; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp += 4; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp += 8; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp += 4; + } + } + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + + if (dy0 == dy1) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp += 2; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp += 2; + } } } } diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index f870a8847462..09008985f121 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -46,16 +46,13 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" + +#include "convolution_3x3_winograd_int8.h" #endif // NCNN_INT8 #if __SSE2__ #include "convolution_3x3_pack1to4.h" -#if NCNN_INT8 -#include "convolution_3x3_pack8to4_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#endif // NCNN_INT8 - #if __AVX__ #include "convolution_3x3_pack1to8.h" #include "convolution_3x3_pack8to1.h" @@ -1231,32 +1228,14 @@ int Convolution_x86::create_pipeline_int8_x86(const Option& opt) const int maxk = kernel_w * kernel_h; const int num_input = weight_data_size / maxk / num_output; - int elempack = 1; - int out_elempack_int32 = 1; -#if __SSE2__ - if (opt.use_packing_layout) - { - elempack = num_input % 8 == 0 ? 8 : 1; - out_elempack_int32 = num_output % 4 == 0 ? 4 : 1; - } -#endif // __SSE2__ + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input > 8 || num_output > 8); - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __SSE2__ - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __SSE2__ - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd23_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - conv3x3s1_winograd23_transform_kernel_int8_sse(weight_data, weight_winograd23_data, num_input, num_output, opt); - // conv3x3s1_winograd43_transform_kernel_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); + if (opt.use_winograd43_convolution) + conv3x3s1_winograd43_transform_kernel_int8(weight_data, weight_winograd43_data, num_input, num_output, opt); + else + conv3x3s1_winograd23_transform_kernel_int8(weight_data, weight_winograd23_data, num_input, num_output, opt); } else if (opt.use_sgemm_convolution) { @@ -1352,6 +1331,8 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con if (top_blob_int32.empty()) return -100; + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input > 8 || num_output > 8); + int _nT = nT ? nT : opt.num_threads; if (nT != 0 && opt.num_threads != nT) { @@ -1360,22 +1341,12 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con NCNN_LOGE("opt.num_threads %d changed, convolution gemm will use load-time value %d", opt.num_threads, nT); } - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __SSE2__ - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __SSE2__ - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd23_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - conv3x3s1_winograd23_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, opt); - // conv3x3s1_winograd43_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); + if (opt.use_winograd43_convolution && !weight_winograd43_data.empty()) + conv3x3s1_winograd43_int8(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, _nT, opt); + else + conv3x3s1_winograd23_int8(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, _nT, opt); } else if (opt.use_sgemm_convolution) { diff --git a/src/layer/x86/convolution_x86_avx2.cpp b/src/layer/x86/convolution_x86_avx2.cpp index 38f107ee0865..49cded702137 100644 --- a/src/layer/x86/convolution_x86_avx2.cpp +++ b/src/layer/x86/convolution_x86_avx2.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_transform_kernel_packed_int8_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) @@ -46,24 +45,24 @@ void convolution_im2col_gemm_int8_avx2(const Mat& bottom_blob, Mat& top_blob, co } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_transform_kernel_int8(kernel, AT, inch, outch, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd23_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd43_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd43_transform_kernel_int8(kernel, AT, inch, outch, opt); } -void conv3x3s1_winograd43_pack8to4_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_avx512vnni.cpp b/src/layer/x86/convolution_x86_avx512vnni.cpp index f0ac51bbf856..8e34bb61309f 100644 --- a/src/layer/x86/convolution_x86_avx512vnni.cpp +++ b/src/layer/x86/convolution_x86_avx512vnni.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_avx512vnni(const Mat& bottom_blob, Mat& top_bl } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_avxvnni.cpp b/src/layer/x86/convolution_x86_avxvnni.cpp index a8ef75bb968b..aa1ba401856c 100644 --- a/src/layer/x86/convolution_x86_avxvnni.cpp +++ b/src/layer/x86/convolution_x86_avxvnni.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_xop.cpp b/src/layer/x86/convolution_x86_xop.cpp index d954f5545655..cacba8f07cdd 100644 --- a/src/layer/x86/convolution_x86_xop.cpp +++ b/src/layer/x86/convolution_x86_xop.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_xop(const Mat& bottom_blob, Mat& top_blob, con } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index c9551330ff67..e75e78c0c255 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -42,6 +42,83 @@ static NCNN_FORCEINLINE signed char float2int8(float v) } #if __SSE2__ +static NCNN_FORCEINLINE void transpose4x8_epi32(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) +{ + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi32(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi32(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi32(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi32(_r6, _r7); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpacklo_epi64(_tmp4, _tmp6); + _r2 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r3 = _mm_unpackhi_epi64(_tmp4, _tmp6); + _r4 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r5 = _mm_unpacklo_epi64(_tmp5, _tmp7); + _r6 = _mm_unpackhi_epi64(_tmp1, _tmp3); + _r7 = _mm_unpackhi_epi64(_tmp5, _tmp7); +} + +static NCNN_FORCEINLINE void transpose4x4_epi32(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi64(_tmp1, _tmp3); +} + +static NCNN_FORCEINLINE void transpose8x8_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) +{ + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); + + __m128i _tmp8 = _mm_unpacklo_epi32(_tmp0, _tmp2); + __m128i _tmp9 = _mm_unpackhi_epi32(_tmp0, _tmp2); + __m128i _tmpa = _mm_unpacklo_epi32(_tmp1, _tmp3); + __m128i _tmpb = _mm_unpackhi_epi32(_tmp1, _tmp3); + __m128i _tmpc = _mm_unpacklo_epi32(_tmp4, _tmp6); + __m128i _tmpd = _mm_unpackhi_epi32(_tmp4, _tmp6); + __m128i _tmpe = _mm_unpacklo_epi32(_tmp5, _tmp7); + __m128i _tmpf = _mm_unpackhi_epi32(_tmp5, _tmp7); + + _r0 = _mm_unpacklo_epi64(_tmp8, _tmpc); + _r1 = _mm_unpackhi_epi64(_tmp8, _tmpc); + _r2 = _mm_unpacklo_epi64(_tmp9, _tmpd); + _r3 = _mm_unpackhi_epi64(_tmp9, _tmpd); + _r4 = _mm_unpacklo_epi64(_tmpa, _tmpe); + _r5 = _mm_unpackhi_epi64(_tmpa, _tmpe); + _r6 = _mm_unpacklo_epi64(_tmpb, _tmpf); + _r7 = _mm_unpackhi_epi64(_tmpb, _tmpf); +} + +static NCNN_FORCEINLINE void transpose8x4_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + + _r0 = _mm_unpacklo_epi32(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi32(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi32(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi32(_tmp1, _tmp3); +} + static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128) { const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); @@ -341,36 +418,6 @@ static NCNN_FORCEINLINE void transpose8x2_ps(__m256& _r0, __m256& _r1) _r1 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); } -static NCNN_FORCEINLINE void transpose8x8_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) -{ - __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); - __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); - __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); - __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); - - __m128i _tmp8 = _mm_unpacklo_epi32(_tmp0, _tmp2); - __m128i _tmp9 = _mm_unpackhi_epi32(_tmp0, _tmp2); - __m128i _tmpa = _mm_unpacklo_epi32(_tmp1, _tmp3); - __m128i _tmpb = _mm_unpackhi_epi32(_tmp1, _tmp3); - __m128i _tmpc = _mm_unpacklo_epi32(_tmp4, _tmp6); - __m128i _tmpd = _mm_unpackhi_epi32(_tmp4, _tmp6); - __m128i _tmpe = _mm_unpacklo_epi32(_tmp5, _tmp7); - __m128i _tmpf = _mm_unpackhi_epi32(_tmp5, _tmp7); - - _r0 = _mm_unpacklo_epi64(_tmp8, _tmpc); - _r1 = _mm_unpackhi_epi64(_tmp8, _tmpc); - _r2 = _mm_unpacklo_epi64(_tmp9, _tmpd); - _r3 = _mm_unpackhi_epi64(_tmp9, _tmpd); - _r4 = _mm_unpacklo_epi64(_tmpa, _tmpe); - _r5 = _mm_unpackhi_epi64(_tmpa, _tmpe); - _r6 = _mm_unpacklo_epi64(_tmpb, _tmpf); - _r7 = _mm_unpackhi_epi64(_tmpb, _tmpf); -} - static NCNN_FORCEINLINE __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7) { const __m256 s01 = _mm256_hadd_ps(v0, v1); @@ -598,6 +645,55 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& return _v; } +#if __AVX2__ +static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1) +{ + __m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_r0, _r1); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m256i& _r2, __m256i& _r3, __m256i& _r4, __m256i& _r5, __m256i& _r6, __m256i& _r7) +{ + __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); + __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); + __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); + __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); + __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); + __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); + __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); + + __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); + __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); + __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); + __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); + __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); + __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); + __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); + __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); + + _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); + _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); + _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); + _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); + _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); + _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); + _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); + _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); + _r6 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); +} + #if __AVX512F__ static NCNN_FORCEINLINE void transpose16x16_ps(__m512& _r0, __m512& _r1, __m512& _r2, __m512& _r3, __m512& _r4, __m512& _r5, __m512& _r6, __m512& _r7, __m512& _r8, __m512& _r9, __m512& _ra, __m512& _rb, __m512& _rc, __m512& _rd, __m512& _re, __m512& _rf) @@ -928,45 +1024,6 @@ static NCNN_FORCEINLINE void transpose16x16_epi16(__m256i& _r0, __m256i& _r1, __ _rf = _mm256_permute2x128_si256(_tmp7, _tmpf, _MM_SHUFFLE(0, 3, 0, 1)); } -static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m256i& _r2, __m256i& _r3, __m256i& _r4, __m256i& _r5, __m256i& _r6, __m256i& _r7) -{ - __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); - __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); - __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); - __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); - __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); - __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); - __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); - __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); - - __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); - __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); - __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); - __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); - __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); - __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); - __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); - __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); - - _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); - _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); - _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); - _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); - _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); - _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); - _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); - _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); - - _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); - _r1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); - _r2 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); - _r3 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); - _r4 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); - _r5 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); - _r6 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); - _r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); -} - static NCNN_FORCEINLINE void transpose8x16_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7, __m128i& _r8, __m128i& _r9, __m128i& _ra, __m128i& _rb, __m128i& _rc, __m128i& _rd, __m128i& _re, __m128i& _rf) { __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); @@ -1088,6 +1145,7 @@ static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m5 } #endif // __AVX512F__ +#endif // __AVX2__ #endif // __AVX__ #endif // __SSE2__ diff --git a/tests/test_convolution_3.cpp b/tests/test_convolution_3.cpp index fa358d0670cc..1d0f8f079b62 100644 --- a/tests/test_convolution_3.cpp +++ b/tests/test_convolution_3.cpp @@ -190,6 +190,30 @@ static int test_convolution_int8(int w, int h, int c, int outch, int kernel, int return ret; } + if (kernel == 3 && dilation == 1 && stride == 1) + { + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = true; + opt.use_fp16_packed = false; + opt.use_fp16_storage = false; + opt.use_fp16_arithmetic = false; + opt.use_bf16_storage = false; + opt.use_shader_pack8 = false; + opt.use_image_storage = false; + opt.use_sgemm_convolution = false; + opt.use_winograd_convolution = true; + opt.use_winograd23_convolution = true; + opt.use_winograd43_convolution = false; + + ret = test_layer_opt("Convolution", pd, weights, opt, a, requant ? 1.0f : 0.001f, 0, flag); + if (ret != 0) + { + fprintf(stderr, "test_convolution_int8 failed w=%d h=%d c=%d outch=%d kernel=%d dilation=%d stride=%d pad=%d bias=%d requant=%d act=%d actparams=[%f,%f]\n", w, h, c, outch, kernel, dilation, stride, pad, bias, requant, activation_type, activation_params[0], activation_params[1]); + return ret; + } + } + { ncnn::Option opt; opt.num_threads = 1; @@ -310,6 +334,7 @@ static int test_convolution_1() || test_convolution_int8(4, 20, 16, 24, 3, 1, 1, 1, 0) || test_convolution_int8(6, 7, 64, 64, 3, 1, 2, 0, 1) || test_convolution_int8(25, 33, 16, 15, 3, 1, 1, 1, 0) + || test_convolution_int8(25, 33, 31, 31, 3, 1, 1, 1, 0) || test_convolution_int8(7, 7, 15, 12, 3, 1, 1, 1, 0) || test_convolution_int8(5, 6, 31, 9, 5, 1, 1, 0, 1) || test_convolution_int8(5, 7, 32, 8, 5, 1, 2, 0, 1)