diff --git a/src/layer/arm/convolution_3x3_int8.h b/src/layer/arm/convolution_3x3_int8.h index 826ed8a82e03..1868b5d6855d 100644 --- a/src/layer/arm/convolution_3x3_int8.h +++ b/src/layer/arm/convolution_3x3_int8.h @@ -12,235 +12,6 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv3x3s1_winograd43_transform_kernel_int8_neon(const Mat& kernel, Mat& kernel_tm_packed, int inch, int outch, const Option& opt) -{ - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 8a-8b-inch/8a-36-outch/8b -#if __ARM_NEON - if (outch >= 8) - { - kernel_tm_packed.create(inch, 36, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)2u * 8, 8); - } - else if (outch >= 4) - { - kernel_tm_packed.create(inch, 36, outch / 4 + outch % 4, (size_t)2u * 4, 4); - } -#else // __ARM_NEON - if (outch >= 2) - { - kernel_tm_packed.create(inch, 36, outch / 2 + outch % 2, (size_t)2u * 2, 2); - } -#endif // __ARM_NEON - else - { - kernel_tm_packed.create(inch, 36, outch, (size_t)2u, 1); - } - - int p = 0; -#if __ARM_NEON - for (; p + 7 < outch; p += 8) - { - Mat g0 = kernel_tm_packed.channel(p / 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q < inch; q++) - { - for (int i = 0; i < 8; i++) - { - g00[0] = kernel_tm.channel(p + i).row(q)[k]; - g00++; - } - } - } - } - for (; p + 3 < outch; p += 4) - { - const Mat k0 = kernel_tm.channel(p); - const Mat k1 = kernel_tm.channel(p + 1); - const Mat k2 = kernel_tm.channel(p + 2); - const Mat k3 = kernel_tm.channel(p + 3); - - Mat g0 = kernel_tm_packed.channel(p / 8 + (p % 8) / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q < inch; q++) - { - g00[0] = k0.row(q)[k]; - g00[1] = k1.row(q)[k]; - g00[2] = k2.row(q)[k]; - g00[3] = k3.row(q)[k]; - g00 += 4; - } - } - } -#else // __ARM_NEON - for (; p + 1 < outch; p += 2) - { - const Mat k0 = kernel_tm.channel(p); - const Mat k1 = kernel_tm.channel(p + 1); - - Mat g0 = kernel_tm_packed.channel(p / 2); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - int q = 0; -#if __ARM_FEATURE_SIMD32 - for (; q + 1 < inch; q += 2) - { - g00[0] = k0.row(q)[k]; - g00[2] = k1.row(q)[k]; - g00[1] = k0.row(q + 1)[k]; - g00[3] = k1.row(q + 1)[k]; - g00 += 4; - } -#endif // __ARM_FEATURE_SIMD32 - for (; q < inch; q++) - { - g00[0] = k0.row(q)[k]; - g00[1] = k1.row(q)[k]; - g00 += 2; - } - } - } -#endif // __ARM_NEON - for (; p < outch; p++) - { - const Mat k0 = kernel_tm.channel(p); - -#if __ARM_NEON - Mat g0 = kernel_tm_packed.channel(p / 8 + (p % 8) / 4 + p % 4); -#else - Mat g0 = kernel_tm_packed.channel(p / 2 + p % 2); -#endif - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q < inch; q++) - { - g00[0] = k0.row(q)[k]; - g00 += 1; - } - } - } -} - -static void conv3x3s1_winograd43_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tiles = outw / 4; - int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - conv3x3s1_winograd43_transform_input_int8_neon(bottom_blob_bordered, bottom_blob_tm, opt); - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - convolution_winograd_dot_int8_neon(bottom_blob_tm, outch, kernel_tm, top_blob_tm, opt); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); - } - { - conv3x3s1_winograd43_transform_output_int8_neon(top_blob_tm, top_blob_bordered, opt); - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - static void conv3x3s2_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch) { kernel_tm.create(8 * 9, inch, outch / 8 + outch % 8, (size_t)1u); diff --git a/src/layer/arm/convolution_3x3_pack8to1_int8.h b/src/layer/arm/convolution_3x3_pack8to1_int8.h deleted file mode 100644 index 5af9f5938e1b..000000000000 --- a/src/layer/arm/convolution_3x3_pack8to1_int8.h +++ /dev/null @@ -1,185 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_neon(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt) -{ - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 8a-inch/8a-36-outch - kernel_tm_pack8to1.create(8 * inch / 8, 36, outch / 8 + outch % 8, (size_t)2u * 8, 8); - - int p = 0; - for (; p + 7 < outch; p += 8) - { - const Mat k0 = kernel_tm.channel(p); - const Mat k1 = kernel_tm.channel(p + 1); - const Mat k2 = kernel_tm.channel(p + 2); - const Mat k3 = kernel_tm.channel(p + 3); - const Mat k4 = kernel_tm.channel(p + 4); - const Mat k5 = kernel_tm.channel(p + 5); - const Mat k6 = kernel_tm.channel(p + 6); - const Mat k7 = kernel_tm.channel(p + 7); - - Mat g0 = kernel_tm_pack8to1.channel(p / 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 8; i++) - { - g00[0] = k0.row(q + i)[k]; - g00[1] = k1.row(q + i)[k]; - g00[2] = k2.row(q + i)[k]; - g00[3] = k3.row(q + i)[k]; - g00[4] = k4.row(q + i)[k]; - g00[5] = k5.row(q + i)[k]; - g00[6] = k6.row(q + i)[k]; - g00[7] = k7.row(q + i)[k]; - - g00 += 8; - } - } - } - } - for (; p < outch; p++) - { - const Mat k0 = kernel_tm.channel(p); - - Mat g0 = kernel_tm_pack8to1.channel(p / 8 + p % 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 8; i++) - { - g00[0] = k0.row(q + i)[k]; - - g00 += 1; - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to1_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tiles = outw / 4; - int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - conv3x3s1_winograd43_transform_input_pack8_int8_neon(bottom_blob_bordered, bottom_blob_tm, opt); - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - convolution_winograd_dot_pack8to1_int8_neon(bottom_blob_tm, outch, kernel_tm, top_blob_tm, opt); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); - } - { - conv3x3s1_winograd43_transform_output_int8_neon(top_blob_tm, top_blob_bordered, opt); - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/arm/convolution_3x3_pack8to4_int8.h b/src/layer/arm/convolution_3x3_pack8to4_int8.h deleted file mode 100644 index ee67ba61ef73..000000000000 --- a/src/layer/arm/convolution_3x3_pack8to4_int8.h +++ /dev/null @@ -1,205 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_neon(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt) -{ - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8.create(inch / 8, 36, outch / 8 + (outch % 8) / 4, (size_t)2u * 64, 64); - - int q = 0; - for (; q + 7 < outch; q += 8) - { - const Mat k0 = kernel_tm.channel(q); - const Mat k1 = kernel_tm.channel(q + 1); - const Mat k2 = kernel_tm.channel(q + 2); - const Mat k3 = kernel_tm.channel(q + 3); - const Mat k4 = kernel_tm.channel(q + 4); - const Mat k5 = kernel_tm.channel(q + 5); - const Mat k6 = kernel_tm.channel(q + 6); - const Mat k7 = kernel_tm.channel(q + 7); - - Mat kernel_tm = kernel_tm_pack8.channel(q / 8); - - for (int k = 0; k < 36; k++) - { - short* g00 = kernel_tm.row(k); - - for (int p = 0; p + 7 < inch; p += 8) - { - for (int i = 0; i < 8; i++) - { - const short* k00 = k0.row(p + i); - const short* k10 = k1.row(p + i); - const short* k20 = k2.row(p + i); - const short* k30 = k3.row(p + i); - const short* k40 = k4.row(p + i); - const short* k50 = k5.row(p + i); - const short* k60 = k6.row(p + i); - const short* k70 = k7.row(p + i); - - g00[0] = k00[k]; - g00[1] = k10[k]; - g00[2] = k20[k]; - g00[3] = k30[k]; - g00[4] = k40[k]; - g00[5] = k50[k]; - g00[6] = k60[k]; - g00[7] = k70[k]; - - g00 += 8; - } - } - } - } - for (; q + 3 < outch; q += 4) - { - const Mat k0 = kernel_tm.channel(q); - const Mat k1 = kernel_tm.channel(q + 1); - const Mat k2 = kernel_tm.channel(q + 2); - const Mat k3 = kernel_tm.channel(q + 3); - - Mat kernel_tm = kernel_tm_pack8.channel(q / 8 + (q % 8) / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = kernel_tm.row(k); - - for (int p = 0; p + 7 < inch; p += 8) - { - for (int i = 0; i < 8; i++) - { - const short* k00 = k0.row(p + i); - const short* k10 = k1.row(p + i); - const short* k20 = k2.row(p + i); - const short* k30 = k3.row(p + i); - - g00[0] = k00[k]; - g00[1] = k10[k]; - g00[2] = k20[k]; - g00[3] = k30[k]; - - g00 += 4; - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tiles = outw / 4; - int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - conv3x3s1_winograd43_transform_input_pack8_int8_neon(bottom_blob_bordered, bottom_blob_tm, opt); - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - convolution_winograd_dot_pack8to4_int8_neon(bottom_blob_tm, outch, kernel_tm, top_blob_tm, opt); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u * 4, 4, opt.workspace_allocator); - } - { - conv3x3s1_winograd43_transform_output_pack4_int8_neon(top_blob_tm, top_blob_bordered, opt); - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/arm/convolution_3x3_winograd_int8.h b/src/layer/arm/convolution_3x3_winograd_int8.h new file mode 100644 index 000000000000..ab108b3f089e --- /dev/null +++ b/src/layer/arm/convolution_3x3_winograd_int8.h @@ -0,0 +1,5719 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void pack_A_tile_int8(const Mat& A, Mat& AT, int batch, int max_ii, int max_kk) +{ + const int N = max_kk * batch; + + for (int b = 0; b < batch; b++) + { + short* pp = AT.row(b); + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + p0 += batch; + pp += 8; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + p0 += batch; + pp += 4; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[batch + N]; + p0 += batch * 2; + pp += 4; + } +#endif + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + p0 += batch; + pp += 2; + } + } + for (; ii < max_ii; ii++) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += batch; + pp += 1; + } + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int batch, int max_jj, int max_kk, int nT) +{ + // NCNN_LOGE("transpose_pack_B_tile_int8 %d %d", max_jj, max_kk); + + #pragma omp parallel for num_threads(nT) + for (int b = 0; b < batch; b++) + { + short* pp = BT.row(b); + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + // transpose 8x12 +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%0, #512] \n" + "prfm pldl1keep, [%0, #1024] \n" + "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" + "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n" + "ld4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%0] \n" + "sub %0, %0, #128 \n" + "uzp1 v20.8h, v0.8h, v4.8h \n" + "uzp2 v26.8h, v0.8h, v4.8h \n" + "uzp1 v23.8h, v2.8h, v6.8h \n" + "uzp2 v29.8h, v2.8h, v6.8h \n" + "uzp1 v21.8h, v16.8h, v1.8h \n" + "uzp2 v27.8h, v16.8h, v1.8h \n" + "uzp1 v22.8h, v5.8h, v17.8h \n" + "uzp2 v28.8h, v5.8h, v17.8h \n" + "uzp1 v24.8h, v18.8h, v3.8h \n" + "uzp2 v30.8h, v18.8h, v3.8h \n" + "uzp1 v25.8h, v7.8h, v19.8h \n" + "uzp2 v31.8h, v7.8h, v19.8h \n" + "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [%1], #64 \n" + "st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [%1], #64 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + p0 += max_jj * batch * 8; +#else // NCNN_GNU_INLINE_ASM + int16x8x4_t _r0 = vld4q_s16(p0); + int16x8x4_t _r1 = vld4q_s16(p0 + 32); + int16x8x4_t _r2 = vld4q_s16(p0 + 64); + int16x8x2_t _t0 = vuzpq_s16(_r0.val[0], _r1.val[0]); + int16x8x2_t _t1 = vuzpq_s16(_r2.val[0], _r0.val[1]); + int16x8x2_t _t2 = vuzpq_s16(_r1.val[1], _r2.val[1]); + int16x8x2_t _t3 = vuzpq_s16(_r0.val[2], _r1.val[2]); + int16x8x2_t _t4 = vuzpq_s16(_r2.val[2], _r0.val[3]); + int16x8x2_t _t5 = vuzpq_s16(_r1.val[3], _r2.val[3]); + vst1q_s16(pp, _t0.val[0]); + vst1q_s16(pp + 8, _t1.val[0]); + vst1q_s16(pp + 16, _t2.val[0]); + vst1q_s16(pp + 24, _t3.val[0]); + vst1q_s16(pp + 32, _t4.val[0]); + vst1q_s16(pp + 40, _t5.val[0]); + vst1q_s16(pp + 48, _t0.val[1]); + vst1q_s16(pp + 56, _t1.val[1]); + vst1q_s16(pp + 64, _t2.val[1]); + vst1q_s16(pp + 72, _t3.val[1]); + vst1q_s16(pp + 80, _t4.val[1]); + vst1q_s16(pp + 88, _t5.val[1]); + p0 += max_jj * batch * 8; + pp += 96; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x8x2_t _r01 = vld2q_s16(p0); + int16x4x2_t _r2 = vld2_s16(p0 + 16); + vst1q_s16(pp, _r01.val[0]); + vst1_s16(pp + 8, _r2.val[0]); + vst1q_s16(pp + 12, _r01.val[1]); + vst1_s16(pp + 20, _r2.val[1]); + p0 += max_jj * batch * 2; + pp += 24; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x8_t _r0 = vld1q_s16(p0); + int16x4_t _r1 = vld1_s16(p0 + 8); + vst1q_s16(pp, _r0); + vst1_s16(pp + 8, _r1); + p0 += max_jj * batch; + pp += 12; + } + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + // transpose 8x8 +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%0, #512] \n" + "prfm pldl1keep, [%0, #1024] \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" + "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n" + "sub %0, %0, #64 \n" + "zip1 v16.8h, v0.8h, v4.8h \n" + "zip2 v20.8h, v0.8h, v4.8h \n" + "zip1 v17.8h, v1.8h, v5.8h \n" + "zip2 v21.8h, v1.8h, v5.8h \n" + "zip1 v18.8h, v2.8h, v6.8h \n" + "zip2 v22.8h, v2.8h, v6.8h \n" + "zip1 v19.8h, v3.8h, v7.8h \n" + "zip2 v23.8h, v3.8h, v7.8h \n" + "st4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n" + "st4 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + p0 += max_jj * batch * 8; +#else // NCNN_GNU_INLINE_ASM + int16x8_t _r0 = vld1q_s16(p0); + int16x8_t _r1 = vld1q_s16(p0 + 8); + int16x8_t _r2 = vld1q_s16(p0 + 16); + int16x8_t _r3 = vld1q_s16(p0 + 24); + int16x8_t _r4 = vld1q_s16(p0 + 32); + int16x8_t _r5 = vld1q_s16(p0 + 40); + int16x8_t _r6 = vld1q_s16(p0 + 48); + int16x8_t _r7 = vld1q_s16(p0 + 56); + int16x8x2_t _r04 = vzipq_s16(_r0, _r4); + int16x8x2_t _r15 = vzipq_s16(_r1, _r5); + int16x8x2_t _r26 = vzipq_s16(_r2, _r6); + int16x8x2_t _r37 = vzipq_s16(_r3, _r7); + int16x8x4_t _r0123; + _r0123.val[0] = _r04.val[0]; + _r0123.val[1] = _r15.val[0]; + _r0123.val[2] = _r26.val[0]; + _r0123.val[3] = _r37.val[0]; + int16x8x4_t _r4567; + _r4567.val[0] = _r04.val[1]; + _r4567.val[1] = _r15.val[1]; + _r4567.val[2] = _r26.val[1]; + _r4567.val[3] = _r37.val[1]; + vst4q_s16(pp, _r0123); + vst4q_s16(pp + 32, _r4567); + p0 += max_jj * batch * 8; + pp += 64; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x8x2_t _r01 = vld2q_s16(p0); + vst1q_s16(pp, _r01.val[0]); + vst1q_s16(pp + 8, _r01.val[1]); + p0 += max_jj * batch * 2; + pp += 16; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x8_t _r0 = vld1q_s16(p0); + vst1q_s16(pp, _r0); + p0 += max_jj * batch; + pp += 8; + } + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #768] \n" + "ld1 {v0.8h, v1.8h, v2.8h}, [%0], #48 \n" + "ld1 {v3.8h, v4.8h, v5.8h}, [%0] \n" + "sub %0, %0, #48 \n" + "zip1 v16.8h, v0.8h, v3.8h \n" + "zip2 v20.8h, v0.8h, v3.8h \n" + "zip1 v17.8h, v1.8h, v4.8h \n" + "zip2 v21.8h, v1.8h, v4.8h \n" + "zip1 v18.8h, v2.8h, v5.8h \n" + "zip2 v22.8h, v2.8h, v5.8h \n" + "st3 {v16.8h, v17.8h, v18.8h}, [%1], #48 \n" + "st3 {v20.8h, v21.8h, v22.8h}, [%1], #48 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v16", "v17", "v18", "v20", "v21", "v22"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #768] \n" + "vldm %0, {d0-d11} \n" + "vzip.16 q0, q3 \n" + "vzip.16 q1, q4 \n" + "vzip.16 q2, q5 \n" + "vst3.s16 {d0,d2,d4}, [%1]! \n" + "vst3.s16 {d1,d3,d5}, [%1]! \n" + "vst3.s16 {d6,d8,d10}, [%1]! \n" + "vst3.s16 {d7,d9,d11}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8_t _r0 = vld1q_s16(p0); + int16x8_t _r1 = vld1q_s16(p0 + 8); + int16x8_t _r2 = vld1q_s16(p0 + 16); + int16x8_t _r3 = vld1q_s16(p0 + 24); + int16x8_t _r4 = vld1q_s16(p0 + 32); + int16x8_t _r5 = vld1q_s16(p0 + 40); + int16x8x2_t _r03 = vzipq_s16(_r0, _r3); + int16x8x2_t _r14 = vzipq_s16(_r1, _r4); + int16x8x2_t _r25 = vzipq_s16(_r2, _r5); + int16x8x3_t _r012; + _r012.val[0] = _r03.val[0]; + _r012.val[1] = _r14.val[0]; + _r012.val[2] = _r25.val[0]; + int16x8x3_t _r345; + _r345.val[0] = _r03.val[1]; + _r345.val[1] = _r14.val[1]; + _r345.val[2] = _r25.val[1]; + vst3q_s16(pp, _r012); + vst3q_s16(pp + 24, _r345); + p0 += max_jj * batch * 8; + pp += 48; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x8x2_t _r01 = vld2q_s16(p0); + int32x4x2_t _r01x = vtrnq_s32(vreinterpretq_s32_s16(_r01.val[0]), vreinterpretq_s32_s16(_r01.val[1])); + int32x2x3_t _r012; + _r012.val[0] = vget_low_s32(_r01x.val[0]); + _r012.val[1] = vget_low_s32(_r01x.val[1]); + _r012.val[2] = vget_high_s32(_r01x.val[0]); + vst3_s32((int*)pp, _r012); + p0 += max_jj * batch * 2; + pp += 12; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x4_t _r0 = vld1_s16(p0); + vst1_s16(pp, _r0); + pp[4] = p0[4]; + pp[5] = p0[5]; + p0 += max_jj * batch; + pp += 6; + } + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #512] \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n" + "st4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1", "v2", "v3"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #512] \n" + "vldm %0, {d0-d7} \n" + "vst4.s16 {d0,d2,d4,d6}, [%1]! \n" + "vst4.s16 {d1,d3,d5,d7}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0", "q1", "q2", "q3"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8x4_t _r0123; + _r0123.val[0] = vld1q_s16(p0); + _r0123.val[1] = vld1q_s16(p0 + 8); + _r0123.val[2] = vld1q_s16(p0 + 16); + _r0123.val[3] = vld1q_s16(p0 + 24); + vst4q_s16(pp, _r0123); + p0 += max_jj * batch * 8; + pp += 32; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + int16x4x2_t _r01 = vld2_s16(p0); + vst1_s16(pp, _r01.val[0]); + vst1_s16(pp + 4, _r01.val[1]); + p0 += max_jj * batch * 2; + pp += 8; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + int16x4_t _r0 = vld1_s16(p0); + vst1_s16(pp, _r0); + p0 += max_jj * batch; + pp += 4; + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const short* p0 = B; + + int kk = 0; +#if __ARM_NEON + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #256] \n" + "ld1 {v0.8h, v1.8h}, [%0] \n" + "st2 {v0.8h, v1.8h}, [%1], #32 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0", "v1"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #256] \n" + "vld1.s16 {d0-d3}, [%0] \n" + "vst2.s16 {d0-d3}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0", "q1"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8x2_t _r01; + _r01.val[0] = vld1q_s16(p0); + _r01.val[1] = vld1q_s16(p0 + 8); + vst2q_s16(pp, _r01); + p0 += max_jj * batch * 8; + pp += 16; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; +#endif // __ARM_NEON + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; +#else + pp[0] = p0[0]; + pp[1] = p0[2]; + pp[2] = p0[1]; + pp[3] = p0[3]; +#endif + p0 += max_jj * batch * 2; + pp += 4; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch; + pp += 2; + } + } + for (; jj < max_jj; jj++) + { + const short* p0 = B; + + int kk = 0; +#if __ARM_NEON + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #128] \n" + "ld1 {v0.8h}, [%0] \n" + "st1 {v0.8h}, [%1], #16 \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "v0"); + p0 += max_jj * batch * 8; +#else // __aarch64__ + asm volatile( + "pld [%0, #128] \n" + "vld1.s16 {d0-d1}, [%0] \n" + "vst1.s16 {d0-d1}, [%1]! \n" + : "=r"(p0), // %0 + "=r"(pp) // %1 + : "0"(p0), + "1"(pp) + : "memory", "q0"); + p0 += max_jj * batch * 8; +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int16x8_t _r0 = vld1q_s16(p0); + vst1q_s16(pp, _r0); + p0 += max_jj * batch * 8; + pp += 8; +#endif // NCNN_GNU_INLINE_ASM + } + p0 -= (b * max_jj + jj) * 8; +#endif // __ARM_NEON + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch * 2; + pp += 2; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += max_jj * batch; + pp += 1; + } + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& top_blob, int batch, int max_ii, int max_jj, int k, int max_kk) +{ + // return; + // NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d", max_ii, max_jj, max_kk); + + int* outptr = top_blob; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #512] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0], #64 \n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%0], #64 \n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #320 \n" + "b 1f \n" + + "0: \n" + "eor v8.16b, v8.16b, v8.16b \n" + "eor v9.16b, v9.16b, v9.16b \n" + "eor v10.16b, v10.16b, v10.16b \n" + "eor v11.16b, v11.16b, v11.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v8.4s, v4.4h, v0.h[0] \n" + "smlal v10.4s, v4.4h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal2 v9.4s, v4.8h, v0.h[0] \n" + "smlal2 v11.4s, v4.8h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal v12.4s, v4.4h, v0.h[2] \n" + "smlal v14.4s, v4.4h, v0.h[3] \n" + "smlal2 v13.4s, v4.8h, v0.h[2] \n" + "smlal2 v15.4s, v4.8h, v0.h[3] \n" + "smlal v16.4s, v4.4h, v0.h[4] \n" + "smlal v18.4s, v4.4h, v0.h[5] \n" + "smlal2 v17.4s, v4.8h, v0.h[4] \n" + "smlal2 v19.4s, v4.8h, v0.h[5] \n" + "smlal v20.4s, v4.4h, v0.h[6] \n" + "smlal v22.4s, v4.4h, v0.h[7] \n" + "smlal2 v21.4s, v4.8h, v0.h[6] \n" + "smlal2 v23.4s, v4.8h, v0.h[7] \n" + "smlal v24.4s, v4.4h, v1.h[0] \n" + "smlal v26.4s, v4.4h, v1.h[1] \n" + "smlal2 v25.4s, v4.8h, v1.h[0] \n" + "smlal2 v27.4s, v4.8h, v1.h[1] \n" + "smlal v28.4s, v4.4h, v1.h[2] \n" + "smlal v30.4s, v4.4h, v1.h[3] \n" + "smlal2 v29.4s, v4.8h, v1.h[2] \n" + "smlal2 v31.4s, v4.8h, v1.h[3] \n" + "smlal v8.4s, v5.4h, v1.h[4] \n" + "smlal v10.4s, v5.4h, v1.h[5] \n" + "smlal2 v9.4s, v5.8h, v1.h[4] \n" + "smlal2 v11.4s, v5.8h, v1.h[5] \n" + "smlal v12.4s, v5.4h, v1.h[6] \n" + "smlal v14.4s, v5.4h, v1.h[7] \n" + "smlal2 v13.4s, v5.8h, v1.h[6] \n" + "smlal2 v15.4s, v5.8h, v1.h[7] \n" + "smlal v16.4s, v5.4h, v2.h[0] \n" + "smlal v18.4s, v5.4h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v17.4s, v5.8h, v2.h[0] \n" + "smlal2 v19.4s, v5.8h, v2.h[1] \n" + "smlal v20.4s, v5.4h, v2.h[2] \n" + "smlal v22.4s, v5.4h, v2.h[3] \n" + "smlal2 v21.4s, v5.8h, v2.h[2] \n" + "smlal2 v23.4s, v5.8h, v2.h[3] \n" + "smlal v24.4s, v5.4h, v2.h[4] \n" + "smlal v26.4s, v5.4h, v2.h[5] \n" + "smlal2 v25.4s, v5.8h, v2.h[4] \n" + "smlal2 v27.4s, v5.8h, v2.h[5] \n" + "smlal v28.4s, v5.4h, v2.h[6] \n" + "smlal v30.4s, v5.4h, v2.h[7] \n" + "smlal2 v29.4s, v5.8h, v2.h[6] \n" + "smlal2 v31.4s, v5.8h, v2.h[7] \n" + "smlal v8.4s, v6.4h, v3.h[0] \n" + "smlal v10.4s, v6.4h, v3.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v9.4s, v6.8h, v3.h[0] \n" + "smlal2 v11.4s, v6.8h, v3.h[1] \n" + "smlal v12.4s, v6.4h, v3.h[2] \n" + "smlal v14.4s, v6.4h, v3.h[3] \n" + "smlal2 v13.4s, v6.8h, v3.h[2] \n" + "smlal2 v15.4s, v6.8h, v3.h[3] \n" + "smlal v16.4s, v6.4h, v3.h[4] \n" + "smlal v18.4s, v6.4h, v3.h[5] \n" + "smlal2 v17.4s, v6.8h, v3.h[4] \n" + "smlal2 v19.4s, v6.8h, v3.h[5] \n" + "smlal v20.4s, v6.4h, v3.h[6] \n" + "smlal v22.4s, v6.4h, v3.h[7] \n" + "smlal2 v21.4s, v6.8h, v3.h[6] \n" + "smlal2 v23.4s, v6.8h, v3.h[7] \n" + "smlal v24.4s, v6.4h, v0.h[0] \n" + "smlal v26.4s, v6.4h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal2 v25.4s, v6.8h, v0.h[0] \n" + "smlal2 v27.4s, v6.8h, v0.h[1] \n" + "smlal v28.4s, v6.4h, v0.h[2] \n" + "smlal v30.4s, v6.4h, v0.h[3] \n" + "smlal2 v29.4s, v6.8h, v0.h[2] \n" + "smlal2 v31.4s, v6.8h, v0.h[3] \n" + "smlal v8.4s, v7.4h, v0.h[4] \n" + "smlal v10.4s, v7.4h, v0.h[5] \n" + "smlal2 v9.4s, v7.8h, v0.h[4] \n" + "smlal2 v11.4s, v7.8h, v0.h[5] \n" + "smlal v12.4s, v7.4h, v0.h[6] \n" + "smlal v14.4s, v7.4h, v0.h[7] \n" + "smlal2 v13.4s, v7.8h, v0.h[6] \n" + "smlal2 v15.4s, v7.8h, v0.h[7] \n" + "smlal v16.4s, v7.4h, v1.h[0] \n" + "smlal v18.4s, v7.4h, v1.h[1] \n" + "smlal2 v17.4s, v7.8h, v1.h[0] \n" + "smlal2 v19.4s, v7.8h, v1.h[1] \n" + "smlal v20.4s, v7.4h, v1.h[2] \n" + "smlal v22.4s, v7.4h, v1.h[3] \n" + "smlal2 v21.4s, v7.8h, v1.h[2] \n" + "smlal2 v23.4s, v7.8h, v1.h[3] \n" + "smlal v24.4s, v7.4h, v1.h[4] \n" + "smlal v26.4s, v7.4h, v1.h[5] \n" + "smlal2 v25.4s, v7.8h, v1.h[4] \n" + "smlal2 v27.4s, v7.8h, v1.h[5] \n" + "smlal v28.4s, v7.4h, v1.h[6] \n" + "smlal v30.4s, v7.4h, v1.h[7] \n" + "smlal2 v29.4s, v7.8h, v1.h[6] \n" + "smlal2 v31.4s, v7.8h, v1.h[7] \n" + "smlal v8.4s, v4.4h, v2.h[0] \n" + "smlal v10.4s, v4.4h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v9.4s, v4.8h, v2.h[0] \n" + "smlal2 v11.4s, v4.8h, v2.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal v12.4s, v4.4h, v2.h[2] \n" + "smlal v14.4s, v4.4h, v2.h[3] \n" + "smlal2 v13.4s, v4.8h, v2.h[2] \n" + "smlal2 v15.4s, v4.8h, v2.h[3] \n" + "smlal v16.4s, v4.4h, v2.h[4] \n" + "smlal v18.4s, v4.4h, v2.h[5] \n" + "smlal2 v17.4s, v4.8h, v2.h[4] \n" + "smlal2 v19.4s, v4.8h, v2.h[5] \n" + "smlal v20.4s, v4.4h, v2.h[6] \n" + "smlal v22.4s, v4.4h, v2.h[7] \n" + "smlal2 v21.4s, v4.8h, v2.h[6] \n" + "smlal2 v23.4s, v4.8h, v2.h[7] \n" + "smlal v24.4s, v4.4h, v3.h[0] \n" + "smlal v26.4s, v4.4h, v3.h[1] \n" + "smlal2 v25.4s, v4.8h, v3.h[0] \n" + "smlal2 v27.4s, v4.8h, v3.h[1] \n" + "smlal v28.4s, v4.4h, v3.h[2] \n" + "smlal v30.4s, v4.4h, v3.h[3] \n" + "smlal2 v29.4s, v4.8h, v3.h[2] \n" + "smlal2 v31.4s, v4.8h, v3.h[3] \n" + "smlal v8.4s, v5.4h, v3.h[4] \n" + "smlal v10.4s, v5.4h, v3.h[5] \n" + "smlal2 v9.4s, v5.8h, v3.h[4] \n" + "smlal2 v11.4s, v5.8h, v3.h[5] \n" + "smlal v12.4s, v5.4h, v3.h[6] \n" + "smlal v14.4s, v5.4h, v3.h[7] \n" + "smlal2 v13.4s, v5.8h, v3.h[6] \n" + "smlal2 v15.4s, v5.8h, v3.h[7] \n" + "smlal v16.4s, v5.4h, v0.h[0] \n" + "smlal v18.4s, v5.4h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal2 v17.4s, v5.8h, v0.h[0] \n" + "smlal2 v19.4s, v5.8h, v0.h[1] \n" + "smlal v20.4s, v5.4h, v0.h[2] \n" + "smlal v22.4s, v5.4h, v0.h[3] \n" + "smlal2 v21.4s, v5.8h, v0.h[2] \n" + "smlal2 v23.4s, v5.8h, v0.h[3] \n" + "smlal v24.4s, v5.4h, v0.h[4] \n" + "smlal v26.4s, v5.4h, v0.h[5] \n" + "smlal2 v25.4s, v5.8h, v0.h[4] \n" + "smlal2 v27.4s, v5.8h, v0.h[5] \n" + "smlal v28.4s, v5.4h, v0.h[6] \n" + "smlal v30.4s, v5.4h, v0.h[7] \n" + "smlal2 v29.4s, v5.8h, v0.h[6] \n" + "smlal2 v31.4s, v5.8h, v0.h[7] \n" + "smlal v8.4s, v6.4h, v1.h[0] \n" + "smlal v10.4s, v6.4h, v1.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v9.4s, v6.8h, v1.h[0] \n" + "smlal2 v11.4s, v6.8h, v1.h[1] \n" + "smlal v12.4s, v6.4h, v1.h[2] \n" + "smlal v14.4s, v6.4h, v1.h[3] \n" + "smlal2 v13.4s, v6.8h, v1.h[2] \n" + "smlal2 v15.4s, v6.8h, v1.h[3] \n" + "smlal v16.4s, v6.4h, v1.h[4] \n" + "smlal v18.4s, v6.4h, v1.h[5] \n" + "smlal2 v17.4s, v6.8h, v1.h[4] \n" + "smlal2 v19.4s, v6.8h, v1.h[5] \n" + "smlal v20.4s, v6.4h, v1.h[6] \n" + "smlal v22.4s, v6.4h, v1.h[7] \n" + "smlal2 v21.4s, v6.8h, v1.h[6] \n" + "smlal2 v23.4s, v6.8h, v1.h[7] \n" + "smlal v24.4s, v6.4h, v2.h[0] \n" + "smlal v26.4s, v6.4h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v25.4s, v6.8h, v2.h[0] \n" + "smlal2 v27.4s, v6.8h, v2.h[1] \n" + "smlal v28.4s, v6.4h, v2.h[2] \n" + "smlal v30.4s, v6.4h, v2.h[3] \n" + "smlal2 v29.4s, v6.8h, v2.h[2] \n" + "smlal2 v31.4s, v6.8h, v2.h[3] \n" + "smlal v8.4s, v7.4h, v2.h[4] \n" + "smlal v10.4s, v7.4h, v2.h[5] \n" + "smlal2 v9.4s, v7.8h, v2.h[4] \n" + "smlal2 v11.4s, v7.8h, v2.h[5] \n" + "smlal v12.4s, v7.4h, v2.h[6] \n" + "smlal v14.4s, v7.4h, v2.h[7] \n" + "smlal2 v13.4s, v7.8h, v2.h[6] \n" + "smlal2 v15.4s, v7.8h, v2.h[7] \n" + "smlal v16.4s, v7.4h, v3.h[0] \n" + "smlal v18.4s, v7.4h, v3.h[1] \n" + "smlal2 v17.4s, v7.8h, v3.h[0] \n" + "smlal2 v19.4s, v7.8h, v3.h[1] \n" + "smlal v20.4s, v7.4h, v3.h[2] \n" + "smlal v22.4s, v7.4h, v3.h[3] \n" + "smlal2 v21.4s, v7.8h, v3.h[2] \n" + "smlal2 v23.4s, v7.8h, v3.h[3] \n" + "smlal v24.4s, v7.4h, v3.h[4] \n" + "smlal v26.4s, v7.4h, v3.h[5] \n" + "smlal2 v25.4s, v7.8h, v3.h[4] \n" + "smlal2 v27.4s, v7.8h, v3.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v7.4h, v3.h[6] \n" + "smlal v30.4s, v7.4h, v3.h[7] \n" + "smlal2 v29.4s, v7.8h, v3.h[6] \n" + "smlal2 v31.4s, v7.8h, v3.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.4h, v1.4h, v2.4h}, [%2], #24 \n" + "smlal v8.4s, v4.4h, v0.h[0] \n" + "smlal v10.4s, v4.4h, v0.h[1] \n" + "smlal2 v9.4s, v4.8h, v0.h[0] \n" + "smlal2 v11.4s, v4.8h, v0.h[1] \n" + "smlal v12.4s, v4.4h, v0.h[2] \n" + "smlal v14.4s, v4.4h, v0.h[3] \n" + "smlal2 v13.4s, v4.8h, v0.h[2] \n" + "smlal2 v15.4s, v4.8h, v0.h[3] \n" + "smlal v16.4s, v4.4h, v1.h[0] \n" + "smlal v18.4s, v4.4h, v1.h[1] \n" + "smlal2 v17.4s, v4.8h, v1.h[0] \n" + "smlal2 v19.4s, v4.8h, v1.h[1] \n" + "smlal v20.4s, v4.4h, v1.h[2] \n" + "smlal v22.4s, v4.4h, v1.h[3] \n" + "smlal2 v21.4s, v4.8h, v1.h[2] \n" + "smlal2 v23.4s, v4.8h, v1.h[3] \n" + "smlal v24.4s, v4.4h, v2.h[0] \n" + "smlal v26.4s, v4.4h, v2.h[1] \n" + "smlal2 v25.4s, v4.8h, v2.h[0] \n" + "smlal2 v27.4s, v4.8h, v2.h[1] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v4.4h, v2.h[2] \n" + "smlal v30.4s, v4.4h, v2.h[3] \n" + "smlal2 v29.4s, v4.8h, v2.h[2] \n" + "smlal2 v31.4s, v4.8h, v2.h[3] \n" + "bne 4b \n" + + "5: \n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%0], #64 \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + int32x4_t _sumc; + int32x4_t _sumd; + int32x4_t _sume; + int32x4_t _sumf; + int32x4_t _sumg; + int32x4_t _sumh; + int32x4_t _sumi; + int32x4_t _sumj; + int32x4_t _sumk; + int32x4_t _suml; + int32x4_t _summ; + int32x4_t _sumn; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + _sumg = vdupq_n_s32(0); + _sumh = vdupq_n_s32(0); + _sumi = vdupq_n_s32(0); + _sumj = vdupq_n_s32(0); + _sumk = vdupq_n_s32(0); + _suml = vdupq_n_s32(0); + _summ = vdupq_n_s32(0); + _sumn = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + _sumg = vld1q_s32(outptr + 64); + _sumh = vld1q_s32(outptr + 68); + _sumi = vld1q_s32(outptr + 72); + _sumj = vld1q_s32(outptr + 76); + _sumk = vld1q_s32(outptr + 80); + _suml = vld1q_s32(outptr + 84); + _summ = vld1q_s32(outptr + 88); + _sumn = vld1q_s32(outptr + 92); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_pA), _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, vget_high_s16(_pA), _pB, 0); + _sum2 = vmlal_laneq_s16(_sum2, vget_low_s16(_pA), _pB, 1); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_pA), _pB, 1); + _sum4 = vmlal_laneq_s16(_sum4, vget_low_s16(_pA), _pB, 2); + _sum5 = vmlal_laneq_s16(_sum5, vget_high_s16(_pA), _pB, 2); + _sum6 = vmlal_laneq_s16(_sum6, vget_low_s16(_pA), _pB, 3); + _sum7 = vmlal_laneq_s16(_sum7, vget_high_s16(_pA), _pB, 3); + _sum8 = vmlal_laneq_s16(_sum8, vget_low_s16(_pA), _pB, 4); + _sum9 = vmlal_laneq_s16(_sum9, vget_high_s16(_pA), _pB, 4); + _suma = vmlal_laneq_s16(_suma, vget_low_s16(_pA), _pB, 5); + _sumb = vmlal_laneq_s16(_sumb, vget_high_s16(_pA), _pB, 5); + _sumc = vmlal_laneq_s16(_sumc, vget_low_s16(_pA), _pB, 6); + _sumd = vmlal_laneq_s16(_sumd, vget_high_s16(_pA), _pB, 6); + _sume = vmlal_laneq_s16(_sume, vget_low_s16(_pA), _pB, 7); + _sumf = vmlal_laneq_s16(_sumf, vget_high_s16(_pA), _pB, 7); + _sumg = vmlal_lane_s16(_sumg, vget_low_s16(_pA), _pB2, 0); + _sumh = vmlal_lane_s16(_sumh, vget_high_s16(_pA), _pB2, 0); + _sumi = vmlal_lane_s16(_sumi, vget_low_s16(_pA), _pB2, 1); + _sumj = vmlal_lane_s16(_sumj, vget_high_s16(_pA), _pB2, 1); + _sumk = vmlal_lane_s16(_sumk, vget_low_s16(_pA), _pB2, 2); + _suml = vmlal_lane_s16(_suml, vget_high_s16(_pA), _pB2, 2); + _summ = vmlal_lane_s16(_summ, vget_low_s16(_pA), _pB2, 3); + _sumn = vmlal_lane_s16(_sumn, vget_high_s16(_pA), _pB2, 3); + pA += 8; + pB += 12; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + vst1q_s32(outptr + 48, _sumc); + vst1q_s32(outptr + 52, _sumd); + vst1q_s32(outptr + 56, _sume); + vst1q_s32(outptr + 60, _sumf); + vst1q_s32(outptr + 64, _sumg); + vst1q_s32(outptr + 68, _sumh); + vst1q_s32(outptr + 72, _sumi); + vst1q_s32(outptr + 76, _sumj); + vst1q_s32(outptr + 80, _sumk); + vst1q_s32(outptr + 84, _suml); + vst1q_s32(outptr + 88, _summ); + vst1q_s32(outptr + 92, _sumn); + outptr += 96; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #512] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v16.4s, v4.4h, v0.h[0] \n" + "smlal v18.4s, v4.4h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v17.4s, v4.8h, v0.h[0] \n" + "smlal2 v19.4s, v4.8h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal v20.4s, v4.4h, v0.h[2] \n" + "smlal v22.4s, v4.4h, v0.h[3] \n" + "smlal2 v21.4s, v4.8h, v0.h[2] \n" + "smlal2 v23.4s, v4.8h, v0.h[3] \n" + "smlal v24.4s, v4.4h, v0.h[4] \n" + "smlal v26.4s, v4.4h, v0.h[5] \n" + "smlal2 v25.4s, v4.8h, v0.h[4] \n" + "smlal2 v27.4s, v4.8h, v0.h[5] \n" + "smlal v28.4s, v4.4h, v0.h[6] \n" + "smlal v30.4s, v4.4h, v0.h[7] \n" + "smlal2 v29.4s, v4.8h, v0.h[6] \n" + "smlal2 v31.4s, v4.8h, v0.h[7] \n" + "smlal v16.4s, v5.4h, v1.h[0] \n" + "smlal v18.4s, v5.4h, v1.h[1] \n" + "smlal2 v17.4s, v5.8h, v1.h[0] \n" + "smlal2 v19.4s, v5.8h, v1.h[1] \n" + "smlal v20.4s, v5.4h, v1.h[2] \n" + "smlal v22.4s, v5.4h, v1.h[3] \n" + "smlal2 v21.4s, v5.8h, v1.h[2] \n" + "smlal2 v23.4s, v5.8h, v1.h[3] \n" + "smlal v24.4s, v5.4h, v1.h[4] \n" + "smlal v26.4s, v5.4h, v1.h[5] \n" + "smlal2 v25.4s, v5.8h, v1.h[4] \n" + "smlal2 v27.4s, v5.8h, v1.h[5] \n" + "smlal v28.4s, v5.4h, v1.h[6] \n" + "smlal v30.4s, v5.4h, v1.h[7] \n" + "smlal2 v29.4s, v5.8h, v1.h[6] \n" + "smlal2 v31.4s, v5.8h, v1.h[7] \n" + "smlal v16.4s, v6.4h, v2.h[0] \n" + "smlal v18.4s, v6.4h, v2.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v17.4s, v6.8h, v2.h[0] \n" + "smlal2 v19.4s, v6.8h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal v20.4s, v6.4h, v2.h[2] \n" + "smlal v22.4s, v6.4h, v2.h[3] \n" + "smlal2 v21.4s, v6.8h, v2.h[2] \n" + "smlal2 v23.4s, v6.8h, v2.h[3] \n" + "smlal v24.4s, v6.4h, v2.h[4] \n" + "smlal v26.4s, v6.4h, v2.h[5] \n" + "smlal2 v25.4s, v6.8h, v2.h[4] \n" + "smlal2 v27.4s, v6.8h, v2.h[5] \n" + "smlal v28.4s, v6.4h, v2.h[6] \n" + "smlal v30.4s, v6.4h, v2.h[7] \n" + "smlal2 v29.4s, v6.8h, v2.h[6] \n" + "smlal2 v31.4s, v6.8h, v2.h[7] \n" + "smlal v16.4s, v7.4h, v3.h[0] \n" + "smlal v18.4s, v7.4h, v3.h[1] \n" + "smlal2 v17.4s, v7.8h, v3.h[0] \n" + "smlal2 v19.4s, v7.8h, v3.h[1] \n" + "smlal v20.4s, v7.4h, v3.h[2] \n" + "smlal v22.4s, v7.4h, v3.h[3] \n" + "smlal2 v21.4s, v7.8h, v3.h[2] \n" + "smlal2 v23.4s, v7.8h, v3.h[3] \n" + "subs w4, w4, #1 \n" + "smlal v24.4s, v7.4h, v3.h[4] \n" + "smlal v26.4s, v7.4h, v3.h[5] \n" + "smlal2 v25.4s, v7.8h, v3.h[4] \n" + "smlal2 v27.4s, v7.8h, v3.h[5] \n" + "smlal v28.4s, v7.4h, v3.h[6] \n" + "smlal v30.4s, v7.4h, v3.h[7] \n" + "smlal2 v29.4s, v7.8h, v3.h[6] \n" + "smlal2 v31.4s, v7.8h, v3.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #3 \n" // w4 = remain = max_kk & 3 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.8h}, [%2], #16 \n" + "smlal v16.4s, v4.4h, v0.h[0] \n" + "smlal v18.4s, v4.4h, v0.h[1] \n" + "smlal2 v17.4s, v4.8h, v0.h[0] \n" + "smlal2 v19.4s, v4.8h, v0.h[1] \n" + "smlal v20.4s, v4.4h, v0.h[2] \n" + "smlal v22.4s, v4.4h, v0.h[3] \n" + "smlal2 v21.4s, v4.8h, v0.h[2] \n" + "smlal2 v23.4s, v4.8h, v0.h[3] \n" + "subs w4, w4, #1 \n" + "smlal v24.4s, v4.4h, v0.h[4] \n" + "smlal v26.4s, v4.4h, v0.h[5] \n" + "smlal2 v25.4s, v4.8h, v0.h[4] \n" + "smlal2 v27.4s, v4.8h, v0.h[5] \n" + "smlal v28.4s, v4.4h, v0.h[6] \n" + "smlal v30.4s, v4.4h, v0.h[7] \n" + "smlal2 v29.4s, v4.8h, v0.h[6] \n" + "smlal2 v31.4s, v4.8h, v0.h[7] \n" + "bne 4b \n" + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + int32x4_t _sumc; + int32x4_t _sumd; + int32x4_t _sume; + int32x4_t _sumf; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_pA), _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, vget_high_s16(_pA), _pB, 0); + _sum2 = vmlal_laneq_s16(_sum2, vget_low_s16(_pA), _pB, 1); + _sum3 = vmlal_laneq_s16(_sum3, vget_high_s16(_pA), _pB, 1); + _sum4 = vmlal_laneq_s16(_sum4, vget_low_s16(_pA), _pB, 2); + _sum5 = vmlal_laneq_s16(_sum5, vget_high_s16(_pA), _pB, 2); + _sum6 = vmlal_laneq_s16(_sum6, vget_low_s16(_pA), _pB, 3); + _sum7 = vmlal_laneq_s16(_sum7, vget_high_s16(_pA), _pB, 3); + _sum8 = vmlal_laneq_s16(_sum8, vget_low_s16(_pA), _pB, 4); + _sum9 = vmlal_laneq_s16(_sum9, vget_high_s16(_pA), _pB, 4); + _suma = vmlal_laneq_s16(_suma, vget_low_s16(_pA), _pB, 5); + _sumb = vmlal_laneq_s16(_sumb, vget_high_s16(_pA), _pB, 5); + _sumc = vmlal_laneq_s16(_sumc, vget_low_s16(_pA), _pB, 6); + _sumd = vmlal_laneq_s16(_sumd, vget_high_s16(_pA), _pB, 6); + _sume = vmlal_laneq_s16(_sume, vget_low_s16(_pA), _pB, 7); + _sumf = vmlal_laneq_s16(_sumf, vget_high_s16(_pA), _pB, 7); + pA += 8; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + vst1q_s32(outptr + 48, _sumc); + vst1q_s32(outptr + 52, _sumd); + vst1q_s32(outptr + 56, _sume); + vst1q_s32(outptr + 60, _sumf); + outptr += 64; +#endif // NCNN_GNU_INLINE_ASM + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #384] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #128 \n" + "b 1f \n" + + "0: \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v20.4s, v6.4h, v0.h[0] \n" + "smlal v22.4s, v6.4h, v0.h[1] \n" + "ld1 {v8.8h, v9.8h}, [%1], #32 \n" + "smlal2 v21.4s, v6.8h, v0.h[0] \n" + "smlal2 v23.4s, v6.8h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal v24.4s, v6.4h, v0.h[2] \n" + "smlal v26.4s, v6.4h, v0.h[3] \n" + "smlal2 v25.4s, v6.8h, v0.h[2] \n" + "smlal2 v27.4s, v6.8h, v0.h[3] \n" + "smlal v28.4s, v6.4h, v0.h[4] \n" + "smlal v30.4s, v6.4h, v0.h[5] \n" + "smlal2 v29.4s, v6.8h, v0.h[4] \n" + "smlal2 v31.4s, v6.8h, v0.h[5] \n" + "smlal v20.4s, v7.4h, v0.h[6] \n" + "smlal v22.4s, v7.4h, v0.h[7] \n" + "smlal2 v21.4s, v7.8h, v0.h[6] \n" + "smlal2 v23.4s, v7.8h, v0.h[7] \n" + "smlal v24.4s, v7.4h, v1.h[0] \n" + "smlal v26.4s, v7.4h, v1.h[1] \n" + "smlal2 v25.4s, v7.8h, v1.h[0] \n" + "smlal2 v27.4s, v7.8h, v1.h[1] \n" + "smlal v28.4s, v7.4h, v1.h[2] \n" + "smlal v30.4s, v7.4h, v1.h[3] \n" + "smlal2 v29.4s, v7.8h, v1.h[2] \n" + "smlal2 v31.4s, v7.8h, v1.h[3] \n" + "smlal v20.4s, v8.4h, v1.h[4] \n" + "smlal v22.4s, v8.4h, v1.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v21.4s, v8.8h, v1.h[4] \n" + "smlal2 v23.4s, v8.8h, v1.h[5] \n" + "smlal v24.4s, v8.4h, v1.h[6] \n" + "smlal v26.4s, v8.4h, v1.h[7] \n" + "smlal2 v25.4s, v8.8h, v1.h[6] \n" + "smlal2 v27.4s, v8.8h, v1.h[7] \n" + "smlal v28.4s, v8.4h, v2.h[0] \n" + "smlal v30.4s, v8.4h, v2.h[1] \n" + "ld1 {v4.8h, v5.8h}, [%2], #32 \n" + "smlal2 v29.4s, v8.8h, v2.h[0] \n" + "smlal2 v31.4s, v8.8h, v2.h[1] \n" + "smlal v20.4s, v9.4h, v2.h[2] \n" + "smlal v22.4s, v9.4h, v2.h[3] \n" + "smlal2 v21.4s, v9.8h, v2.h[2] \n" + "smlal2 v23.4s, v9.8h, v2.h[3] \n" + "smlal v24.4s, v9.4h, v2.h[4] \n" + "smlal v26.4s, v9.4h, v2.h[5] \n" + "smlal2 v25.4s, v9.8h, v2.h[4] \n" + "smlal2 v27.4s, v9.8h, v2.h[5] \n" + "smlal v28.4s, v9.4h, v2.h[6] \n" + "smlal v30.4s, v9.4h, v2.h[7] \n" + "smlal2 v29.4s, v9.8h, v2.h[6] \n" + "smlal2 v31.4s, v9.8h, v2.h[7] \n" + "smlal v20.4s, v6.4h, v3.h[0] \n" + "smlal v22.4s, v6.4h, v3.h[1] \n" + "ld1 {v8.8h, v9.8h}, [%1], #32 \n" + "smlal2 v21.4s, v6.8h, v3.h[0] \n" + "smlal2 v23.4s, v6.8h, v3.h[1] \n" + "smlal v24.4s, v6.4h, v3.h[2] \n" + "smlal v26.4s, v6.4h, v3.h[3] \n" + "smlal2 v25.4s, v6.8h, v3.h[2] \n" + "smlal2 v27.4s, v6.8h, v3.h[3] \n" + "smlal v28.4s, v6.4h, v3.h[4] \n" + "smlal v30.4s, v6.4h, v3.h[5] \n" + "smlal2 v29.4s, v6.8h, v3.h[4] \n" + "smlal2 v31.4s, v6.8h, v3.h[5] \n" + "smlal v20.4s, v7.4h, v3.h[6] \n" + "smlal v22.4s, v7.4h, v3.h[7] \n" + "smlal2 v21.4s, v7.8h, v3.h[6] \n" + "smlal2 v23.4s, v7.8h, v3.h[7] \n" + "smlal v24.4s, v7.4h, v4.h[0] \n" + "smlal v26.4s, v7.4h, v4.h[1] \n" + "prfm pldl1keep, [%2, #384] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal2 v25.4s, v7.8h, v4.h[0] \n" + "smlal2 v27.4s, v7.8h, v4.h[1] \n" + "smlal v28.4s, v7.4h, v4.h[2] \n" + "smlal v30.4s, v7.4h, v4.h[3] \n" + "smlal2 v29.4s, v7.8h, v4.h[2] \n" + "smlal2 v31.4s, v7.8h, v4.h[3] \n" + "smlal v20.4s, v8.4h, v4.h[4] \n" + "smlal v22.4s, v8.4h, v4.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v21.4s, v8.8h, v4.h[4] \n" + "smlal2 v23.4s, v8.8h, v4.h[5] \n" + "smlal v24.4s, v8.4h, v4.h[6] \n" + "smlal v26.4s, v8.4h, v4.h[7] \n" + "smlal2 v25.4s, v8.8h, v4.h[6] \n" + "smlal2 v27.4s, v8.8h, v4.h[7] \n" + "smlal v28.4s, v8.4h, v5.h[0] \n" + "smlal v30.4s, v8.4h, v5.h[1] \n" + "smlal2 v29.4s, v8.8h, v5.h[0] \n" + "smlal2 v31.4s, v8.8h, v5.h[1] \n" + "smlal v20.4s, v9.4h, v5.h[2] \n" + "smlal v22.4s, v9.4h, v5.h[3] \n" + "smlal2 v21.4s, v9.8h, v5.h[2] \n" + "smlal2 v23.4s, v9.8h, v5.h[3] \n" + "smlal v24.4s, v9.4h, v5.h[4] \n" + "smlal v26.4s, v9.4h, v5.h[5] \n" + "smlal2 v25.4s, v9.8h, v5.h[4] \n" + "smlal2 v27.4s, v9.8h, v5.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v9.4h, v5.h[6] \n" + "smlal v30.4s, v9.4h, v5.h[7] \n" + "smlal2 v29.4s, v9.8h, v5.h[6] \n" + "smlal2 v31.4s, v9.8h, v5.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.8h}, [%2] \n" + "add %2, %2, #12 \n" + "smlal v20.4s, v4.4h, v0.h[0] \n" + "smlal v22.4s, v4.4h, v0.h[1] \n" + "smlal2 v21.4s, v4.8h, v0.h[0] \n" + "smlal2 v23.4s, v4.8h, v0.h[1] \n" + "smlal v24.4s, v4.4h, v0.h[2] \n" + "smlal v26.4s, v4.4h, v0.h[3] \n" + "smlal2 v25.4s, v4.8h, v0.h[2] \n" + "smlal2 v27.4s, v4.8h, v0.h[3] \n" + "smlal v28.4s, v4.4h, v0.h[4] \n" + "smlal v30.4s, v4.4h, v0.h[5] \n" + "smlal2 v29.4s, v4.8h, v0.h[4] \n" + "smlal2 v31.4s, v4.8h, v0.h[5] \n" + "subs w4, w4, #1 \n" + "bne 4b \n" + + "5: \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #384] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0!, {d8-d15} \n" + "vldm %0, {d16-d31} \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "veor q4, q4 \n" + "veor q5, q5 \n" + "veor q6, q6 \n" + "veor q7, q7 \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #3 \n" // r4 = max_kk >> 3 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d4-d5}, [%1]! \n" + "vld1.s16 {d0-d1}, [%2]! \n" + ".align 4 \n" + "2: \n" + "vmlal.s16 q4, d4, d0[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d0[1] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q8, d4, d0[2] \n" + "vmlal.s16 q10, d4, d0[3] \n" + "vmlal.s16 q5, d5, d0[0] \n" + "vmlal.s16 q7, d5, d0[1] \n" + "vmlal.s16 q9, d5, d0[2] \n" + "vmlal.s16 q11, d5, d0[3] \n" + "vmlal.s16 q12, d4, d1[0] \n" + "vmlal.s16 q14, d4, d1[1] \n" + "vmlal.s16 q13, d5, d1[0] \n" + "vmlal.s16 q15, d5, d1[1] \n" + "vmlal.s16 q4, d6, d1[2] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d1[3] \n" + "vmlal.s16 q5, d7, d1[2] \n" + "vmlal.s16 q7, d7, d1[3] \n" + "vmlal.s16 q8, d6, d2[0] \n" + "pld [%2, #384] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q10, d6, d2[1] \n" + "vmlal.s16 q12, d6, d2[2] \n" + "vmlal.s16 q14, d6, d2[3] \n" + "vmlal.s16 q9, d7, d2[0] \n" + "vmlal.s16 q11, d7, d2[1] \n" + "vmlal.s16 q13, d7, d2[2] \n" + "vmlal.s16 q15, d7, d2[3] \n" + "vmlal.s16 q4, d4, d3[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d3[1] \n" + "vmlal.s16 q8, d4, d3[2] \n" + "vmlal.s16 q10, d4, d3[3] \n" + "vmlal.s16 q5, d5, d3[0] \n" + "vmlal.s16 q7, d5, d3[1] \n" + "vmlal.s16 q9, d5, d3[2] \n" + "vmlal.s16 q11, d5, d3[3] \n" + "vmlal.s16 q12, d4, d0[0] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q14, d4, d0[1] \n" + "vmlal.s16 q13, d5, d0[0] \n" + "vmlal.s16 q15, d5, d0[1] \n" + "vmlal.s16 q4, d6, d0[2] \n" + "pld [%1, #512] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d0[3] \n" + "vmlal.s16 q5, d7, d0[2] \n" + "vmlal.s16 q7, d7, d0[3] \n" + "vmlal.s16 q8, d6, d1[0] \n" + "vmlal.s16 q10, d6, d1[1] \n" + "vmlal.s16 q12, d6, d1[2] \n" + "vmlal.s16 q14, d6, d1[3] \n" + "vmlal.s16 q9, d7, d1[0] \n" + "vmlal.s16 q11, d7, d1[1] \n" + "vmlal.s16 q13, d7, d1[2] \n" + "vmlal.s16 q15, d7, d1[3] \n" + "vmlal.s16 q4, d4, d2[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d2[1] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q8, d4, d2[2] \n" + "vmlal.s16 q10, d4, d2[3] \n" + "vmlal.s16 q5, d5, d2[0] \n" + "vmlal.s16 q7, d5, d2[1] \n" + "vmlal.s16 q9, d5, d2[2] \n" + "vmlal.s16 q11, d5, d2[3] \n" + "vmlal.s16 q12, d4, d3[0] \n" + "vmlal.s16 q14, d4, d3[1] \n" + "vmlal.s16 q13, d5, d3[0] \n" + "vmlal.s16 q15, d5, d3[1] \n" + "vmlal.s16 q4, d6, d3[2] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d3[3] \n" + "vmlal.s16 q5, d7, d3[2] \n" + "vmlal.s16 q7, d7, d3[3] \n" + "vmlal.s16 q8, d6, d0[0] \n" + "pld [%2, #384] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q10, d6, d0[1] \n" + "vmlal.s16 q12, d6, d0[2] \n" + "vmlal.s16 q14, d6, d0[3] \n" + "vmlal.s16 q9, d7, d0[0] \n" + "vmlal.s16 q11, d7, d0[1] \n" + "vmlal.s16 q13, d7, d0[2] \n" + "vmlal.s16 q15, d7, d0[3] \n" + "vmlal.s16 q4, d4, d1[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q6, d4, d1[1] \n" + "vmlal.s16 q8, d4, d1[2] \n" + "vmlal.s16 q10, d4, d1[3] \n" + "vmlal.s16 q5, d5, d1[0] \n" + "vmlal.s16 q7, d5, d1[1] \n" + "vmlal.s16 q9, d5, d1[2] \n" + "vmlal.s16 q11, d5, d1[3] \n" + "vmlal.s16 q12, d4, d2[0] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q14, d4, d2[1] \n" + "vmlal.s16 q13, d5, d2[0] \n" + "vmlal.s16 q15, d5, d2[1] \n" + "vmlal.s16 q4, d6, d2[2] \n" + "pld [%1, #512] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q6, d6, d2[3] \n" + "vmlal.s16 q5, d7, d2[2] \n" + "vmlal.s16 q7, d7, d2[3] \n" + "vmlal.s16 q8, d6, d3[0] \n" + "vmlal.s16 q10, d6, d3[1] \n" + "vmlal.s16 q12, d6, d3[2] \n" + "vmlal.s16 q14, d6, d3[3] \n" + "vmlal.s16 q9, d7, d3[0] \n" + "vmlal.s16 q11, d7, d3[1] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d7, d3[2] \n" + "vmlal.s16 q15, d7, d3[3] \n" + "bne 2b \n" + "sub %1, %1, #16 \n" + "sub %2, %2, #16 \n" + + "3: \n" + "and r4, %6, #7 \n" // w4 = remain = max_kk & 7 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2-d3}, [%2] \n" + "add %2, %2, #12 \n" + "vmlal.s16 q4, d0, d2[0] \n" + "vmlal.s16 q6, d0, d2[1] \n" + "vmlal.s16 q8, d0, d2[2] \n" + "vmlal.s16 q10, d0, d2[3] \n" + "vmlal.s16 q5, d1, d2[0] \n" + "vmlal.s16 q7, d1, d2[1] \n" + "vmlal.s16 q9, d1, d2[2] \n" + "vmlal.s16 q11, d1, d2[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q12, d0, d3[0] \n" + "vmlal.s16 q14, d0, d3[1] \n" + "vmlal.s16 q13, d1, d3[0] \n" + "vmlal.s16 q15, d1, d3[1] \n" + "bne 4b \n" + + "5: \n" + "vstm %0!, {d8-d15} \n" + "vstm %0!, {d16-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_pA), vget_low_s16(_pB), 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_pA), vget_low_s16(_pB), 0); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_pA), vget_low_s16(_pB), 1); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_pA), vget_low_s16(_pB), 1); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_pA), vget_low_s16(_pB), 2); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_pA), vget_low_s16(_pB), 2); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_pA), vget_low_s16(_pB), 3); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_pA), vget_low_s16(_pB), 3); + _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_pA), vget_high_s16(_pB), 0); + _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_pA), vget_high_s16(_pB), 0); + _suma = vmlal_lane_s16(_suma, vget_low_s16(_pA), vget_high_s16(_pB), 1); + _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_pA), vget_high_s16(_pB), 1); + pA += 8; + pB += 6; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + outptr += 48; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #512] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + ".align 4 \n" + "2: \n" + "smlal v24.4s, v4.4h, v0.h[0] \n" + "smlal v26.4s, v4.4h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v25.4s, v4.8h, v0.h[0] \n" + "smlal2 v27.4s, v4.8h, v0.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%2], #32 \n" + "smlal v28.4s, v4.4h, v0.h[2] \n" + "smlal v30.4s, v4.4h, v0.h[3] \n" + "smlal2 v29.4s, v4.8h, v0.h[2] \n" + "smlal2 v31.4s, v4.8h, v0.h[3] \n" + "smlal v24.4s, v5.4h, v0.h[4] \n" + "smlal v26.4s, v5.4h, v0.h[5] \n" + "smlal2 v25.4s, v5.8h, v0.h[4] \n" + "smlal2 v27.4s, v5.8h, v0.h[5] \n" + "smlal v28.4s, v5.4h, v0.h[6] \n" + "smlal v30.4s, v5.4h, v0.h[7] \n" + "smlal2 v29.4s, v5.8h, v0.h[6] \n" + "smlal2 v31.4s, v5.8h, v0.h[7] \n" + "smlal v24.4s, v6.4h, v1.h[0] \n" + "smlal v26.4s, v6.4h, v1.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v25.4s, v6.8h, v1.h[0] \n" + "smlal2 v27.4s, v6.8h, v1.h[1] \n" + "smlal v28.4s, v6.4h, v1.h[2] \n" + "smlal v30.4s, v6.4h, v1.h[3] \n" + "smlal2 v29.4s, v6.8h, v1.h[2] \n" + "smlal2 v31.4s, v6.8h, v1.h[3] \n" + "smlal v24.4s, v7.4h, v1.h[4] \n" + "smlal v26.4s, v7.4h, v1.h[5] \n" + "smlal2 v25.4s, v7.8h, v1.h[4] \n" + "smlal2 v27.4s, v7.8h, v1.h[5] \n" + "smlal v28.4s, v7.4h, v1.h[6] \n" + "smlal v30.4s, v7.4h, v1.h[7] \n" + "smlal2 v29.4s, v7.8h, v1.h[6] \n" + "smlal2 v31.4s, v7.8h, v1.h[7] \n" + "smlal v24.4s, v4.4h, v2.h[0] \n" + "smlal v26.4s, v4.4h, v2.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v25.4s, v4.8h, v2.h[0] \n" + "smlal2 v27.4s, v4.8h, v2.h[1] \n" + "prfm pldl1keep, [%2, #512] \n" + "ld1 {v0.8h, v1.8h}, [%2], #32 \n" + "smlal v28.4s, v4.4h, v2.h[2] \n" + "smlal v30.4s, v4.4h, v2.h[3] \n" + "smlal2 v29.4s, v4.8h, v2.h[2] \n" + "smlal2 v31.4s, v4.8h, v2.h[3] \n" + "smlal v24.4s, v5.4h, v2.h[4] \n" + "smlal v26.4s, v5.4h, v2.h[5] \n" + "smlal2 v25.4s, v5.8h, v2.h[4] \n" + "smlal2 v27.4s, v5.8h, v2.h[5] \n" + "smlal v28.4s, v5.4h, v2.h[6] \n" + "smlal v30.4s, v5.4h, v2.h[7] \n" + "smlal2 v29.4s, v5.8h, v2.h[6] \n" + "smlal2 v31.4s, v5.8h, v2.h[7] \n" + "smlal v24.4s, v6.4h, v3.h[0] \n" + "smlal v26.4s, v6.4h, v3.h[1] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v25.4s, v6.8h, v3.h[0] \n" + "smlal2 v27.4s, v6.8h, v3.h[1] \n" + "smlal v28.4s, v6.4h, v3.h[2] \n" + "smlal v30.4s, v6.4h, v3.h[3] \n" + "smlal2 v29.4s, v6.8h, v3.h[2] \n" + "smlal2 v31.4s, v6.8h, v3.h[3] \n" + "smlal v24.4s, v7.4h, v3.h[4] \n" + "smlal v26.4s, v7.4h, v3.h[5] \n" + "smlal2 v25.4s, v7.8h, v3.h[4] \n" + "smlal2 v27.4s, v7.8h, v3.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v7.4h, v3.h[6] \n" + "smlal v30.4s, v7.4h, v3.h[7] \n" + "smlal2 v29.4s, v7.8h, v3.h[6] \n" + "smlal2 v31.4s, v7.8h, v3.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #32 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.4h}, [%2], #8 \n" + "smlal v24.4s, v4.4h, v0.h[0] \n" + "smlal v26.4s, v4.4h, v0.h[1] \n" + "smlal2 v25.4s, v4.8h, v0.h[0] \n" + "smlal2 v27.4s, v4.8h, v0.h[1] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v4.4h, v0.h[2] \n" + "smlal v30.4s, v4.4h, v0.h[3] \n" + "smlal2 v29.4s, v4.8h, v0.h[2] \n" + "smlal2 v31.4s, v4.8h, v0.h[3] \n" + "bne 4b \n" + + "5: \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #256] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0, {d16-d31} \n" + "b 1f \n" + + "0: \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d4-d5}, [%1]! \n" + "vld1.s16 {d0-d1}, [%2]! \n" + ".align 4 \n" + "2: \n" + "vmlal.s16 q8, d4, d0[0] \n" + "vld1.s16 {d6-d7}, [%1]! \n" + "vmlal.s16 q10, d4, d0[1] \n" + "vmlal.s16 q12, d4, d0[2] \n" + "vmlal.s16 q14, d4, d0[3] \n" + "vmlal.s16 q9, d5, d0[0] \n" + "vld1.s16 {d8-d9}, [%1]! \n" + "vmlal.s16 q11, d5, d0[1] \n" + "vld1.s16 {d2-d3}, [%2]! \n" + "vmlal.s16 q13, d5, d0[2] \n" + "vmlal.s16 q15, d5, d0[3] \n" + "vmlal.s16 q8, d6, d1[0] \n" + "vmlal.s16 q10, d6, d1[1] \n" + "vmlal.s16 q12, d6, d1[2] \n" + "vmlal.s16 q14, d6, d1[3] \n" + "vmlal.s16 q9, d7, d1[0] \n" + "vld1.s16 {d10-d11}, [%1]! \n" + "vmlal.s16 q11, d7, d1[1] \n" + "vmlal.s16 q13, d7, d1[2] \n" + "vmlal.s16 q15, d7, d1[3] \n" + "vmlal.s16 q8, d8, d2[0] \n" + "vmlal.s16 q10, d8, d2[1] \n" + "vmlal.s16 q12, d8, d2[2] \n" + "vmlal.s16 q14, d8, d2[3] \n" + "vmlal.s16 q9, d9, d2[0] \n" + "pld [%1, #512] \n" + "vld1.s16 {d4-d5}, [%1]! \n" + "vmlal.s16 q11, d9, d2[1] \n" + "pld [%2, #256] \n" + "vld1.s16 {d0-d1}, [%2]! \n" + "vmlal.s16 q13, d9, d2[2] \n" + "vmlal.s16 q15, d9, d2[3] \n" + "vmlal.s16 q8, d10, d3[0] \n" + "vmlal.s16 q10, d10, d3[1] \n" + "vmlal.s16 q12, d10, d3[2] \n" + "vmlal.s16 q14, d10, d3[3] \n" + "vmlal.s16 q9, d11, d3[0] \n" + "vmlal.s16 q11, d11, d3[1] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d11, d3[2] \n" + "vmlal.s16 q15, d11, d3[3] \n" + "bne 2b \n" + "sub %1, %1, #16 \n" + "sub %2, %2, #16 \n" + + "3: \n" + "and r4, %6, #3 \n" // w4 = remain = max_kk & 3 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2}, [%2]! \n" + "vmlal.s16 q8, d0, d2[0] \n" + "vmlal.s16 q10, d0, d2[1] \n" + "vmlal.s16 q12, d0, d2[2] \n" + "vmlal.s16 q14, d0, d2[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q9, d1, d2[0] \n" + "vmlal.s16 q11, d1, d2[1] \n" + "vmlal.s16 q13, d1, d2[2] \n" + "vmlal.s16 q15, d1, d2[3] \n" + "bne 4b \n" + + "5: \n" + "vstm %0!, {d16-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_pA), _pB, 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_pA), _pB, 0); + _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_pA), _pB, 1); + _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_pA), _pB, 1); + _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_pA), _pB, 2); + _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_pA), _pB, 2); + _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_pA), _pB, 3); + _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_pA), _pB, 3); + pA += 8; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + outptr += 32; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #256] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "b 1f \n" + + "0: \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v0.8h}, [%2], #16 \n" + ".align 4 \n" + "2: \n" + "smlal v28.4s, v4.4h, v0.h[0] \n" + "smlal v30.4s, v4.4h, v0.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v0.h[0] \n" + "smlal2 v31.4s, v4.8h, v0.h[1] \n" + "ld1 {v1.8h}, [%2], #16 \n" + "smlal v28.4s, v5.4h, v0.h[2] \n" + "smlal v30.4s, v5.4h, v0.h[3] \n" + "smlal2 v29.4s, v5.8h, v0.h[2] \n" + "smlal2 v31.4s, v5.8h, v0.h[3] \n" + "smlal v28.4s, v6.4h, v0.h[4] \n" + "smlal v30.4s, v6.4h, v0.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v0.h[4] \n" + "smlal2 v31.4s, v6.8h, v0.h[5] \n" + "smlal v28.4s, v7.4h, v0.h[6] \n" + "smlal v30.4s, v7.4h, v0.h[7] \n" + "smlal2 v29.4s, v7.8h, v0.h[6] \n" + "smlal2 v31.4s, v7.8h, v0.h[7] \n" + "smlal v28.4s, v4.4h, v1.h[0] \n" + "smlal v30.4s, v4.4h, v1.h[1] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v1.h[0] \n" + "smlal2 v31.4s, v4.8h, v1.h[1] \n" + "prfm pldl1keep, [%2, #256] \n" + "ld1 {v0.8h}, [%2], #16 \n" + "smlal v28.4s, v5.4h, v1.h[2] \n" + "smlal v30.4s, v5.4h, v1.h[3] \n" + "smlal2 v29.4s, v5.8h, v1.h[2] \n" + "smlal2 v31.4s, v5.8h, v1.h[3] \n" + "smlal v28.4s, v6.4h, v1.h[4] \n" + "smlal v30.4s, v6.4h, v1.h[5] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v1.h[4] \n" + "smlal2 v31.4s, v6.8h, v1.h[5] \n" + "subs w4, w4, #1 \n" + "smlal v28.4s, v7.4h, v1.h[6] \n" + "smlal v30.4s, v7.4h, v1.h[7] \n" + "smlal2 v29.4s, v7.8h, v1.h[6] \n" + "smlal2 v31.4s, v7.8h, v1.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #16 \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1 {v0.4h}, [%2] \n" + "add %2, %2, #4 \n" + "smlal v28.4s, v4.4h, v0.h[0] \n" + "smlal v30.4s, v4.4h, v0.h[1] \n" + "subs w4, w4, #1 \n" + "smlal2 v29.4s, v4.8h, v0.h[0] \n" + "smlal2 v31.4s, v4.8h, v0.h[1] \n" + "bne 4b \n" + + "5: \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #128] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0, {d24-d31} \n" + "b 1f \n" + + "0: \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d2-d5}, [%1]! \n" + "vld1.s16 {d0}, [%2]! \n" + ".align 4 \n" + "2: \n" + "vmlal.s16 q12, d2, d0[0] \n" + "vld1.s16 {d6-d9}, [%1]! \n" + "vmlal.s16 q14, d2, d0[1] \n" + "vld1.s16 {d1}, [%2]! \n" + "vmlal.s16 q13, d3, d0[0] \n" + "vmlal.s16 q15, d3, d0[1] \n" + "vmlal.s16 q12, d4, d0[2] \n" + "vmlal.s16 q14, d4, d0[3] \n" + "vmlal.s16 q13, d5, d0[2] \n" + "vmlal.s16 q15, d5, d0[3] \n" + "vmlal.s16 q12, d6, d1[0] \n" + "pld [%1, #512] \n" + "vld1.s16 {d2-d5}, [%1]! \n" + "vmlal.s16 q14, d6, d1[1] \n" + "pld [%2, #128] \n" + "vld1.s16 {d0}, [%2]! \n" + "vmlal.s16 q13, d7, d1[0] \n" + "vmlal.s16 q15, d7, d1[1] \n" + "vmlal.s16 q12, d8, d1[2] \n" + "vmlal.s16 q14, d8, d1[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d9, d1[2] \n" + "vmlal.s16 q15, d9, d1[3] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #8 \n" + + "3: \n" + "and r4, %6, #3 \n" // w4 = remain = max_kk & 3 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2}, [%2] \n" + "add %2, %2, #4 \n" + "vmlal.s16 q12, d0, d2[0] \n" + "vmlal.s16 q14, d0, d2[1] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q13, d1, d2[0] \n" + "vmlal.s16 q15, d1, d2[1] \n" + "bne 4b \n" + + "5: \n" + "vstm %0!, {d24-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x4_t _pB0 = vdup_n_s16(pB[0]); + int16x4_t _pB1 = vdup_n_s16(pB[1]); + _sum0 = vmlal_s16(_sum0, vget_low_s16(_pA), _pB0); + _sum1 = vmlal_s16(_sum1, vget_high_s16(_pA), _pB0); + _sum2 = vmlal_s16(_sum2, vget_low_s16(_pA), _pB1); + _sum3 = vmlal_s16(_sum3, vget_high_s16(_pA), _pB1); + pA += 8; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + outptr += 16; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%1, #512] \n" + "prfm pldl1keep, [%2, #128] \n" + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v30.4s, v31.4s}, [%0] \n" + "b 1f \n" + + "0: \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 3f \n" + + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "ld1 {v1.8h}, [%2], #16 \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + ".align 4 \n" + "2: \n" + "mov v0.16b, v1.16b \n" + "smlal v28.4s, v4.4h, v0.h[0] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v0.h[0] \n" + "prfm pldl1keep, [%2, #128] \n" + "ld1 {v1.8h}, [%2], #16 \n" + "smlal v30.4s, v5.4h, v0.h[1] \n" + "smlal2 v31.4s, v5.8h, v0.h[1] \n" + "smlal v28.4s, v6.4h, v0.h[2] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v0.h[2] \n" + "smlal v30.4s, v7.4h, v0.h[3] \n" + "smlal2 v31.4s, v7.8h, v0.h[3] \n" + "smlal v28.4s, v4.4h, v0.h[4] \n" + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" + "smlal2 v29.4s, v4.8h, v0.h[4] \n" + "smlal v30.4s, v5.4h, v0.h[5] \n" + "smlal2 v31.4s, v5.8h, v0.h[5] \n" + "smlal v28.4s, v6.4h, v0.h[6] \n" + "prfm pldl1keep, [%1, #512] \n" + "ld1 {v4.8h, v5.8h}, [%1], #32 \n" + "smlal2 v29.4s, v6.8h, v0.h[6] \n" + "subs w4, w4, #1 \n" + "smlal v30.4s, v7.4h, v0.h[7] \n" + "smlal2 v31.4s, v7.8h, v0.h[7] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + "sub %2, %2, #16 \n" + "add v30.4s, v30.4s, v28.4s \n" + "add v31.4s, v31.4s, v29.4s \n" + + "3: \n" + "and w4, %w6, #7 \n" // w4 = remain = max_kk & 7 + "cmp w4, #0 \n" + "beq 5f \n" + + "4: \n" + "ld1 {v4.8h}, [%1], #16 \n" + "ld1r {v0.4h}, [%2], #2 \n" + "subs w4, w4, #1 \n" + "smlal v30.4s, v4.4h, v0.h[0] \n" + "smlal2 v31.4s, v4.8h, v0.h[0] \n" + "bne 4b \n" + + "5: \n" + "st1 {v30.4s, v31.4s}, [%0], #32 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "pld [%1, #512] \n" + "pld [%2, #64] \n" + "cmp %7, #0 \n" + "beq 0f \n" + + "vld1.s32 {d28-d31}, [%0] \n" + "b 1f \n" + + "0: \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + "vld1.s16 {d2-d5}, [%1]! \n" + ".align 4 \n" + "2: \n" + "pld [%2, #64] \n" + "vld1.s16 {d0}, [%2]! \n" + "vmlal.s16 q14, d2, d0[0] \n" + "vld1.s16 {d6-d9}, [%1]! \n" + "vmlal.s16 q15, d3, d0[0] \n" + "vmlal.s16 q14, d4, d0[1] \n" + "vmlal.s16 q15, d5, d0[1] \n" + "vmlal.s16 q14, d6, d0[2] \n" + "pld [%1, #512] \n" + "vld1.s16 {d2-d5}, [%1]! \n" + "vmlal.s16 q15, d7, d0[2] \n" + "vmlal.s16 q14, d8, d0[3] \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q15, d9, d0[3] \n" + "bne 2b \n" + "sub %1, %1, #32 \n" + + "3: \n" + "and r4, %6, #3 \n" // w4 = remain = max_kk & 3 + "cmp r4, #0 \n" + "beq 5f \n" + + "4: \n" + "vld1.s16 {d0-d1}, [%1]! \n" + "vld1.s16 {d2[]}, [%2]! \n" + "subs r4, r4, #1 \n" + "vmlal.s16 q14, d0, d2[0] \n" + "vmlal.s16 q15, d1, d2[0] \n" + "bne 4b \n" + + "5: \n" + "vst1.s32 {d28-d31}, [%0]! \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x8_t _pA = vld1q_s16(pA); + int16x4_t _pB = vld1_dup_s16(pB); + _sum0 = vmlal_s16(_sum0, vget_low_s16(_pA), _pB); + _sum1 = vmlal_s16(_sum1, vget_high_s16(_pA), _pB); + pA += 8; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + outptr += 8; +#endif // NCNN_GNU_INLINE_ASM + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_laneq_s16(_sum0, _pA, _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, _pA, _pB, 1); + _sum2 = vmlal_laneq_s16(_sum2, _pA, _pB, 2); + _sum3 = vmlal_laneq_s16(_sum3, _pA, _pB, 3); + _sum4 = vmlal_laneq_s16(_sum4, _pA, _pB, 4); + _sum5 = vmlal_laneq_s16(_sum5, _pA, _pB, 5); + _sum6 = vmlal_laneq_s16(_sum6, _pA, _pB, 6); + _sum7 = vmlal_laneq_s16(_sum7, _pA, _pB, 7); + _sum8 = vmlal_lane_s16(_sum8, _pA, _pB2, 0); + _sum9 = vmlal_lane_s16(_sum9, _pA, _pB2, 1); + _suma = vmlal_lane_s16(_suma, _pA, _pB2, 2); + _sumb = vmlal_lane_s16(_sumb, _pA, _pB2, 3); + pA += 4; + pB += 12; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + outptr += 48; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_laneq_s16(_sum0, _pA, _pB, 0); + _sum1 = vmlal_laneq_s16(_sum1, _pA, _pB, 1); + _sum2 = vmlal_laneq_s16(_sum2, _pA, _pB, 2); + _sum3 = vmlal_laneq_s16(_sum3, _pA, _pB, 3); + _sum4 = vmlal_laneq_s16(_sum4, _pA, _pB, 4); + _sum5 = vmlal_laneq_s16(_sum5, _pA, _pB, 5); + _sum6 = vmlal_laneq_s16(_sum6, _pA, _pB, 6); + _sum7 = vmlal_laneq_s16(_sum7, _pA, _pB, 7); + pA += 4; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + outptr += 32; + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, _pA, vget_low_s16(_pB), 0); + _sum1 = vmlal_lane_s16(_sum1, _pA, vget_low_s16(_pB), 1); + _sum2 = vmlal_lane_s16(_sum2, _pA, vget_low_s16(_pB), 2); + _sum3 = vmlal_lane_s16(_sum3, _pA, vget_low_s16(_pB), 3); + _sum4 = vmlal_lane_s16(_sum4, _pA, vget_high_s16(_pB), 0); + _sum5 = vmlal_lane_s16(_sum5, _pA, vget_high_s16(_pB), 1); + pA += 4; + pB += 6; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + outptr += 24; + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_lane_s16(_sum0, _pA, _pB, 0); + _sum1 = vmlal_lane_s16(_sum1, _pA, _pB, 1); + _sum2 = vmlal_lane_s16(_sum2, _pA, _pB, 2); + _sum3 = vmlal_lane_s16(_sum3, _pA, _pB, 3); + pA += 4; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x4_t _pB0 = vdup_n_s16(pB[0]); + int16x4_t _pB1 = vdup_n_s16(pB[1]); + _sum0 = vmlal_s16(_sum0, _pA, _pB0); + _sum1 = vmlal_s16(_sum1, _pA, _pB1); + pA += 4; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + outptr += 8; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_s16(pA); + int16x4_t _pB = vld1_dup_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, _pB); + pA += 4; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + outptr += 4; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + int32x4x2_t _s23 = vld2q_s32(outptr + 8); + int32x4x2_t _s45 = vld2q_s32(outptr + 16); + _sum0 = _s01.val[0]; + _sum3 = _s01.val[1]; + _sum1 = _s23.val[0]; + _sum4 = _s23.val[1]; + _sum2 = _s45.val[0]; + _sum5 = _s45.val[1]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA0 = vdup_n_s16(pA[0]); + int16x4_t _pA1 = vdup_n_s16(pA[1]); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_s16(_sum0, _pA0, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA0, vget_high_s16(_pB)); + _sum2 = vmlal_s16(_sum2, _pA0, _pB2); + _sum3 = vmlal_s16(_sum3, _pA1, vget_low_s16(_pB)); + _sum4 = vmlal_s16(_sum4, _pA1, vget_high_s16(_pB)); + _sum5 = vmlal_s16(_sum5, _pA1, _pB2); + pA += 2; + pB += 12; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum3; + int32x4x2_t _s23; + _s23.val[0] = _sum1; + _s23.val[1] = _sum4; + int32x4x2_t _s45; + _s45.val[0] = _sum2; + _s45.val[1] = _sum5; + vst2q_s32(outptr, _s01); + vst2q_s32(outptr + 8, _s23); + vst2q_s32(outptr + 16, _s45); + outptr += 24; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + int32x4x2_t _s23 = vld2q_s32(outptr + 8); + _sum0 = _s01.val[0]; + _sum2 = _s01.val[1]; + _sum1 = _s23.val[0]; + _sum3 = _s23.val[1]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA0 = vdup_n_s16(pA[0]); + int16x4_t _pA1 = vdup_n_s16(pA[1]); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA0, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA0, vget_high_s16(_pB)); + _sum2 = vmlal_s16(_sum2, _pA1, vget_low_s16(_pB)); + _sum3 = vmlal_s16(_sum3, _pA1, vget_high_s16(_pB)); + pA += 2; + pB += 8; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum2; + int32x4x2_t _s23; + _s23.val[0] = _sum1; + _s23.val[1] = _sum3; + vst2q_s32(outptr, _s01); + vst2q_s32(outptr + 8, _s23); + outptr += 16; + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + _sum0 = _s01.val[0]; + _sum1 = _s01.val[1]; + _sum2 = vld1q_s32(outptr + 8); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vreinterpret_s16_s32(vld1_dup_s32((const int*)pA)); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vzip_s16(vget_high_s16(_pB), vget_high_s16(_pB)).val[0]; + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_pB), _pA, 0); + _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_pB), _pA, 1); + _sum2 = vmlal_s16(_sum2, _pA, _pB2); + pA += 2; + pB += 6; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum1; + vst2q_s32(outptr, _s01); + vst1q_s32(outptr + 8, _sum2); + outptr += 12; + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + int32x4x2_t _s01 = vld2q_s32(outptr); + _sum0 = _s01.val[0]; + _sum1 = _s01.val[1]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA0 = vdup_n_s16(pA[0]); + int16x4_t _pA1 = vdup_n_s16(pA[1]); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA0, _pB); + _sum1 = vmlal_s16(_sum1, _pA1, _pB); + pA += 2; + pB += 4; + } + + int32x4x2_t _s01; + _s01.val[0] = _sum0; + _s01.val[1] = _sum1; + vst2q_s32(outptr, _s01); + outptr += 8; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum00 = 0; + int sum01 = 0; + int sum10 = 0; + int sum11 = 0; + + if (k == 0) + { + sum00 = 0; + sum01 = 0; + sum10 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + // fomit-frame-pointer implied in optimized flag spare one register + // let us stay away from error: ‘asm’ operand has impossible constraints --- nihui +#if __OPTIMIZE__ + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA0 = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%0], #4 \n" // int16x2_t _pA1 = *((int16x2_t*)pA); pA += 2; + "ldr r4, [%1], #4 \n" // int16x2_t _pB0 = *((int16x2_t*)pB); pB += 2; + "ldr r5, [%1], #4 \n" // int16x2_t _pB1 = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_pA0, _pB0, sum00); + "smlad %3, r3, r4, %3 \n" // sum01 = __smlad(_pA1, _pB0, sum01); + "smlad %4, r2, r5, %4 \n" // sum10 = __smlad(_pA0, _pB1, sum10); + "smlad %5, r3, r5, %5 \n" // sum11 = __smlad(_pA1, _pB1, sum11); + : "=r"(pA), + "=r"(pB), + "=r"(sum00), + "=r"(sum01), + "=r"(sum10), + "=r"(sum11) + : "0"(pA), + "1"(pB), + "2"(sum00), + "3"(sum01), + "4"(sum10), + "5"(sum11) + : "memory", "r2", "r3", "r4", "r5"); +#else + int _pA0 = *((int*)pA); + int _pA1 = *((int*)(pA + 2)); + int _pB0 = *((int*)pB); + int _pB1 = *((int*)(pB + 2)); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum00) + : "0"(sum00), "r"(_pA0), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum01) + : "0"(sum01), "r"(_pA1), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum10) + : "0"(sum10), "r"(_pA0), "r"(_pB1) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum11) + : "0"(sum11), "r"(_pA1), "r"(_pB1) + :); + pA += 4; + pB += 4; +#endif + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + outptr += 2 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA0 = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%0], #4 \n" // int16x2_t _pA1 = *((int16x2_t*)pA); pA += 2; + "ldr r4, [%1], #4 \n" // int16x2_t _pB = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r4, %2 \n" // sum0 = __smlad(_pA0, _pB, sum0); + "smlad %3, r3, r4, %3 \n" // sum1 = __smlad(_pA1, _pB, sum1); + : "=r"(pA), + "=r"(pB), + "=r"(sum0), + "=r"(sum1) + : "0"(pA), + "1"(pB), + "2"(sum0), + "3"(sum1) + : "memory", "r2", "r3", "r4"); + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + } + } + for (; ii < max_ii; ii++) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 11 < max_jj; jj += 12) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + int16x4_t _pB2 = vld1_s16(pB + 8); + _sum0 = vmlal_s16(_sum0, _pA, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA, vget_high_s16(_pB)); + _sum2 = vmlal_s16(_sum2, _pA, _pB2); + pA += 1; + pB += 12; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + outptr += 12; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA, vget_high_s16(_pB)); + pA += 1; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + outptr += 8; + } +#endif // __aarch64__ + for (; jj + 5 < max_jj; jj += 6) + { + const short* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x8_t _pB = vld1q_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, vget_low_s16(_pB)); + _sum1 = vmlal_s16(_sum1, _pA, vget_high_s16(_pB)); + pA += 1; + pB += 6; + } + + vst1q_s32(outptr, _sum0); + vst1_s32(outptr + 4, vget_low_s32(_sum1)); + outptr += 6; + } + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + int16x4_t _pA = vld1_dup_s16(pA); + int16x4_t _pB = vld1_s16(pB); + _sum0 = vmlal_s16(_sum0, _pA, _pB); + pA += 1; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + outptr += 4; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%1], #4 \n" // int16x2_t _pB0 = *((int16x2_t*)pB); pB += 2; + "ldr r4, [%1], #4 \n" // int16x2_t _pB1 = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r3, %2 \n" // sum0 = __smlad(_pA, _pB0, sum0); + "smlad %3, r2, r4, %3 \n" // sum1 = __smlad(_pA, _pB1, sum1); + : "=r"(pA), + "=r"(pB), + "=r"(sum0), + "=r"(sum1) + : "0"(pA), + "1"(pB), + "2"(sum0), + "3"(sum1) + : "memory", "r2", "r3", "r4"); + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum = 0; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + int kk = 0; +#if !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + asm volatile( + "ldr r2, [%0], #4 \n" // int16x2_t _pA = *((int16x2_t*)pA); pA += 2; + "ldr r3, [%1], #4 \n" // int16x2_t _pB = *((int16x2_t*)pB); pB += 2; + "smlad %2, r2, r3, %2 \n" // sum = __smlad(_pA, _pB, sum); + : "=r"(pA), + "=r"(pB), + "=r"(sum) + : "0"(pA), + "1"(pB), + "2"(sum) + : "memory", "r2", "r3"); + } +#endif // !__ARM_NEON && __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk++) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + outptr += 1; + } + } + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const int l2_cache_size_int8 = (int)(get_cpu_level2_cache_size() / sizeof(short)); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + // we shall take B into account for batched gemm, but that will be slower on arm in practice, why ? + // (void)B; + + // solve K + { + // try not to split K +#if __aarch64__ + int tile_size = (l2_cache_size_int8 - 32) / 12; +#elif __ARM_NEON + int tile_size = (l2_cache_size_int8 - 32) / 6; +#else + int tile_size = (l2_cache_size_int8 - 2) / 3; +#endif + +#if __aarch64__ + TILE_K = std::max(8, tile_size / 8 * 8); +#elif __ARM_NEON + TILE_K = std::max(4, tile_size / 4 * 4); +#else + TILE_K = std::max(2, tile_size / 2 * 2); +#endif + + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __aarch64__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __ARM_NEON + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + } + + // solve M + { +#if __ARM_NEON + TILE_M = 8; +#else + TILE_M = 2; +#endif + } + + { + TILE_M *= std::min(nT, get_physical_cpu_count()); + + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __ARM_NEON + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + + if (nT > 1) + { +#if __ARM_NEON + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + +#if __ARM_NEON + TILE_M = std::max(8, TILE_M); +#else + TILE_M = std::max(2, TILE_M); +#endif + } + + if (N > 0) + { + int tile_size; + if (TILE_K >= K) + { + tile_size = (l2_cache_size_int8 - TILE_M * TILE_K) / TILE_K; + } + else + { + tile_size = (l2_cache_size_int8 - TILE_M * TILE_K) / (TILE_M * 2 + TILE_K); + } + +#if __aarch64__ + TILE_N = std::max(4, tile_size / 4 * 4); +#elif __ARM_NEON + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_N = std::max(1, tile_size); +#endif + + int nn_N = (N + TILE_N - 1) / TILE_N; + +#if __aarch64__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#elif __ARM_NEON + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + +#if __aarch64__ + TILE_N = std::max(4, TILE_N); +#elif __ARM_NEON + TILE_N = std::max(4, TILE_N); +#else + TILE_N = std::max(1, TILE_N); +#endif + } +} + +static inline void conv3x3s1_winograd23_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const signed char ktm[4][3] = { + // {2, 0, 0}, + // {1, 1, 1}, + // {1, -1, 1}, + // {0, 0, 2} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[4][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 2; + tmp[1][m] = r0 + r1 + r2; + tmp[2][m] = r0 - r1 + r2; + tmp[3][m] = r2 * 2; + + k0 += 3; + } + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 2; + short z1 = r0 + r1 + r2; + short z2 = r0 - r1 + r2; + short z3 = r2 * 2; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp += 4; + } + } + } +} + +static void conv3x3s1_winograd23_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ + const int M = outch; + const int K = inch; + const int B = 16; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 2u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd23_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd23_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const signed char itm[4][4] = { + // {1, 0, -1, 0}, + // {0, 1, 1, 0}, + // {0, -1, 1, 0}, + // {0, -1, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w - 1) / 2; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __ARM_NEON + nn_max_kk = max_kk / 8; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + + short tmp[4][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + int8x8_t _r0 = vdup_n_s8(0); + int8x8_t _r1 = vdup_n_s8(0); + int8x8_t _r2 = vdup_n_s8(0); + int8x8_t _r3 = vdup_n_s8(0); + + if (ti * 2 + m < h) + { + if (elempack == 8) + { + _r0 = vld1_s8(r0); + if (tj * 2 + 1 < w) _r1 = vld1_s8(r0 + 8); + if (tj * 2 + 2 < w) _r2 = vld1_s8(r0 + 16); + if (tj * 2 + 3 < w) _r3 = vld1_s8(r0 + 24); + } + if (elempack == 1) + { + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + int8x8_t _t0 = vld1_s8(r0); + int8x8_t _t1 = vld1_s8(r1); + int8x8_t _t2 = vld1_s8(r2); + int8x8_t _t3 = vld1_s8(r3); + int8x8_t _t4 = vld1_s8(r4); + int8x8_t _t5 = vld1_s8(r5); + int8x8_t _t6 = vld1_s8(r6); + int8x8_t _t7 = vld1_s8(r7); + + int8x8_t _t01 = vzip_s8(_t0, _t1).val[0]; + int8x8_t _t23 = vzip_s8(_t2, _t3).val[0]; + int8x8_t _t45 = vzip_s8(_t4, _t5).val[0]; + int8x8_t _t67 = vzip_s8(_t6, _t7).val[0]; + int16x4x2_t _t0123 = vzip_s16(vreinterpret_s16_s8(_t01), vreinterpret_s16_s8(_t23)); + int16x4x2_t _t4567 = vzip_s16(vreinterpret_s16_s8(_t45), vreinterpret_s16_s8(_t67)); + int16x8_t _ta = vcombine_s16(_t0123.val[0], _t0123.val[1]); + int16x8_t _tb = vcombine_s16(_t4567.val[0], _t4567.val[1]); + int32x4x2_t _tab = vzipq_s32(vreinterpretq_s32_s16(_ta), vreinterpretq_s32_s16(_tb)); + + _r0 = vreinterpret_s8_s32(vget_low_s32(_tab.val[0])); + if (tj * 2 + 1 < w) _r1 = vreinterpret_s8_s32(vget_high_s32(_tab.val[0])); + if (tj * 2 + 2 < w) _r2 = vreinterpret_s8_s32(vget_low_s32(_tab.val[1])); + if (tj * 2 + 3 < w) _r3 = vreinterpret_s8_s32(vget_high_s32(_tab.val[1])); + } + } + + int16x8_t _tmp0 = vsubl_s8(_r0, _r2); + int16x8_t _tmp1 = vaddl_s8(_r1, _r2); + int16x8_t _tmp2 = vsubl_s8(_r2, _r1); + int16x8_t _tmp3 = vsubl_s8(_r3, _r1); + + vst1q_s16(tmp[0][m], _tmp0); + vst1q_s16(tmp[1][m], _tmp1); + vst1q_s16(tmp[2][m], _tmp2); + vst1q_s16(tmp[3][m], _tmp3); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { + int16x8_t _r0 = vld1q_s16(tmp[m][0]); + int16x8_t _r1 = vld1q_s16(tmp[m][1]); + int16x8_t _r2 = vld1q_s16(tmp[m][2]); + int16x8_t _r3 = vld1q_s16(tmp[m][3]); + + int16x8_t _tmp0 = vsubq_s16(_r0, _r2); + int16x8_t _tmp1 = vaddq_s16(_r1, _r2); + int16x8_t _tmp2 = vsubq_s16(_r2, _r1); + int16x8_t _tmp3 = vsubq_s16(_r3, _r1); + + vst1q_s16(p0, _tmp0); + vst1q_s16(p1, _tmp1); + vst1q_s16(p2, _tmp2); + vst1q_s16(p3, _tmp3); + + p0 += max_jj * 4 * 8; + p1 += max_jj * 4 * 8; + p2 += max_jj * 4 * 8; + p3 += max_jj * 4 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __ARM_NEON + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __ARM_NEON + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[4][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 2 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 2 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 2 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + } + } + + tmp[0][m][0] = r00 - r20; + tmp[0][m][1] = r01 - r21; + tmp[1][m][0] = r10 + r20; + tmp[1][m][1] = r11 + r21; + tmp[2][m][0] = r20 - r10; + tmp[2][m][1] = r21 - r11; + tmp[3][m][0] = r30 - r10; + tmp[3][m][1] = r31 - r11; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + + p0[0] = r00 - r20; + p0[1] = r01 - r21; + p1[0] = r10 + r20; + p1[1] = r11 + r21; + p2[0] = r20 - r10; + p2[1] = r21 - r11; + p3[0] = r30 - r10; + p3[1] = r31 - r11; + + p0 += max_jj * 4 * 2; + p1 += max_jj * 4 * 2; + p2 += max_jj * 4 * 2; + p3 += max_jj * 4 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 2 + 1 < w) r1 = r0123[1]; + if (tj * 2 + 2 < w) r2 = r0123[2]; + if (tj * 2 + 3 < w) r3 = r0123[3]; + } + } + + tmp[0][m] = r0 - r2; + tmp[1][m] = r1 + r2; + tmp[2][m] = r2 - r1; + tmp[3][m] = r3 - r1; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + + p0[0] = r0 - r2; + p1[0] = r1 + r2; + p2[0] = r2 - r1; + p3[0] = r3 - r1; + + p0 += max_jj * 4; + p1 += max_jj * 4; + p2 += max_jj * 4; + p3 += max_jj * 4; + } + } + } +} + +static inline void conv3x3s1_winograd23_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[2][4] = { + // {1, 1, 1, 0}, + // {0, 1, -1, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 1) / 2; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + int tmp[2][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { + int32x4_t _r00 = vld1q_s32(r0); + int32x4_t _r01 = vld1q_s32(r0 + 4); + int32x4_t _r10 = vld1q_s32(r1); + int32x4_t _r11 = vld1q_s32(r1 + 4); + int32x4_t _r20 = vld1q_s32(r2); + int32x4_t _r21 = vld1q_s32(r2 + 4); + int32x4_t _r30 = vld1q_s32(r3); + int32x4_t _r31 = vld1q_s32(r3 + 4); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_r00, _r10), _r20); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_r01, _r11), _r21); + int32x4_t _tmp10 = vaddq_s32(vsubq_s32(_r10, _r20), _r30); + int32x4_t _tmp11 = vaddq_s32(vsubq_s32(_r11, _r21), _r31); + + vst1q_s32(tmp[0][m], _tmp00); + vst1q_s32(tmp[0][m] + 4, _tmp01); + vst1q_s32(tmp[1][m], _tmp10); + vst1q_s32(tmp[1][m] + 4, _tmp11); + + r0 += max_jj * 4 * 8; + r1 += max_jj * 4 * 8; + r2 += max_jj * 4 * 8; + r3 += max_jj * 4 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int32x4_t _r00 = vld1q_s32(tmp[m][0]); + int32x4_t _r01 = vld1q_s32(tmp[m][0] + 4); + int32x4_t _r10 = vld1q_s32(tmp[m][1]); + int32x4_t _r11 = vld1q_s32(tmp[m][1] + 4); + int32x4_t _r20 = vld1q_s32(tmp[m][2]); + int32x4_t _r21 = vld1q_s32(tmp[m][2] + 4); + int32x4_t _r30 = vld1q_s32(tmp[m][3]); + int32x4_t _r31 = vld1q_s32(tmp[m][3] + 4); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_r00, _r10), _r20); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_r01, _r11), _r21); + int32x4_t _tmp10 = vaddq_s32(vsubq_s32(_r10, _r20), _r30); + int32x4_t _tmp11 = vaddq_s32(vsubq_s32(_r11, _r21), _r31); + + _tmp00 = vshrq_n_s32(_tmp00, 2); + _tmp01 = vshrq_n_s32(_tmp01, 2); + _tmp10 = vshrq_n_s32(_tmp10, 2); + _tmp11 = vshrq_n_s32(_tmp11, 2); + + if (out_elempack == 8) + { + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr0 + 4, _tmp01); + if (tj * 2 + 1 < outw) + { + vst1q_s32(outptr0 + 8, _tmp10); + vst1q_s32(outptr0 + 12, _tmp11); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr1, _tmp01); + if (tj * 2 + 1 < outw) + { + vst1q_s32(outptr0 + 4, _tmp10); + vst1q_s32(outptr1 + 4, _tmp11); + } + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = vgetq_lane_s32(_tmp00, 0); + outptr1[0] = vgetq_lane_s32(_tmp00, 1); + outptr2[0] = vgetq_lane_s32(_tmp00, 2); + outptr3[0] = vgetq_lane_s32(_tmp00, 3); + outptr4[0] = vgetq_lane_s32(_tmp01, 0); + outptr5[0] = vgetq_lane_s32(_tmp01, 1); + outptr6[0] = vgetq_lane_s32(_tmp01, 2); + outptr7[0] = vgetq_lane_s32(_tmp01, 3); + + if (tj * 2 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp10, 0); + outptr1[1] = vgetq_lane_s32(_tmp10, 1); + outptr2[1] = vgetq_lane_s32(_tmp10, 2); + outptr3[1] = vgetq_lane_s32(_tmp10, 3); + outptr4[1] = vgetq_lane_s32(_tmp11, 0); + outptr5[1] = vgetq_lane_s32(_tmp11, 1); + outptr6[1] = vgetq_lane_s32(_tmp11, 2); + outptr7[1] = vgetq_lane_s32(_tmp11, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + int tmp[2][4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + + for (int m = 0; m < 4; m++) + { + int32x4_t _r0 = vld1q_s32(r0); + int32x4_t _r1 = vld1q_s32(r1); + int32x4_t _r2 = vld1q_s32(r2); + int32x4_t _r3 = vld1q_s32(r3); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_r0, _r1), _r2); + int32x4_t _tmp1 = vaddq_s32(vsubq_s32(_r1, _r2), _r3); + + vst1q_s32(tmp[0][m], _tmp0); + vst1q_s32(tmp[1][m], _tmp1); + + r0 += max_jj * 4 * 4; + r1 += max_jj * 4 * 4; + r2 += max_jj * 4 * 4; + r3 += max_jj * 4 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int32x4_t _r0 = vld1q_s32(tmp[m][0]); + int32x4_t _r1 = vld1q_s32(tmp[m][1]); + int32x4_t _r2 = vld1q_s32(tmp[m][2]); + int32x4_t _r3 = vld1q_s32(tmp[m][3]); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_r0, _r1), _r2); + int32x4_t _tmp1 = vaddq_s32(vsubq_s32(_r1, _r2), _r3); + + _tmp0 = vshrq_n_s32(_tmp0, 2); + _tmp1 = vshrq_n_s32(_tmp1, 2); + + if (out_elempack == 4) + { + vst1q_s32(outptr0, _tmp0); + if (tj * 2 + 1 < outw) vst1q_s32(outptr0 + 4, _tmp1); + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = vgetq_lane_s32(_tmp0, 0); + outptr1[0] = vgetq_lane_s32(_tmp0, 1); + outptr2[0] = vgetq_lane_s32(_tmp0, 2); + outptr3[0] = vgetq_lane_s32(_tmp0, 3); + + if (tj * 2 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp1, 0); + outptr1[1] = vgetq_lane_s32(_tmp1, 1); + outptr2[1] = vgetq_lane_s32(_tmp1, 2); + outptr3[1] = vgetq_lane_s32(_tmp1, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[2][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m][0] = r0[0] + r1[0] + r2[0]; + tmp[0][m][1] = r0[1] + r1[1] + r2[1]; + tmp[1][m][0] = r1[0] - r2[0] + r3[0]; + tmp[1][m][1] = r1[1] - r2[1] + r3[1]; + + r0 += max_jj * 4 * 2; + r1 += max_jj * 4 * 2; + r2 += max_jj * 4 * 2; + r3 += max_jj * 4 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp00 = tmp[m][0][0] + tmp[m][1][0] + tmp[m][2][0]; + int tmp01 = tmp[m][0][1] + tmp[m][1][1] + tmp[m][2][1]; + int tmp10 = tmp[m][1][0] - tmp[m][2][0] + tmp[m][3][0]; + int tmp11 = tmp[m][1][1] - tmp[m][2][1] + tmp[m][3][1]; + + tmp00 = tmp00 >> 2; + tmp01 = tmp01 >> 2; + tmp10 = tmp10 >> 2; + tmp11 = tmp11 >> 2; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[2][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m] = r0[0] + r1[0] + r2[0]; + tmp[1][m] = r1[0] - r2[0] + r3[0]; + + r0 += max_jj * 4; + r1 += max_jj * 4; + r2 += max_jj * 4; + r3 += max_jj * 4; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp0 = tmp[m][0] + tmp[m][1] + tmp[m][2]; + int tmp1 = tmp[m][1] - tmp[m][2] + tmp[m][3]; + + tmp0 = tmp0 >> 2; + tmp1 = tmp1 >> 2; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 2 + 1 < outw) outptr0[1] = tmp1; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd23_int8(Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 2n+2, winograd F(2,3) + int w_tiles = (outw + 1) / 2; + int h_tiles = (outh + 1) / 2; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 16; + + // NCNN_LOGE("conv3x3s1_winograd23_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 2u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 2u, opt.workspace_allocator); + + // #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + bottom_blob.release(); + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk); + } + + // transform output + conv3x3s1_winograd23_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} + +static inline void conv3x3s1_winograd43_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const short ktm[6][3] = { + // {6, 0, 0}, + // {-4, -4, -4}, + // {-4, 4, -4}, + // {1, 2, 4}, + // {1, -2, 4}, + // {0, 0, 6} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[6][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 6; + tmp[1][m] = -r0 * 4 - r1 * 4 - r2 * 4; + tmp[2][m] = -r0 * 4 + r1 * 4 - r2 * 4; + tmp[3][m] = r0 + r1 * 2 + r2 * 4; + tmp[4][m] = r0 - r1 * 2 + r2 * 4; + tmp[5][m] = r2 * 6; + + k0 += 3; + } + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 6; + short z1 = -r0 * 4 - r1 * 4 - r2 * 4; + short z2 = -r0 * 4 + r1 * 4 - r2 * 4; + short z3 = r0 + r1 * 2 + r2 * 4; + short z4 = r0 - r1 * 2 + r2 * 4; + short z5 = r2 * 6; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp[4] = z4; + ptmp[5] = z5; + ptmp += 6; + } + } + } +} + +static void conv3x3s1_winograd43_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ + const int M = outch; + const int K = inch; + const int B = 36; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 2u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd43_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd43_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const float itm[4][4] = { + // {4, 0, -5, 0, 1, 0}, + // {0, -4, -4, 1, 1, 0}, + // {0, 4, -4, -1, 1, 0}, + // {0, -2, -1, 2, 1, 0}, + // {0, 2, -1, -2, 1, 0}, + // {0, 4, 0, -5, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w + 1) / 4; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __ARM_NEON + nn_max_kk = max_kk / 8; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + + short tmp[6][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + int8x8_t _v5 = vdup_n_s8(5); + + for (int m = 0; m < 6; m++) + { + int8x8_t _r0 = vdup_n_s8(0); + int8x8_t _r1 = vdup_n_s8(0); + int8x8_t _r2 = vdup_n_s8(0); + int8x8_t _r3 = vdup_n_s8(0); + int8x8_t _r4 = vdup_n_s8(0); + int8x8_t _r5 = vdup_n_s8(0); + + if (ti * 4 + m < h) + { + if (elempack == 8) + { + _r0 = vld1_s8(r0); + if (tj * 4 + 1 < w) _r1 = vld1_s8(r0 + 8); + if (tj * 4 + 2 < w) _r2 = vld1_s8(r0 + 16); + if (tj * 4 + 3 < w) _r3 = vld1_s8(r0 + 24); + if (tj * 4 + 4 < w) _r4 = vld1_s8(r0 + 32); + if (tj * 4 + 5 < w) _r5 = vld1_s8(r0 + 40); + } + if (elempack == 1) + { + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + int8x8_t _t0 = vld1_s8(r0); + int8x8_t _t1 = vld1_s8(r1); + int8x8_t _t2 = vld1_s8(r2); + int8x8_t _t3 = vld1_s8(r3); + int8x8_t _t4 = vld1_s8(r4); + int8x8_t _t5 = vld1_s8(r5); + int8x8_t _t6 = vld1_s8(r6); + int8x8_t _t7 = vld1_s8(r7); + + int8x8_t _t01 = vzip_s8(_t0, _t1).val[0]; + int8x8_t _t23 = vzip_s8(_t2, _t3).val[0]; + int8x8_t _t45 = vzip_s8(_t4, _t5).val[0]; + int8x8_t _t67 = vzip_s8(_t6, _t7).val[0]; + int16x4x2_t _t0123 = vzip_s16(vreinterpret_s16_s8(_t01), vreinterpret_s16_s8(_t23)); + int16x4x2_t _t4567 = vzip_s16(vreinterpret_s16_s8(_t45), vreinterpret_s16_s8(_t67)); + int16x8_t _ta = vcombine_s16(_t0123.val[0], _t0123.val[1]); + int16x8_t _tb = vcombine_s16(_t4567.val[0], _t4567.val[1]); + int32x4x2_t _tab = vzipq_s32(vreinterpretq_s32_s16(_ta), vreinterpretq_s32_s16(_tb)); + + _r0 = vreinterpret_s8_s32(vget_low_s32(_tab.val[0])); + if (tj * 4 + 1 < w) _r1 = vreinterpret_s8_s32(vget_high_s32(_tab.val[0])); + if (tj * 4 + 2 < w) _r2 = vreinterpret_s8_s32(vget_low_s32(_tab.val[1])); + if (tj * 4 + 3 < w) _r3 = vreinterpret_s8_s32(vget_high_s32(_tab.val[1])); + if (tj * 4 + 4 < w) + { + _t01 = vzip_s8(_t0, _t1).val[1]; + _t23 = vzip_s8(_t2, _t3).val[1]; + _t45 = vzip_s8(_t4, _t5).val[1]; + _t67 = vzip_s8(_t6, _t7).val[1]; + int16x4_t _tc = vzip_s16(vreinterpret_s16_s8(_t01), vreinterpret_s16_s8(_t23)).val[0]; + int16x4_t _td = vzip_s16(vreinterpret_s16_s8(_t45), vreinterpret_s16_s8(_t67)).val[0]; + int32x2x2_t _tcd = vzip_s32(vreinterpret_s32_s16(_tc), vreinterpret_s32_s16(_td)); + + _r4 = vreinterpret_s8_s32(_tcd.val[0]); + if (tj * 4 + 5 < w) _r5 = vreinterpret_s8_s32(_tcd.val[1]); + } + } + } + + int16x8_t _tmp12a = vsubw_s8(vshll_n_s8(_r1, 2), _r3); + int16x8_t _tmp12b = vsubw_s8(vshll_n_s8(_r2, 2), _r4); + int16x8_t _tmp34a = vshlq_n_s16(vsubl_s8(_r3, _r1), 1); + int16x8_t _tmp34b = vsubl_s8(_r4, _r2); + + int16x8_t _tmp0 = vaddq_s16(vmovl_s8(_r4), vsubq_s16(vshll_n_s8(_r0, 2), vmull_s8(_r2, _v5))); + int16x8_t _tmp1 = vnegq_s16(vaddq_s16(_tmp12a, _tmp12b)); + int16x8_t _tmp2 = vsubq_s16(_tmp12a, _tmp12b); + int16x8_t _tmp3 = vaddq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp4 = vsubq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp5 = vaddq_s16(vmovl_s8(_r5), vsubq_s16(vshll_n_s8(_r1, 2), vmull_s8(_r3, _v5))); + + vst1q_s16(tmp[0][m], _tmp0); + vst1q_s16(tmp[1][m], _tmp1); + vst1q_s16(tmp[2][m], _tmp2); + vst1q_s16(tmp[3][m], _tmp3); + vst1q_s16(tmp[4][m], _tmp4); + vst1q_s16(tmp[5][m], _tmp5); + + r0 += w * elempack; + } + + int16x8_t _v5q = vdupq_n_s16(5); + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + short* p4 = p0 + max_jj * 8 * 4; + short* p5 = p0 + max_jj * 8 * 5; + + for (int m = 0; m < 6; m++) + { + int16x8_t _r0 = vld1q_s16(tmp[m][0]); + int16x8_t _r1 = vld1q_s16(tmp[m][1]); + int16x8_t _r2 = vld1q_s16(tmp[m][2]); + int16x8_t _r3 = vld1q_s16(tmp[m][3]); + int16x8_t _r4 = vld1q_s16(tmp[m][4]); + int16x8_t _r5 = vld1q_s16(tmp[m][5]); + + int16x8_t _tmp12a = vsubq_s16(_r3, vshlq_n_s16(_r1, 2)); + int16x8_t _tmp12b = vsubq_s16(_r4, vshlq_n_s16(_r2, 2)); + int16x8_t _tmp34a = vshlq_n_s16(vsubq_s16(_r3, _r1), 1); + int16x8_t _tmp34b = vsubq_s16(_r4, _r2); + + int16x8_t _tmp0 = vaddq_s16(_r4, vsubq_s16(vshlq_n_s16(_r0, 2), vmulq_s16(_r2, _v5q))); + int16x8_t _tmp1 = vaddq_s16(_tmp12b, _tmp12a); + int16x8_t _tmp2 = vsubq_s16(_tmp12b, _tmp12a); + int16x8_t _tmp3 = vaddq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp4 = vsubq_s16(_tmp34b, _tmp34a); + int16x8_t _tmp5 = vaddq_s16(_r5, vsubq_s16(vshlq_n_s16(_r1, 2), vmulq_s16(_r3, _v5q))); + + vst1q_s16(p0, _tmp0); + vst1q_s16(p1, _tmp1); + vst1q_s16(p2, _tmp2); + vst1q_s16(p3, _tmp3); + vst1q_s16(p4, _tmp4); + vst1q_s16(p5, _tmp5); + + p0 += max_jj * 6 * 8; + p1 += max_jj * 6 * 8; + p2 += max_jj * 6 * 8; + p3 += max_jj * 6 * 8; + p4 += max_jj * 6 * 8; + p5 += max_jj * 6 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __ARM_NEON + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __ARM_NEON + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[6][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + signed char r40 = 0; + signed char r41 = 0; + signed char r50 = 0; + signed char r51 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 4 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 4 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 4 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + if (tj * 4 + 4 < w) + { + r40 = r0[4]; + r41 = r1[4]; + } + if (tj * 4 + 5 < w) + { + r50 = r0[5]; + r51 = r1[5]; + } + } + } + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + tmp[0][m][0] = r40 + r00 * 4 - r20 * 5; + tmp[0][m][1] = r41 + r01 * 4 - r21 * 5; + tmp[1][m][0] = tmp120b + tmp120a; + tmp[1][m][1] = tmp121b + tmp121a; + tmp[2][m][0] = tmp120b - tmp120a; + tmp[2][m][1] = tmp121b - tmp121a; + tmp[3][m][0] = tmp340b + tmp340a; + tmp[3][m][1] = tmp341b + tmp341a; + tmp[4][m][0] = tmp340b - tmp340a; + tmp[4][m][1] = tmp341b - tmp341a; + tmp[5][m][0] = r50 + r10 * 4 - r30 * 5; + tmp[5][m][1] = r51 + r11 * 4 - r31 * 5; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + short* p4 = p0 + max_jj * 2 * 4; + short* p5 = p0 + max_jj * 2 * 5; + + for (int m = 0; m < 6; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + short r40 = tmp[m][4][0]; + short r41 = tmp[m][4][1]; + short r50 = tmp[m][5][0]; + short r51 = tmp[m][5][1]; + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + p0[0] = r40 + r00 * 4 - r20 * 5; + p0[1] = r41 + r01 * 4 - r21 * 5; + p1[0] = tmp120b + tmp120a; + p1[1] = tmp121b + tmp121a; + p2[0] = tmp120b - tmp120a; + p2[1] = tmp121b - tmp121a; + p3[0] = tmp340b + tmp340a; + p3[1] = tmp341b + tmp341a; + p4[0] = tmp340b - tmp340a; + p4[1] = tmp341b - tmp341a; + p5[0] = r50 + r10 * 4 - r30 * 5; + p5[1] = r51 + r11 * 4 - r31 * 5; + + p0 += max_jj * 6 * 2; + p1 += max_jj * 6 * 2; + p2 += max_jj * 6 * 2; + p3 += max_jj * 6 * 2; + p4 += max_jj * 6 * 2; + p5 += max_jj * 6 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[6][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + signed char r4 = 0; + signed char r5 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 4 + 1 < w) r1 = r0123[1]; + if (tj * 4 + 2 < w) r2 = r0123[2]; + if (tj * 4 + 3 < w) r3 = r0123[3]; + if (tj * 4 + 4 < w) r4 = r0123[4]; + if (tj * 4 + 5 < w) r5 = r0123[5]; + } + } + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + tmp[0][m] = r4 + r0 * 4 - r2 * 5; + tmp[1][m] = tmp12b + tmp12a; + tmp[2][m] = tmp12b - tmp12a; + tmp[3][m] = tmp34b + tmp34a; + tmp[4][m] = tmp34b - tmp34a; + tmp[5][m] = r5 + r1 * 4 - r3 * 5; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + short* p4 = p0 + max_jj * 4; + short* p5 = p0 + max_jj * 5; + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + short r4 = tmp[m][4]; + short r5 = tmp[m][5]; + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + p0[0] = r4 + r0 * 4 - r2 * 5; + p1[0] = tmp12b + tmp12a; + p2[0] = tmp12b - tmp12a; + p3[0] = tmp34b + tmp34a; + p4[0] = tmp34b - tmp34a; + p5[0] = r5 + r1 * 4 - r3 * 5; + + p0 += max_jj * 6; + p1 += max_jj * 6; + p2 += max_jj * 6; + p3 += max_jj * 6; + p4 += max_jj * 6; + p5 += max_jj * 6; + } + } + } +} + +static inline void conv3x3s1_winograd43_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[4][6] = { + // {1, 1, 1, 1, 1, 0}, + // {0, 1, -1, 2, -2, 0}, + // {0, 1, 1, 4, 4, 0}, + // {0, 1, -1, 8, -8, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 3) / 4; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + int tmp[4][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + const int* r4 = r0 + max_jj * 8 * 4; + const int* r5 = r0 + max_jj * 8 * 5; + + for (int m = 0; m < 5; m++) + { + int32x4_t _r00 = vld1q_s32(r0); + int32x4_t _r01 = vld1q_s32(r0 + 4); + int32x4_t _r10 = vld1q_s32(r1); + int32x4_t _r11 = vld1q_s32(r1 + 4); + int32x4_t _r20 = vld1q_s32(r2); + int32x4_t _r21 = vld1q_s32(r2 + 4); + int32x4_t _r30 = vld1q_s32(r3); + int32x4_t _r31 = vld1q_s32(r3 + 4); + int32x4_t _r40 = vld1q_s32(r4); + int32x4_t _r41 = vld1q_s32(r4 + 4); + int32x4_t _r50 = vld1q_s32(r5); + int32x4_t _r51 = vld1q_s32(r5 + 4); + + int32x4_t _tmp02a0 = vaddq_s32(_r10, _r20); + int32x4_t _tmp02a1 = vaddq_s32(_r11, _r21); + int32x4_t _tmp02b0 = vaddq_s32(_r30, _r40); + int32x4_t _tmp02b1 = vaddq_s32(_r31, _r41); + int32x4_t _tmp13a0 = vsubq_s32(_r10, _r20); + int32x4_t _tmp13a1 = vsubq_s32(_r11, _r21); + int32x4_t _tmp13b0 = vsubq_s32(_r30, _r40); + int32x4_t _tmp13b1 = vsubq_s32(_r31, _r41); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_tmp02a0, _tmp02b0), _r00); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_tmp02a1, _tmp02b1), _r01); + int32x4_t _tmp10 = vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 1)); + int32x4_t _tmp11 = vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 1)); + int32x4_t _tmp20 = vaddq_s32(_tmp02a0, vshlq_n_s32(_tmp02b0, 2)); + int32x4_t _tmp21 = vaddq_s32(_tmp02a1, vshlq_n_s32(_tmp02b1, 2)); + int32x4_t _tmp30 = vaddq_s32(vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 3)), vshlq_n_s32(_r50, 2)); + int32x4_t _tmp31 = vaddq_s32(vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 3)), vshlq_n_s32(_r51, 2)); + + vst1q_s32(tmp[0][m], _tmp00); + vst1q_s32(tmp[0][m] + 4, _tmp01); + vst1q_s32(tmp[1][m], _tmp10); + vst1q_s32(tmp[1][m] + 4, _tmp11); + vst1q_s32(tmp[2][m], _tmp20); + vst1q_s32(tmp[2][m] + 4, _tmp21); + vst1q_s32(tmp[3][m], _tmp30); + vst1q_s32(tmp[3][m] + 4, _tmp31); + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + for (int m = 5; m < 6; m++) + { + int32x4_t _r00 = vld1q_s32(r0); + int32x4_t _r01 = vld1q_s32(r0 + 4); + int32x4_t _r10 = vld1q_s32(r1); + int32x4_t _r11 = vld1q_s32(r1 + 4); + int32x4_t _r20 = vld1q_s32(r2); + int32x4_t _r21 = vld1q_s32(r2 + 4); + int32x4_t _r30 = vld1q_s32(r3); + int32x4_t _r31 = vld1q_s32(r3 + 4); + int32x4_t _r40 = vld1q_s32(r4); + int32x4_t _r41 = vld1q_s32(r4 + 4); + int32x4_t _r50 = vld1q_s32(r5); + int32x4_t _r51 = vld1q_s32(r5 + 4); + + int32x4_t _tmp02a0 = vaddq_s32(_r10, _r20); + int32x4_t _tmp02a1 = vaddq_s32(_r11, _r21); + int32x4_t _tmp02b0 = vaddq_s32(_r30, _r40); + int32x4_t _tmp02b1 = vaddq_s32(_r31, _r41); + int32x4_t _tmp13a0 = vsubq_s32(_r10, _r20); + int32x4_t _tmp13a1 = vsubq_s32(_r11, _r21); + int32x4_t _tmp13b0 = vsubq_s32(_r30, _r40); + int32x4_t _tmp13b1 = vsubq_s32(_r31, _r41); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_tmp02a0, _tmp02b0), _r00); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_tmp02a1, _tmp02b1), _r01); + int32x4_t _tmp10 = vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 1)); + int32x4_t _tmp11 = vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 1)); + int32x4_t _tmp20 = vaddq_s32(_tmp02a0, vshlq_n_s32(_tmp02b0, 2)); + int32x4_t _tmp21 = vaddq_s32(_tmp02a1, vshlq_n_s32(_tmp02b1, 2)); + int32x4_t _tmp30 = vaddq_s32(vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 3)), vshlq_n_s32(_r50, 2)); + int32x4_t _tmp31 = vaddq_s32(vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 3)), vshlq_n_s32(_r51, 2)); + + _tmp00 = vshlq_n_s32(_tmp00, 2); + _tmp01 = vshlq_n_s32(_tmp01, 2); + _tmp10 = vshlq_n_s32(_tmp10, 2); + _tmp11 = vshlq_n_s32(_tmp11, 2); + _tmp20 = vshlq_n_s32(_tmp20, 2); + _tmp21 = vshlq_n_s32(_tmp21, 2); + _tmp30 = vshlq_n_s32(_tmp30, 2); + _tmp31 = vshlq_n_s32(_tmp31, 2); + + vst1q_s32(tmp[0][m], _tmp00); + vst1q_s32(tmp[0][m] + 4, _tmp01); + vst1q_s32(tmp[1][m], _tmp10); + vst1q_s32(tmp[1][m] + 4, _tmp11); + vst1q_s32(tmp[2][m], _tmp20); + vst1q_s32(tmp[2][m] + 4, _tmp21); + vst1q_s32(tmp[3][m], _tmp30); + vst1q_s32(tmp[3][m] + 4, _tmp31); + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int32x4_t _r00 = vld1q_s32(tmp[m][0]); + int32x4_t _r01 = vld1q_s32(tmp[m][0] + 4); + int32x4_t _r10 = vld1q_s32(tmp[m][1]); + int32x4_t _r11 = vld1q_s32(tmp[m][1] + 4); + int32x4_t _r20 = vld1q_s32(tmp[m][2]); + int32x4_t _r21 = vld1q_s32(tmp[m][2] + 4); + int32x4_t _r30 = vld1q_s32(tmp[m][3]); + int32x4_t _r31 = vld1q_s32(tmp[m][3] + 4); + int32x4_t _r40 = vld1q_s32(tmp[m][4]); + int32x4_t _r41 = vld1q_s32(tmp[m][4] + 4); + int32x4_t _r50 = vld1q_s32(tmp[m][5]); + int32x4_t _r51 = vld1q_s32(tmp[m][5] + 4); + + int32x4_t _tmp02a0 = vaddq_s32(_r10, _r20); + int32x4_t _tmp02a1 = vaddq_s32(_r11, _r21); + int32x4_t _tmp02b0 = vaddq_s32(_r30, _r40); + int32x4_t _tmp02b1 = vaddq_s32(_r31, _r41); + int32x4_t _tmp13a0 = vsubq_s32(_r10, _r20); + int32x4_t _tmp13a1 = vsubq_s32(_r11, _r21); + int32x4_t _tmp13b0 = vsubq_s32(_r30, _r40); + int32x4_t _tmp13b1 = vsubq_s32(_r31, _r41); + + int32x4_t _tmp00 = vaddq_s32(vaddq_s32(_tmp02a0, _tmp02b0), _r00); + int32x4_t _tmp01 = vaddq_s32(vaddq_s32(_tmp02a1, _tmp02b1), _r01); + int32x4_t _tmp10 = vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 1)); + int32x4_t _tmp11 = vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 1)); + int32x4_t _tmp20 = vaddq_s32(_tmp02a0, vshlq_n_s32(_tmp02b0, 2)); + int32x4_t _tmp21 = vaddq_s32(_tmp02a1, vshlq_n_s32(_tmp02b1, 2)); + int32x4_t _tmp30 = vaddq_s32(vaddq_s32(_tmp13a0, vshlq_n_s32(_tmp13b0, 3)), _r50); + int32x4_t _tmp31 = vaddq_s32(vaddq_s32(_tmp13a1, vshlq_n_s32(_tmp13b1, 3)), _r51); + + // TODO use integer trick for division by 576 + float32x4_t _v576 = vdupq_n_f32(1.0 / 576); + _tmp00 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp00), _v576)); + _tmp01 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp01), _v576)); + _tmp10 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp10), _v576)); + _tmp11 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp11), _v576)); + _tmp20 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp20), _v576)); + _tmp21 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp21), _v576)); + _tmp30 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp30), _v576)); + _tmp31 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp31), _v576)); + + if (out_elempack == 8) + { + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr0 + 4, _tmp01); + if (tj * 4 + 1 < outw) + { + vst1q_s32(outptr0 + 8, _tmp10); + vst1q_s32(outptr0 + 12, _tmp11); + } + if (tj * 4 + 2 < outw) + { + vst1q_s32(outptr0 + 16, _tmp20); + vst1q_s32(outptr0 + 20, _tmp21); + } + if (tj * 4 + 3 < outw) + { + vst1q_s32(outptr0 + 24, _tmp30); + vst1q_s32(outptr0 + 28, _tmp31); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + vst1q_s32(outptr0, _tmp00); + vst1q_s32(outptr1, _tmp01); + if (tj * 4 + 1 < outw) + { + vst1q_s32(outptr0 + 4, _tmp10); + vst1q_s32(outptr1 + 4, _tmp11); + } + if (tj * 4 + 2 < outw) + { + vst1q_s32(outptr0 + 8, _tmp20); + vst1q_s32(outptr1 + 8, _tmp21); + } + if (tj * 4 + 3 < outw) + { + vst1q_s32(outptr0 + 12, _tmp30); + vst1q_s32(outptr1 + 12, _tmp31); + } + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = vgetq_lane_s32(_tmp00, 0); + outptr1[0] = vgetq_lane_s32(_tmp00, 1); + outptr2[0] = vgetq_lane_s32(_tmp00, 2); + outptr3[0] = vgetq_lane_s32(_tmp00, 3); + outptr4[0] = vgetq_lane_s32(_tmp01, 0); + outptr5[0] = vgetq_lane_s32(_tmp01, 1); + outptr6[0] = vgetq_lane_s32(_tmp01, 2); + outptr7[0] = vgetq_lane_s32(_tmp01, 3); + if (tj * 4 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp10, 0); + outptr1[1] = vgetq_lane_s32(_tmp10, 1); + outptr2[1] = vgetq_lane_s32(_tmp10, 2); + outptr3[1] = vgetq_lane_s32(_tmp10, 3); + outptr4[1] = vgetq_lane_s32(_tmp11, 0); + outptr5[1] = vgetq_lane_s32(_tmp11, 1); + outptr6[1] = vgetq_lane_s32(_tmp11, 2); + outptr7[1] = vgetq_lane_s32(_tmp11, 3); + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = vgetq_lane_s32(_tmp20, 0); + outptr1[2] = vgetq_lane_s32(_tmp20, 1); + outptr2[2] = vgetq_lane_s32(_tmp20, 2); + outptr3[2] = vgetq_lane_s32(_tmp20, 3); + outptr4[2] = vgetq_lane_s32(_tmp21, 0); + outptr5[2] = vgetq_lane_s32(_tmp21, 1); + outptr6[2] = vgetq_lane_s32(_tmp21, 2); + outptr7[2] = vgetq_lane_s32(_tmp21, 3); + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = vgetq_lane_s32(_tmp30, 0); + outptr1[3] = vgetq_lane_s32(_tmp30, 1); + outptr2[3] = vgetq_lane_s32(_tmp30, 2); + outptr3[3] = vgetq_lane_s32(_tmp30, 3); + outptr4[3] = vgetq_lane_s32(_tmp31, 0); + outptr5[3] = vgetq_lane_s32(_tmp31, 1); + outptr6[3] = vgetq_lane_s32(_tmp31, 2); + outptr7[3] = vgetq_lane_s32(_tmp31, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + int tmp[4][6][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + const int* r4 = r0 + max_jj * 4 * 4; + const int* r5 = r0 + max_jj * 4 * 5; + + for (int m = 0; m < 5; m++) + { + int32x4_t _r0 = vld1q_s32(r0); + int32x4_t _r1 = vld1q_s32(r1); + int32x4_t _r2 = vld1q_s32(r2); + int32x4_t _r3 = vld1q_s32(r3); + int32x4_t _r4 = vld1q_s32(r4); + int32x4_t _r5 = vld1q_s32(r5); + + int32x4_t _tmp02a = vaddq_s32(_r1, _r2); + int32x4_t _tmp02b = vaddq_s32(_r3, _r4); + int32x4_t _tmp13a = vsubq_s32(_r1, _r2); + int32x4_t _tmp13b = vsubq_s32(_r3, _r4); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_tmp02a, _tmp02b), _r0); + int32x4_t _tmp1 = vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 1)); + int32x4_t _tmp2 = vaddq_s32(_tmp02a, vshlq_n_s32(_tmp02b, 2)); + int32x4_t _tmp3 = vaddq_s32(vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 3)), vshlq_n_s32(_r5, 2)); + + vst1q_s32(tmp[0][m], _tmp0); + vst1q_s32(tmp[1][m], _tmp1); + vst1q_s32(tmp[2][m], _tmp2); + vst1q_s32(tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + for (int m = 5; m < 6; m++) + { + int32x4_t _r0 = vld1q_s32(r0); + int32x4_t _r1 = vld1q_s32(r1); + int32x4_t _r2 = vld1q_s32(r2); + int32x4_t _r3 = vld1q_s32(r3); + int32x4_t _r4 = vld1q_s32(r4); + int32x4_t _r5 = vld1q_s32(r5); + + int32x4_t _tmp02a = vaddq_s32(_r1, _r2); + int32x4_t _tmp02b = vaddq_s32(_r3, _r4); + int32x4_t _tmp13a = vsubq_s32(_r1, _r2); + int32x4_t _tmp13b = vsubq_s32(_r3, _r4); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_tmp02a, _tmp02b), _r0); + int32x4_t _tmp1 = vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 1)); + int32x4_t _tmp2 = vaddq_s32(_tmp02a, vshlq_n_s32(_tmp02b, 2)); + int32x4_t _tmp3 = vaddq_s32(vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 3)), vshlq_n_s32(_r5, 2)); + + _tmp0 = vshlq_n_s32(_tmp0, 2); + _tmp1 = vshlq_n_s32(_tmp1, 2); + _tmp2 = vshlq_n_s32(_tmp2, 2); + _tmp3 = vshlq_n_s32(_tmp3, 2); + + vst1q_s32(tmp[0][m], _tmp0); + vst1q_s32(tmp[1][m], _tmp1); + vst1q_s32(tmp[2][m], _tmp2); + vst1q_s32(tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int32x4_t _r0 = vld1q_s32(tmp[m][0]); + int32x4_t _r1 = vld1q_s32(tmp[m][1]); + int32x4_t _r2 = vld1q_s32(tmp[m][2]); + int32x4_t _r3 = vld1q_s32(tmp[m][3]); + int32x4_t _r4 = vld1q_s32(tmp[m][4]); + int32x4_t _r5 = vld1q_s32(tmp[m][5]); + + int32x4_t _tmp02a = vaddq_s32(_r1, _r2); + int32x4_t _tmp02b = vaddq_s32(_r3, _r4); + int32x4_t _tmp13a = vsubq_s32(_r1, _r2); + int32x4_t _tmp13b = vsubq_s32(_r3, _r4); + + int32x4_t _tmp0 = vaddq_s32(vaddq_s32(_tmp02a, _tmp02b), _r0); + int32x4_t _tmp1 = vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 1)); + int32x4_t _tmp2 = vaddq_s32(_tmp02a, vshlq_n_s32(_tmp02b, 2)); + int32x4_t _tmp3 = vaddq_s32(vaddq_s32(_tmp13a, vshlq_n_s32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + float32x4_t _v576 = vdupq_n_f32(1.0 / 576); + _tmp0 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp0), _v576)); + _tmp1 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp1), _v576)); + _tmp2 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp2), _v576)); + _tmp3 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_tmp3), _v576)); + + if (out_elempack == 4) + { + vst1q_s32(outptr0, _tmp0); + if (tj * 4 + 1 < outw) vst1q_s32(outptr0 + 4, _tmp1); + if (tj * 4 + 2 < outw) vst1q_s32(outptr0 + 8, _tmp2); + if (tj * 4 + 3 < outw) vst1q_s32(outptr0 + 12, _tmp3); + } + if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = vgetq_lane_s32(_tmp0, 0); + outptr1[0] = vgetq_lane_s32(_tmp0, 1); + outptr2[0] = vgetq_lane_s32(_tmp0, 2); + outptr3[0] = vgetq_lane_s32(_tmp0, 3); + if (tj * 4 + 1 < outw) + { + outptr0[1] = vgetq_lane_s32(_tmp1, 0); + outptr1[1] = vgetq_lane_s32(_tmp1, 1); + outptr2[1] = vgetq_lane_s32(_tmp1, 2); + outptr3[1] = vgetq_lane_s32(_tmp1, 3); + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = vgetq_lane_s32(_tmp2, 0); + outptr1[2] = vgetq_lane_s32(_tmp2, 1); + outptr2[2] = vgetq_lane_s32(_tmp2, 2); + outptr3[2] = vgetq_lane_s32(_tmp2, 3); + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = vgetq_lane_s32(_tmp3, 0); + outptr1[3] = vgetq_lane_s32(_tmp3, 1); + outptr2[3] = vgetq_lane_s32(_tmp3, 2); + outptr3[3] = vgetq_lane_s32(_tmp3, 3); + } + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[4][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + const int* r4 = r0 + max_jj * 2 * 4; + const int* r5 = r0 + max_jj * 2 * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + for (int m = 5; m < 6; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp00 = tmp00 * 4; + tmp01 = tmp01 * 4; + tmp10 = tmp10 * 4; + tmp11 = tmp11 * 4; + tmp20 = tmp20 * 4; + tmp21 = tmp21 * 4; + tmp30 = tmp30 * 4; + tmp31 = tmp31 * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a0 = tmp[m][1][0] + tmp[m][2][0]; + int tmp02a1 = tmp[m][1][1] + tmp[m][2][1]; + int tmp02b0 = tmp[m][3][0] + tmp[m][4][0]; + int tmp02b1 = tmp[m][3][1] + tmp[m][4][1]; + int tmp13a0 = tmp[m][1][0] - tmp[m][2][0]; + int tmp13a1 = tmp[m][1][1] - tmp[m][2][1]; + int tmp13b0 = tmp[m][3][0] - tmp[m][4][0]; + int tmp13b1 = tmp[m][3][1] - tmp[m][4][1]; + + int tmp00 = tmp02a0 + tmp02b0 + tmp[m][0][0]; + int tmp01 = tmp02a1 + tmp02b1 + tmp[m][0][1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + tmp[m][5][0]; + int tmp31 = tmp13a1 + tmp13b1 * 8 + tmp[m][5][1]; + + tmp00 = tmp00 / 576; + tmp01 = tmp01 / 576; + tmp10 = tmp10 / 576; + tmp11 = tmp11 / 576; + tmp20 = tmp20 / 576; + tmp21 = tmp21 / 576; + tmp30 = tmp30 / 576; + tmp31 = tmp31 / 576; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp20; + outptr1[2] = tmp21; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp30; + outptr1[3] = tmp31; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[4][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + const int* r4 = r0 + max_jj * 4; + const int* r5 = r0 + max_jj * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + for (int m = 5; m < 6; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp0 = tmp0 * 4; + tmp1 = tmp1 * 4; + tmp2 = tmp2 * 4; + tmp3 = tmp3 * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a = tmp[m][1] + tmp[m][2]; + int tmp02b = tmp[m][3] + tmp[m][4]; + int tmp13a = tmp[m][1] - tmp[m][2]; + int tmp13b = tmp[m][3] - tmp[m][4]; + + int tmp0 = tmp02a + tmp02b + tmp[m][0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + tmp[m][5]; + + tmp0 = tmp0 / 576; + tmp1 = tmp1 / 576; + tmp2 = tmp2 / 576; + tmp3 = tmp3 / 576; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 4 + 1 < outw) outptr0[1] = tmp1; + if (tj * 4 + 2 < outw) outptr0[2] = tmp2; + if (tj * 4 + 3 < outw) outptr0[3] = tmp3; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd43_int8(Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 4n+2, winograd F(4,3) + int w_tiles = (outw + 3) / 4; + int h_tiles = (outh + 3) / 4; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 36; + + // NCNN_LOGE("conv3x3s1_winograd43_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 2u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 2u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + bottom_blob.release(); + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk); + } + + // transform output + conv3x3s1_winograd43_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} diff --git a/src/layer/arm/convolution_arm.cpp b/src/layer/arm/convolution_arm.cpp index c8d48aec7622..849a8daea6ba 100644 --- a/src/layer/arm/convolution_arm.cpp +++ b/src/layer/arm/convolution_arm.cpp @@ -49,10 +49,9 @@ namespace ncnn { #if NCNN_INT8 #include "convolution_im2col_gemm_int8.h" +#include "convolution_3x3_winograd_int8.h" -#include "convolution_winograd_transform_int8.h" -#include "convolution_winograd_dot_int8.h" -#include "convolution_3x3_int8.h" +// #include "convolution_3x3_int8.h" #include "convolution_int8.h" #endif // NCNN_INT8 @@ -74,12 +73,6 @@ namespace ncnn { #include "convolution_pack8to4_int8.h" #include "convolution_pack1to4_int8.h" #include "convolution_pack8to1_int8.h" -#include "convolution_winograd_transform_pack4_int8.h" -#include "convolution_winograd_transform_pack8_int8.h" -#include "convolution_winograd_dot_pack8to4_int8.h" -#include "convolution_winograd_dot_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" -#include "convolution_3x3_pack8to1_int8.h" #endif // NCNN_INT8 #endif // __ARM_NEON @@ -1285,6 +1278,14 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) const int maxk = kernel_w * kernel_h; const int num_input = weight_data_size / maxk / num_output; + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input >= 8 && num_output >= 8) && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1; +#if NCNN_ARM82DOT + if (ncnn::cpu_support_arm_asimddp()) + { + prefer_winograd = false; + } +#endif + int elempack = 1; int out_elempack = 1; #if __ARM_NEON @@ -1295,25 +1296,12 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) } #endif // __ARM_NEON -#if NCNN_ARM82DOT - if (elempack == 8 && out_elempack == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && (!ncnn::cpu_support_arm_asimddp() || (ncnn::cpu_support_arm_asimddp() && num_input >= 256 && num_output >= 256))) -#else - if (elempack == 8 && out_elempack == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) -#endif - { -#if __ARM_NEON - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_neon(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __ARM_NEON - } - else if (elempack == 8 && out_elempack == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + if (opt.use_winograd_convolution && prefer_winograd) { -#if __ARM_NEON - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_neon(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __ARM_NEON - } - else if (elempack == 1 && out_elempack == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_winograd43_transform_kernel_int8_neon(weight_data, weight_winograd43_data, num_input, num_output, opt); + if (opt.use_winograd43_convolution) + conv3x3s1_winograd43_transform_kernel_int8(weight_data, weight_winograd43_data, num_input, num_output, opt); + else + conv3x3s1_winograd23_transform_kernel_int8(weight_data, weight_winograd23_data, num_input, num_output, opt); } else if (opt.use_sgemm_convolution) { @@ -1321,10 +1309,6 @@ int Convolution_arm::create_pipeline_int8_arm(const Option& opt) } else if (elempack == 1 && out_elempack == 1) { - // if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) - // { - // conv3x3s2_transform_kernel_int8_neon(weight_data, weight_3x3s2_data_int8, num_input, num_output); - // } weight_data_tm = weight_data; } else @@ -1405,20 +1389,29 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con // NCNN_LOGE("forward_int8_arm %d %d %d %d %d", w, h, bottom_blob_bordered.c, elempack, out_elempack); - top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); - if (top_blob.empty()) - return -100; - -#if NCNN_ARM82DOT int channels = bottom_blob_bordered.c; const int num_input = channels * elempack; + + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input >= 8 && num_output >= 8) && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1; +#if NCNN_ARM82DOT + if (ncnn::cpu_support_arm_asimddp()) + { + prefer_winograd = false; + } #endif int out_elempack_int32 = 1; #if __ARM_NEON if (opt.use_packing_layout) { - out_elempack_int32 = num_output % 4 == 0 ? 4 : 1; + if ((opt.use_winograd_convolution && prefer_winograd) || opt.use_sgemm_convolution) + { + out_elempack_int32 = num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + } + else + { + out_elempack_int32 = num_output % 4 == 0 ? 4 : 1; + } } #endif // __ARM_NEON @@ -1435,25 +1428,12 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con NCNN_LOGE("opt.num_threads %d changed, convolution gemm will use load-time value %d", opt.num_threads, nT); } -#if NCNN_ARM82DOT - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && (!ncnn::cpu_support_arm_asimddp() || (ncnn::cpu_support_arm_asimddp() && num_input >= 256 && num_output >= 256))) -#else - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) -#endif + if (opt.use_winograd_convolution && prefer_winograd) { -#if __ARM_NEON - conv3x3s1_winograd43_pack8to4_int8_neon(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __ARM_NEON - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __ARM_NEON - conv3x3s1_winograd43_pack8to1_int8_neon(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __ARM_NEON - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - conv3x3s1_winograd43_int8_neon(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); + if (opt.use_winograd43_convolution && !weight_winograd43_data.empty()) + conv3x3s1_winograd43_int8(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, _nT, opt); + else + conv3x3s1_winograd23_int8(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, _nT, opt); } else if (opt.use_sgemm_convolution) { @@ -1478,6 +1458,12 @@ int Convolution_arm::forward_int8_arm(const Mat& bottom_blob, Mat& top_blob, con convolution_int8(bottom_blob_bordered, top_blob_int32, weight_data_tm, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); } + bottom_blob_bordered.release(); + + top_blob.create(outw, outh, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + if (use_int8_requantize) { requantize_from_int32_to_int8(top_blob_int32, top_blob, scale_in_data, top_blob_int8_scales, bias_data, activation_type, activation_params, opt); diff --git a/src/layer/arm/convolution_winograd_dot_int8.h b/src/layer/arm/convolution_winograd_dot_int8.h deleted file mode 100644 index d5cf1bcd87eb..000000000000 --- a/src/layer/arm/convolution_winograd_dot_int8.h +++ /dev/null @@ -1,1005 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void convolution_winograd_dot_int8_neon(Mat& bottom_blob_tm, int outch, const Mat& kernel_tm, Mat& top_blob_tm, const Option& opt) -{ - // Mat bottom_blob_tm(tiles, 16/36/64, inch, 2u, 1, opt.workspace_allocator); - - const int tiles = bottom_blob_tm.w; - const int batch = bottom_blob_tm.h; - const int inch = bottom_blob_tm.c; - - // permute - Mat bottom_blob_tm2; -#if __ARM_NEON -#if __aarch64__ - if (tiles >= 8) - bottom_blob_tm2.create(inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 4) - bottom_blob_tm2.create(inch, tiles / 4 + tiles % 4, batch, 8u, 4, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(inch, tiles, batch, 2u, 1, opt.workspace_allocator); -#else - if (tiles >= 4) - bottom_blob_tm2.create(inch, tiles / 4 + tiles % 4, batch, 8u, 4, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(inch, tiles, batch, 2u, 1, opt.workspace_allocator); -#endif -#else // __ARM_NEON - if (tiles >= 2) - bottom_blob_tm2.create(inch, tiles / 2 + tiles % 2, batch, 4u, 2, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(inch, tiles, batch, 2u, 1, opt.workspace_allocator); -#endif // __ARM_NEON - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < batch; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __ARM_NEON -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - short* tmpptr = tm2.row(i / 8); - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; - for (; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - vst1q_s16(tmpptr, _r0); - r0 += bottom_blob_tm.cstep; - tmpptr += 8; - } - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 8 + (i % 8) / 4); -#else - short* tmpptr = tm2.row(i / 4); -#endif - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; - for (; q < inch; q++) - { - int16x4_t _r0 = vld1_s16(r0); - vst1_s16(tmpptr, _r0); - r0 += bottom_blob_tm.cstep; - tmpptr += 4; - } - } -#else // __ARM_NEON - for (; i + 1 < tiles; i += 2) - { - short* tmpptr = tm2.row(i / 2); - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; -#if __ARM_FEATURE_SIMD32 - for (; q + 1 < inch; q += 2) - { - tmpptr[0] = r0[0]; - tmpptr[2] = r0[1]; - r0 += bottom_blob_tm.cstep; - tmpptr[1] = r0[0]; - tmpptr[3] = r0[1]; - r0 += bottom_blob_tm.cstep; - tmpptr += 4; - } -#endif // __ARM_FEATURE_SIMD32 - for (; q < inch; q++) - { - tmpptr[0] = r0[0]; - tmpptr[1] = r0[1]; - r0 += bottom_blob_tm.cstep; - tmpptr += 2; - } - } -#endif // __ARM_NEON - for (; i < tiles; i++) - { -#if __ARM_NEON -#if __aarch64__ - short* tmpptr = tm2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - short* tmpptr = tm2.row(i / 4 + i % 4); -#endif -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - const short* r0 = (const short*)bottom_blob_tm + r * tiles + i; - - int q = 0; - for (; q < inch; q++) - { - tmpptr[0] = r0[0]; - r0 += bottom_blob_tm.cstep; - tmpptr += 1; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, batch, outch, 4u, 1, opt.workspace_allocator); - -#if __ARM_NEON - int nn_outch = outch >> 3; - int remain_outch_start = nn_outch << 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 8; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - int* output4_tm = top_blob_tm.channel(p + 4); - int* output5_tm = top_blob_tm.channel(p + 5); - int* output6_tm = top_blob_tm.channel(p + 6); - int* output7_tm = top_blob_tm.channel(p + 7); - - const Mat kernel0_tm = kernel_tm.channel(p / 8); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum00 = vdupq_n_s32(0); - int32x4_t _sum10 = vdupq_n_s32(0); - int32x4_t _sum20 = vdupq_n_s32(0); - int32x4_t _sum30 = vdupq_n_s32(0); - int32x4_t _sum40 = vdupq_n_s32(0); - int32x4_t _sum50 = vdupq_n_s32(0); - int32x4_t _sum60 = vdupq_n_s32(0); - int32x4_t _sum70 = vdupq_n_s32(0); - int32x4_t _sum01 = vdupq_n_s32(0); - int32x4_t _sum11 = vdupq_n_s32(0); - int32x4_t _sum21 = vdupq_n_s32(0); - int32x4_t _sum31 = vdupq_n_s32(0); - int32x4_t _sum41 = vdupq_n_s32(0); - int32x4_t _sum51 = vdupq_n_s32(0); - int32x4_t _sum61 = vdupq_n_s32(0); - int32x4_t _sum71 = vdupq_n_s32(0); - - int j = 0; - for (; j < inch; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val0), vget_low_s16(_w0), 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val0), vget_low_s16(_w0), 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val0), vget_low_s16(_w0), 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val0), vget_low_s16(_w0), 3); - _sum40 = vmlal_lane_s16(_sum40, vget_low_s16(_val0), vget_high_s16(_w0), 0); - _sum50 = vmlal_lane_s16(_sum50, vget_low_s16(_val0), vget_high_s16(_w0), 1); - _sum60 = vmlal_lane_s16(_sum60, vget_low_s16(_val0), vget_high_s16(_w0), 2); - _sum70 = vmlal_lane_s16(_sum70, vget_low_s16(_val0), vget_high_s16(_w0), 3); - - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val0), vget_low_s16(_w0), 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val0), vget_low_s16(_w0), 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val0), vget_low_s16(_w0), 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val0), vget_low_s16(_w0), 3); - _sum41 = vmlal_lane_s16(_sum41, vget_high_s16(_val0), vget_high_s16(_w0), 0); - _sum51 = vmlal_lane_s16(_sum51, vget_high_s16(_val0), vget_high_s16(_w0), 1); - _sum61 = vmlal_lane_s16(_sum61, vget_high_s16(_val0), vget_high_s16(_w0), 2); - _sum71 = vmlal_lane_s16(_sum71, vget_high_s16(_val0), vget_high_s16(_w0), 3); - - r0 += 8; - k0 += 8; - } - - vst1q_s32(output0_tm, _sum00); - vst1q_s32(output0_tm + 4, _sum01); - vst1q_s32(output1_tm, _sum10); - vst1q_s32(output1_tm + 4, _sum11); - vst1q_s32(output2_tm, _sum20); - vst1q_s32(output2_tm + 4, _sum21); - vst1q_s32(output3_tm, _sum30); - vst1q_s32(output3_tm + 4, _sum31); - vst1q_s32(output4_tm, _sum40); - vst1q_s32(output4_tm + 4, _sum41); - vst1q_s32(output5_tm, _sum50); - vst1q_s32(output5_tm + 4, _sum51); - vst1q_s32(output6_tm, _sum60); - vst1q_s32(output6_tm + 4, _sum61); - vst1q_s32(output7_tm, _sum70); - vst1q_s32(output7_tm + 4, _sum71); - - output0_tm += 8; - output1_tm += 8; - output2_tm += 8; - output3_tm += 8; - output4_tm += 8; - output5_tm += 8; - output6_tm += 8; - output7_tm += 8; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - int j = 0; - for (; j + 1 < inch; j += 2) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), vget_low_s16(_w0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val01), vget_low_s16(_w0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val01), vget_low_s16(_w0), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val01), vget_low_s16(_w0), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val01), vget_high_s16(_w0), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val01), vget_high_s16(_w0), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val01), vget_high_s16(_w0), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val01), vget_high_s16(_w0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val01), vget_low_s16(_w1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), vget_low_s16(_w1), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val01), vget_low_s16(_w1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val01), vget_low_s16(_w1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val01), vget_high_s16(_w1), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val01), vget_high_s16(_w1), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val01), vget_high_s16(_w1), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val01), vget_high_s16(_w1), 3); - - r0 += 8; - k0 += 16; - } - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, _val0, vget_low_s16(_w0), 0); - _sum1 = vmlal_lane_s16(_sum1, _val0, vget_low_s16(_w0), 1); - _sum2 = vmlal_lane_s16(_sum2, _val0, vget_low_s16(_w0), 2); - _sum3 = vmlal_lane_s16(_sum3, _val0, vget_low_s16(_w0), 3); - _sum4 = vmlal_lane_s16(_sum4, _val0, vget_high_s16(_w0), 0); - _sum5 = vmlal_lane_s16(_sum5, _val0, vget_high_s16(_w0), 1); - _sum6 = vmlal_lane_s16(_sum6, _val0, vget_high_s16(_w0), 2); - _sum7 = vmlal_lane_s16(_sum7, _val0, vget_high_s16(_w0), 3); - - r0 += 4; - k0 += 8; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output2_tm, _sum2); - vst1q_s32(output3_tm, _sum3); - vst1q_s32(output4_tm, _sum4); - vst1q_s32(output5_tm, _sum5); - vst1q_s32(output6_tm, _sum6); - vst1q_s32(output7_tm, _sum7); - - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - output4_tm += 4; - output5_tm += 4; - output6_tm += 4; - output7_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x4_t _val0123 = vld1_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), _val0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), _val0123, 0); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), _val0123, 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), _val0123, 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), _val0123, 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), _val0123, 2); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), _val0123, 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), _val0123, 3); - - r0 += 4; - k0 += 32; - } - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_dup_s16(r0); - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val0, vget_low_s16(_w0)); - _sum1 = vmlal_s16(_sum1, _val0, vget_high_s16(_w0)); - - r0 += 1; - k0 += 8; - } - - output0_tm[0] = vgetq_lane_s32(_sum0, 0); - output1_tm[0] = vgetq_lane_s32(_sum0, 1); - output2_tm[0] = vgetq_lane_s32(_sum0, 2); - output3_tm[0] = vgetq_lane_s32(_sum0, 3); - output4_tm[0] = vgetq_lane_s32(_sum1, 0); - output5_tm[0] = vgetq_lane_s32(_sum1, 1); - output6_tm[0] = vgetq_lane_s32(_sum1, 2); - output7_tm[0] = vgetq_lane_s32(_sum1, 3); - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - output4_tm += 1; - output5_tm += 1; - output6_tm += 1; - output7_tm += 1; - } - } - } - - nn_outch = (outch - remain_outch_start) >> 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = remain_outch_start + pp * 4; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p / 8 + (p % 8) / 4); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum00 = vdupq_n_s32(0); - int32x4_t _sum10 = vdupq_n_s32(0); - int32x4_t _sum20 = vdupq_n_s32(0); - int32x4_t _sum30 = vdupq_n_s32(0); - int32x4_t _sum01 = vdupq_n_s32(0); - int32x4_t _sum11 = vdupq_n_s32(0); - int32x4_t _sum21 = vdupq_n_s32(0); - int32x4_t _sum31 = vdupq_n_s32(0); - - int j = 0; - for (; j + 1 < inch; j += 2) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _val23 = vld1q_s16(r0 + 8); - int16x8_t _w01 = vld1q_s16(k0); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val01), vget_low_s16(_w01), 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val01), vget_low_s16(_w01), 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val01), vget_low_s16(_w01), 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val01), vget_low_s16(_w01), 3); - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val01), vget_low_s16(_w01), 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val01), vget_low_s16(_w01), 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val01), vget_low_s16(_w01), 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val01), vget_low_s16(_w01), 3); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val23), vget_high_s16(_w01), 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val23), vget_high_s16(_w01), 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val23), vget_high_s16(_w01), 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val23), vget_high_s16(_w01), 3); - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val23), vget_high_s16(_w01), 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val23), vget_high_s16(_w01), 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val23), vget_high_s16(_w01), 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val23), vget_high_s16(_w01), 3); - - r0 += 16; - k0 += 8; - } - for (; j < inch; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x4_t _w0 = vld1_s16(k0); - - _sum00 = vmlal_lane_s16(_sum00, vget_low_s16(_val0), _w0, 0); - _sum10 = vmlal_lane_s16(_sum10, vget_low_s16(_val0), _w0, 1); - _sum20 = vmlal_lane_s16(_sum20, vget_low_s16(_val0), _w0, 2); - _sum30 = vmlal_lane_s16(_sum30, vget_low_s16(_val0), _w0, 3); - _sum01 = vmlal_lane_s16(_sum01, vget_high_s16(_val0), _w0, 0); - _sum11 = vmlal_lane_s16(_sum11, vget_high_s16(_val0), _w0, 1); - _sum21 = vmlal_lane_s16(_sum21, vget_high_s16(_val0), _w0, 2); - _sum31 = vmlal_lane_s16(_sum31, vget_high_s16(_val0), _w0, 3); - - r0 += 8; - k0 += 4; - } - - vst1q_s32(output0_tm, _sum00); - vst1q_s32(output0_tm + 4, _sum01); - vst1q_s32(output1_tm, _sum10); - vst1q_s32(output1_tm + 4, _sum11); - vst1q_s32(output2_tm, _sum20); - vst1q_s32(output2_tm + 4, _sum21); - vst1q_s32(output3_tm, _sum30); - vst1q_s32(output3_tm + 4, _sum31); - - output0_tm += 8; - output1_tm += 8; - output2_tm += 8; - output3_tm += 8; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 1 < inch; j += 2) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _w01 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), vget_low_s16(_w01), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val01), vget_low_s16(_w01), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val01), vget_low_s16(_w01), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val01), vget_low_s16(_w01), 3); - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val01), vget_high_s16(_w01), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), vget_high_s16(_w01), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val01), vget_high_s16(_w01), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val01), vget_high_s16(_w01), 3); - - r0 += 8; - k0 += 8; - } - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_s16(r0); - int16x4_t _w0 = vld1_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, _val0, _w0, 0); - _sum1 = vmlal_lane_s16(_sum1, _val0, _w0, 1); - _sum2 = vmlal_lane_s16(_sum2, _val0, _w0, 2); - _sum3 = vmlal_lane_s16(_sum3, _val0, _w0, 3); - - r0 += 4; - k0 += 4; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output2_tm, _sum2); - vst1q_s32(output3_tm, _sum3); - - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x4_t _val0123 = vld1_s16(r0); - int16x8_t _w01 = vld1q_s16(k0); - int16x8_t _w23 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w01), _val0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w01), _val0123, 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w23), _val0123, 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w23), _val0123, 3); - - r0 += 4; - k0 += 16; - } - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - _sum0 = vaddq_s32(_sum0, _sum2); - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_dup_s16(r0); - int16x4_t _w0 = vld1_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val0, _w0); - - r0 += 1; - k0 += 4; - } - - output0_tm[0] = vgetq_lane_s32(_sum0, 0); - output1_tm[0] = vgetq_lane_s32(_sum0, 1); - output2_tm[0] = vgetq_lane_s32(_sum0, 2); - output3_tm[0] = vgetq_lane_s32(_sum0, 3); - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - } - } - } - - remain_outch_start += nn_outch << 2; -#else // __ARM_NEON - int nn_outch = outch >> 1; - int remain_outch_start = nn_outch << 1; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 2; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - - const Mat kernel0_tm = kernel_tm.channel(p / 2); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; - for (; i + 1 < tiles; i += 2) - { - const short* r0 = bb2.row(i / 2); - const short* k0 = kernel0_tm.row(r); - - int sum00 = 0; - int sum10 = 0; - int sum01 = 0; - int sum11 = 0; - - int j = 0; -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - // fomit-frame-pointer implied in optimized flag spare one register - // let us stay away from error: ‘asm’ operand has impossible constraints --- nihui -#if __OPTIMIZE__ - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val02 = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%0], #4 \n" // int16x2_t _val13 = *((int16x2_t*)r0); r0 += 2; - "ldr r4, [%1], #4 \n" // int16x2_t _w02 = *((int16x2_t*)k0); k0 += 2; - "ldr r5, [%1], #4 \n" // int16x2_t _w13 = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_val02, _w02, sum00); - "smlad %3, r3, r4, %3 \n" // sum01 = __smlad(_val13, _w02, sum01); - "smlad %4, r2, r5, %4 \n" // sum10 = __smlad(_val02, _w13, sum10); - "smlad %5, r3, r5, %5 \n" // sum11 = __smlad(_val13, _w13, sum11); - : "=r"(r0), - "=r"(k0), - "=r"(sum00), - "=r"(sum01), - "=r"(sum10), - "=r"(sum11) - : "0"(r0), - "1"(k0), - "2"(sum00), - "3"(sum01), - "4"(sum10), - "5"(sum11) - : "memory", "r2", "r3", "r4", "r5"); -#else - int _val02 = *((int*)r0); - int _val13 = *((int*)(r0 + 2)); - int _w02 = *((int*)k0); - int _w13 = *((int*)(k0 + 2)); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum00) - : "0"(sum00), "r"(_val02), "r"(_w02) - :); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum01) - : "0"(sum01), "r"(_val13), "r"(_w02) - :); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum10) - : "0"(sum10), "r"(_val02), "r"(_w13) - :); - asm volatile("smlad %0, %2, %3, %0" - : "=r"(sum11) - : "0"(sum11), "r"(_val13), "r"(_w13) - :); - r0 += 4; - k0 += 4; -#endif - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val0 = r0[0]; - signed short val1 = r0[1]; - - signed short w0 = k0[0]; - signed short w1 = k0[1]; - - sum00 += val0 * w0; - sum10 += val0 * w1; - sum01 += val1 * w0; - sum11 += val1 * w1; - - r0 += 2; - k0 += 2; - } - - output0_tm[0] = sum00; - output1_tm[0] = sum10; - output0_tm[1] = sum01; - output1_tm[1] = sum11; - output0_tm += 2; - output1_tm += 2; - } - for (; i < tiles; i++) - { - const short* r0 = bb2.row(i / 2 + i % 2); - const short* k0 = kernel0_tm.row(r); - - int sum0 = 0; - int sum1 = 0; - - int j = 0; -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val01 = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%1], #4 \n" // int16x2_t _w02 = *((int16x2_t*)k0); k0 += 2; - "ldr r4, [%1], #4 \n" // int16x2_t _w13 = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r3, %2 \n" // sum00 = __smlad(_val01, _w02, sum00); - "smlad %3, r2, r4, %3 \n" // sum01 = __smlad(_val01, _w02, sum01); - : "=r"(r0), - "=r"(k0), - "=r"(sum0), - "=r"(sum1) - : "0"(r0), - "1"(k0), - "2"(sum0), - "3"(sum1) - : "memory", "r2", "r3", "r4"); - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val = r0[0]; - - sum0 += val * k0[0]; - sum1 += val * k0[1]; - - r0 += 1; - k0 += 2; - } - - output0_tm[0] = sum0; - output1_tm[0] = sum1; - output0_tm += 1; - output1_tm += 1; - } - } - } -#endif // __ARM_NEON - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - -#if __ARM_NEON - const Mat kernel0_tm = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4); -#else - const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2); -#endif - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __ARM_NEON -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _val23 = vld1q_s16(r0 + 8); - int16x8_t _val45 = vld1q_s16(r0 + 16); - int16x8_t _val67 = vld1q_s16(r0 + 24); - int16x4_t _w0123 = vld1_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), _w0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), _w0123, 0); - - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val23), _w0123, 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val23), _w0123, 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val45), _w0123, 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val45), _w0123, 2); - - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val67), _w0123, 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val67), _w0123, 3); - - k0 += 4; - r0 += 32; - } - _sum0 = vaddq_s32(_sum0, _sum2); - _sum1 = vaddq_s32(_sum1, _sum3); - for (; j < inch; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x4_t _w0 = vld1_dup_s16(k0); - - _sum0 = vmlal_s16(_sum0, _w0, vget_low_s16(_val0)); - _sum1 = vmlal_s16(_sum1, _w0, vget_high_s16(_val0)); - - k0 += 1; - r0 += 8; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - output0_tm += 8; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - int j = 0; - for (; j + 3 < inch; j += 4) - { - int16x8_t _val01 = vld1q_s16(r0); - int16x8_t _val23 = vld1q_s16(r0 + 8); - int16x4_t _w0123 = vld1_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val01), _w0123, 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val01), _w0123, 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val23), _w0123, 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val23), _w0123, 3); - - k0 += 4; - r0 += 16; - } - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - _sum0 = vaddq_s32(_sum0, _sum2); - for (; j < inch; j++) - { - int16x4_t _val0 = vld1_s16(r0); - int16x4_t _w0 = vld1_dup_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val0, _w0); - - k0 += 1; - r0 += 4; - } - - vst1q_s32(output0_tm, _sum0); - output0_tm += 4; - } -#else - for (; i + 1 < tiles; i += 2) - { - const short* r0 = bb2.row(i / 2); - const short* k0 = kernel0_tm.row(r); - - int sum0 = 0; - int sum1 = 0; - - int j = 0; -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val02 = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%0], #4 \n" // int16x2_t _val13 = *((int16x2_t*)r0); r0 += 2; - "ldr r4, [%1], #4 \n" // int16x2_t _w01 = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_val02, _w01, sum00); - "smlad %3, r3, r4, %3 \n" // sum01 = __smlad(_val13, _w01, sum01); - : "=r"(r0), - "=r"(k0), - "=r"(sum0), - "=r"(sum1) - : "0"(r0), - "1"(k0), - "2"(sum0), - "3"(sum1) - : "memory", "r2", "r3", "r4"); - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val0 = r0[0]; - signed short val1 = r0[1]; - signed short w = k0[0]; - - sum0 += val0 * w; - sum1 += val1 * w; - - k0 += 1; - r0 += 2; - } - - output0_tm[0] = sum0; - output0_tm[1] = sum1; - output0_tm += 2; - } -#endif - for (; i < tiles; i++) - { -#if __ARM_NEON -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int sum = 0; - - int j = 0; -#if __ARM_NEON - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - for (; j + 7 < inch; j += 8) - { - int16x8_t _val = vld1q_s16(r0); - int16x8_t _w = vld1q_s16(k0); - - _sum0 = vmlal_s16(_sum0, vget_low_s16(_val), vget_low_s16(_w)); - _sum1 = vmlal_s16(_sum1, vget_high_s16(_val), vget_high_s16(_w)); - - k0 += 8; - r0 += 8; - } - _sum0 = vaddq_s32(_sum0, _sum1); - for (; j + 3 < inch; j += 4) - { - int16x4_t _val = vld1_s16(r0); - int16x4_t _w = vld1_s16(k0); - - _sum0 = vmlal_s16(_sum0, _val, _w); - - k0 += 4; - r0 += 4; - } -#if __aarch64__ - sum = vaddvq_s32(_sum0); -#else - int32x2_t _ss = vadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); - _ss = vpadd_s32(_ss, _ss); - - sum = vget_lane_s32(_ss, 0); -#endif -#endif // __ARM_NEON -#if __ARM_FEATURE_SIMD32 - for (; j + 1 < inch; j += 2) - { - asm volatile( - "ldr r2, [%0], #4 \n" // int16x2_t _val = *((int16x2_t*)r0); r0 += 2; - "ldr r3, [%1], #4 \n" // int16x2_t _w = *((int16x2_t*)k0); k0 += 2; - "smlad %2, r2, r3, %2 \n" // sum = __smlad(_val, _w, sum); - : "=r"(r0), - "=r"(k0), - "=r"(sum) - : "0"(r0), - "1"(k0), - "2"(sum) - : "memory", "r2", "r3"); - } -#endif // __ARM_FEATURE_SIMD32 - for (; j < inch; j++) - { - signed short val = r0[0]; - signed short w = k0[0]; - - sum += val * w; - - k0 += 1; - r0 += 1; - } - - output0_tm[0] = sum; - output0_tm++; - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_dot_pack8to1_int8.h b/src/layer/arm/convolution_winograd_dot_pack8to1_int8.h deleted file mode 100644 index 6192be128465..000000000000 --- a/src/layer/arm/convolution_winograd_dot_pack8to1_int8.h +++ /dev/null @@ -1,774 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void convolution_winograd_dot_pack8to1_int8_neon(Mat& bottom_blob_tm, int outch, const Mat& kernel_tm, Mat& top_blob_tm, const Option& opt) -{ - // Mat bottom_blob_tm(tiles, 16/36/64, inch, 16u, 8, opt.workspace_allocator); - - const int tiles = bottom_blob_tm.w; - const int batch = bottom_blob_tm.h; - const int inch = bottom_blob_tm.c; - - // permute - Mat bottom_blob_tm2; -#if __aarch64__ - if (tiles >= 8) - bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#else - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#endif // __aarch64__ - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < batch; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - short* tm2p = tm2.row(i / 8); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 8x8 - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" - "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n" - "sub %0, %0, #64 \n" - - "uzp1 v16.8h, v0.8h, v4.8h \n" - "uzp2 v20.8h, v0.8h, v4.8h \n" - "uzp1 v17.8h, v1.8h, v5.8h \n" - "uzp2 v21.8h, v1.8h, v5.8h \n" - "uzp1 v18.8h, v2.8h, v6.8h \n" - "uzp2 v22.8h, v2.8h, v6.8h \n" - "uzp1 v19.8h, v3.8h, v7.8h \n" - "uzp2 v23.8h, v3.8h, v7.8h \n" - - "st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n" - "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); - - r0 += bottom_blob_tm.cstep * 8; - } - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - short* tm2p = tm2.row(i / 8 + (i % 8) / 4); -#else - short* tm2p = tm2.row(i / 4); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 8x4 -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n" - "st4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0", "v1", "v2", "v3"); -#else - asm volatile( - "pld [%0, #512] \n" - "vldm %0, {d0-d7} \n" - "vswp d1, d2 \n" - "vswp d5, d6 \n" - "vswp q1, q2 \n" - "vst4.s16 {d0-d3}, [%1 :64]! \n" - "vst4.s16 {d4-d7}, [%1 :64]! \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "q0", "q1", "q2", "q3"); -#endif // __aarch64__ - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i < tiles; i++) - { -#if __aarch64__ - short* tm2p = tm2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - short* tm2p = tm2.row(i / 4 + i % 4); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #128] \n" - "ld1 {v0.8h}, [%0] \n" - "st1 {v0.8h}, [%1], #16 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0"); -#else - asm volatile( - "pld [%0, #128] \n" - "vld1.s16 {d0-d1}, [%0 :64] \n" - "vst1.s16 {d0-d1}, [%1 :64]! \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "q0"); -#endif // __aarch64__ - r0 += bottom_blob_tm.cstep * 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, batch, outch, 4u, 1, opt.workspace_allocator); - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 8; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - int* output4_tm = top_blob_tm.channel(p + 4); - int* output5_tm = top_blob_tm.channel(p + 5); - int* output6_tm = top_blob_tm.channel(p + 6); - int* output7_tm = top_blob_tm.channel(p + 7); - - const Mat kernel01_tm = kernel_tm.channel(p / 8); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - const short* kptr = kernel01_tm.row(r); - - int nn = inch; // inch always > 0 - - asm volatile( - "eor v16.16b, v16.16b, v16.16b \n" - "eor v17.16b, v17.16b, v17.16b \n" - "eor v18.16b, v18.16b, v18.16b \n" - "eor v19.16b, v19.16b, v19.16b \n" - "eor v20.16b, v20.16b, v20.16b \n" - "eor v21.16b, v21.16b, v21.16b \n" - "eor v22.16b, v22.16b, v22.16b \n" - "eor v23.16b, v23.16b, v23.16b \n" - "eor v24.16b, v24.16b, v24.16b \n" - "eor v25.16b, v25.16b, v25.16b \n" - "eor v26.16b, v26.16b, v26.16b \n" - "eor v27.16b, v27.16b, v27.16b \n" - "eor v28.16b, v28.16b, v28.16b \n" - "eor v29.16b, v29.16b, v29.16b \n" - "eor v30.16b, v30.16b, v30.16b \n" - "eor v31.16b, v31.16b, v31.16b \n" - - "0: \n" - - "prfm pldl1keep, [%9, #512] \n" - "ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [%9], #64 \n" - - "prfm pldl1keep, [%10, #512] \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%10], #64 \n" - - "smlal v16.4s, v8.4h, v0.h[0] \n" - "smlal2 v17.4s, v8.8h, v0.h[0] \n" - "smlal v18.4s, v8.4h, v0.h[1] \n" - "smlal2 v19.4s, v8.8h, v0.h[1] \n" - "smlal v20.4s, v8.4h, v0.h[2] \n" - "smlal2 v21.4s, v8.8h, v0.h[2] \n" - "smlal v22.4s, v8.4h, v0.h[3] \n" - "smlal2 v23.4s, v8.8h, v0.h[3] \n" - "smlal v24.4s, v8.4h, v0.h[4] \n" - "smlal2 v25.4s, v8.8h, v0.h[4] \n" - "smlal v26.4s, v8.4h, v0.h[5] \n" - "smlal2 v27.4s, v8.8h, v0.h[5] \n" - "smlal v28.4s, v8.4h, v0.h[6] \n" - "smlal2 v29.4s, v8.8h, v0.h[6] \n" - "smlal v30.4s, v8.4h, v0.h[7] \n" - "smlal2 v31.4s, v8.8h, v0.h[7] \n" - - "smlal v16.4s, v9.4h, v1.h[0] \n" - "smlal2 v17.4s, v9.8h, v1.h[0] \n" - "smlal v18.4s, v9.4h, v1.h[1] \n" - "smlal2 v19.4s, v9.8h, v1.h[1] \n" - "smlal v20.4s, v9.4h, v1.h[2] \n" - "smlal2 v21.4s, v9.8h, v1.h[2] \n" - "smlal v22.4s, v9.4h, v1.h[3] \n" - "smlal2 v23.4s, v9.8h, v1.h[3] \n" - "smlal v24.4s, v9.4h, v1.h[4] \n" - "smlal2 v25.4s, v9.8h, v1.h[4] \n" - "smlal v26.4s, v9.4h, v1.h[5] \n" - "smlal2 v27.4s, v9.8h, v1.h[5] \n" - "smlal v28.4s, v9.4h, v1.h[6] \n" - "smlal2 v29.4s, v9.8h, v1.h[6] \n" - "smlal v30.4s, v9.4h, v1.h[7] \n" - "smlal2 v31.4s, v9.8h, v1.h[7] \n" - - "prfm pldl1keep, [%9, #512] \n" - "ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [%9], #64 \n" - - "smlal v16.4s, v10.4h, v2.h[0] \n" - "smlal2 v17.4s, v10.8h, v2.h[0] \n" - "smlal v18.4s, v10.4h, v2.h[1] \n" - "smlal2 v19.4s, v10.8h, v2.h[1] \n" - "smlal v20.4s, v10.4h, v2.h[2] \n" - "smlal2 v21.4s, v10.8h, v2.h[2] \n" - "smlal v22.4s, v10.4h, v2.h[3] \n" - "smlal2 v23.4s, v10.8h, v2.h[3] \n" - "smlal v24.4s, v10.4h, v2.h[4] \n" - "smlal2 v25.4s, v10.8h, v2.h[4] \n" - "smlal v26.4s, v10.4h, v2.h[5] \n" - "smlal2 v27.4s, v10.8h, v2.h[5] \n" - "smlal v28.4s, v10.4h, v2.h[6] \n" - "smlal2 v29.4s, v10.8h, v2.h[6] \n" - "smlal v30.4s, v10.4h, v2.h[7] \n" - "smlal2 v31.4s, v10.8h, v2.h[7] \n" - - "prfm pldl1keep, [%10, #512] \n" - "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%10], #64 \n" - - "smlal v16.4s, v11.4h, v3.h[0] \n" - "smlal2 v17.4s, v11.8h, v3.h[0] \n" - "smlal v18.4s, v11.4h, v3.h[1] \n" - "smlal2 v19.4s, v11.8h, v3.h[1] \n" - "smlal v20.4s, v11.4h, v3.h[2] \n" - "smlal2 v21.4s, v11.8h, v3.h[2] \n" - "smlal v22.4s, v11.4h, v3.h[3] \n" - "smlal2 v23.4s, v11.8h, v3.h[3] \n" - "smlal v24.4s, v11.4h, v3.h[4] \n" - "smlal2 v25.4s, v11.8h, v3.h[4] \n" - "smlal v26.4s, v11.4h, v3.h[5] \n" - "smlal2 v27.4s, v11.8h, v3.h[5] \n" - "smlal v28.4s, v11.4h, v3.h[6] \n" - "smlal2 v29.4s, v11.8h, v3.h[6] \n" - "smlal v30.4s, v11.4h, v3.h[7] \n" - "smlal2 v31.4s, v11.8h, v3.h[7] \n" - - "smlal v16.4s, v12.4h, v4.h[0] \n" - "smlal2 v17.4s, v12.8h, v4.h[0] \n" - "smlal v18.4s, v12.4h, v4.h[1] \n" - "smlal2 v19.4s, v12.8h, v4.h[1] \n" - "smlal v20.4s, v12.4h, v4.h[2] \n" - "smlal2 v21.4s, v12.8h, v4.h[2] \n" - "smlal v22.4s, v12.4h, v4.h[3] \n" - "smlal2 v23.4s, v12.8h, v4.h[3] \n" - "smlal v24.4s, v12.4h, v4.h[4] \n" - "smlal2 v25.4s, v12.8h, v4.h[4] \n" - "smlal v26.4s, v12.4h, v4.h[5] \n" - "smlal2 v27.4s, v12.8h, v4.h[5] \n" - "smlal v28.4s, v12.4h, v4.h[6] \n" - "smlal2 v29.4s, v12.8h, v4.h[6] \n" - "smlal v30.4s, v12.4h, v4.h[7] \n" - "smlal2 v31.4s, v12.8h, v4.h[7] \n" - - "smlal v16.4s, v13.4h, v5.h[0] \n" - "smlal2 v17.4s, v13.8h, v5.h[0] \n" - "smlal v18.4s, v13.4h, v5.h[1] \n" - "smlal2 v19.4s, v13.8h, v5.h[1] \n" - "smlal v20.4s, v13.4h, v5.h[2] \n" - "smlal2 v21.4s, v13.8h, v5.h[2] \n" - "smlal v22.4s, v13.4h, v5.h[3] \n" - "smlal2 v23.4s, v13.8h, v5.h[3] \n" - "smlal v24.4s, v13.4h, v5.h[4] \n" - "smlal2 v25.4s, v13.8h, v5.h[4] \n" - "smlal v26.4s, v13.4h, v5.h[5] \n" - "smlal2 v27.4s, v13.8h, v5.h[5] \n" - "smlal v28.4s, v13.4h, v5.h[6] \n" - "smlal2 v29.4s, v13.8h, v5.h[6] \n" - "smlal v30.4s, v13.4h, v5.h[7] \n" - "smlal2 v31.4s, v13.8h, v5.h[7] \n" - - "smlal v16.4s, v14.4h, v6.h[0] \n" - "smlal2 v17.4s, v14.8h, v6.h[0] \n" - "smlal v18.4s, v14.4h, v6.h[1] \n" - "smlal2 v19.4s, v14.8h, v6.h[1] \n" - "smlal v20.4s, v14.4h, v6.h[2] \n" - "smlal2 v21.4s, v14.8h, v6.h[2] \n" - "smlal v22.4s, v14.4h, v6.h[3] \n" - "smlal2 v23.4s, v14.8h, v6.h[3] \n" - "smlal v24.4s, v14.4h, v6.h[4] \n" - "smlal2 v25.4s, v14.8h, v6.h[4] \n" - "smlal v26.4s, v14.4h, v6.h[5] \n" - "smlal2 v27.4s, v14.8h, v6.h[5] \n" - "smlal v28.4s, v14.4h, v6.h[6] \n" - "smlal2 v29.4s, v14.8h, v6.h[6] \n" - "smlal v30.4s, v14.4h, v6.h[7] \n" - "smlal2 v31.4s, v14.8h, v6.h[7] \n" - - "subs %w0, %w0, #1 \n" - - "smlal v16.4s, v15.4h, v7.h[0] \n" - "smlal2 v17.4s, v15.8h, v7.h[0] \n" - "smlal v18.4s, v15.4h, v7.h[1] \n" - "smlal2 v19.4s, v15.8h, v7.h[1] \n" - "smlal v20.4s, v15.4h, v7.h[2] \n" - "smlal2 v21.4s, v15.8h, v7.h[2] \n" - "smlal v22.4s, v15.4h, v7.h[3] \n" - "smlal2 v23.4s, v15.8h, v7.h[3] \n" - "smlal v24.4s, v15.4h, v7.h[4] \n" - "smlal2 v25.4s, v15.8h, v7.h[4] \n" - "smlal v26.4s, v15.4h, v7.h[5] \n" - "smlal2 v27.4s, v15.8h, v7.h[5] \n" - "smlal v28.4s, v15.4h, v7.h[6] \n" - "smlal2 v29.4s, v15.8h, v7.h[6] \n" - "smlal v30.4s, v15.4h, v7.h[7] \n" - "smlal2 v31.4s, v15.8h, v7.h[7] \n" - - "bne 0b \n" - - "st1 {v16.4s, v17.4s}, [%1], #32 \n" - "st1 {v18.4s, v19.4s}, [%2], #32 \n" - "st1 {v20.4s, v21.4s}, [%3], #32 \n" - "st1 {v22.4s, v23.4s}, [%4], #32 \n" - "st1 {v24.4s, v25.4s}, [%5], #32 \n" - "st1 {v26.4s, v27.4s}, [%6], #32 \n" - "st1 {v28.4s, v29.4s}, [%7], #32 \n" - "st1 {v30.4s, v31.4s}, [%8], #32 \n" - - : "=r"(nn), // %0 - "=r"(output0_tm), // %1 - "=r"(output1_tm), // %2 - "=r"(output2_tm), // %3 - "=r"(output3_tm), // %4 - "=r"(output4_tm), // %5 - "=r"(output5_tm), // %6 - "=r"(output6_tm), // %7 - "=r"(output7_tm), // %8 - "=r"(r0), // %9 - "=r"(kptr) // %10 - : "0"(nn), - "1"(output0_tm), - "2"(output1_tm), - "3"(output2_tm), - "4"(output3_tm), - "5"(output4_tm), - "6"(output5_tm), - "7"(output6_tm), - "8"(output7_tm), - "9"(r0), - "10"(kptr) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel01_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val0), vget_low_s16(_w0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val0), vget_low_s16(_w0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val0), vget_low_s16(_w0), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val0), vget_low_s16(_w0), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val0), vget_high_s16(_w0), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val0), vget_high_s16(_w0), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val0), vget_high_s16(_w0), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val0), vget_high_s16(_w0), 3); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val0), vget_low_s16(_w1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val0), vget_low_s16(_w1), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val0), vget_low_s16(_w1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val0), vget_low_s16(_w1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val0), vget_high_s16(_w1), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val0), vget_high_s16(_w1), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val0), vget_high_s16(_w1), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val0), vget_high_s16(_w1), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val1), vget_low_s16(_w2), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val1), vget_low_s16(_w2), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val1), vget_low_s16(_w2), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val1), vget_low_s16(_w2), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val1), vget_high_s16(_w2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val1), vget_high_s16(_w2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val1), vget_high_s16(_w2), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val1), vget_high_s16(_w2), 3); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val1), vget_low_s16(_w3), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val1), vget_low_s16(_w3), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val1), vget_low_s16(_w3), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val1), vget_low_s16(_w3), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val1), vget_high_s16(_w3), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val1), vget_high_s16(_w3), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val1), vget_high_s16(_w3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val1), vget_high_s16(_w3), 3); - - int16x8_t _w4 = vld1q_s16(k0 + 32); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val2), vget_low_s16(_w4), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val2), vget_low_s16(_w4), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val2), vget_low_s16(_w4), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val2), vget_low_s16(_w4), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val2), vget_high_s16(_w4), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val2), vget_high_s16(_w4), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val2), vget_high_s16(_w4), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val2), vget_high_s16(_w4), 3); - - int16x8_t _w5 = vld1q_s16(k0 + 40); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val2), vget_low_s16(_w5), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val2), vget_low_s16(_w5), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val2), vget_low_s16(_w5), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val2), vget_low_s16(_w5), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val2), vget_high_s16(_w5), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val2), vget_high_s16(_w5), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val2), vget_high_s16(_w5), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val2), vget_high_s16(_w5), 3); - - int16x8_t _w6 = vld1q_s16(k0 + 48); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_val3), vget_low_s16(_w6), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_val3), vget_low_s16(_w6), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_val3), vget_low_s16(_w6), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_val3), vget_low_s16(_w6), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_val3), vget_high_s16(_w6), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_val3), vget_high_s16(_w6), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_val3), vget_high_s16(_w6), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_val3), vget_high_s16(_w6), 3); - - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_val3), vget_low_s16(_w7), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_val3), vget_low_s16(_w7), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_val3), vget_low_s16(_w7), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_val3), vget_low_s16(_w7), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_val3), vget_high_s16(_w7), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_val3), vget_high_s16(_w7), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_val3), vget_high_s16(_w7), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_val3), vget_high_s16(_w7), 3); - - r0 += 32; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output2_tm, _sum2); - vst1q_s32(output3_tm, _sum3); - vst1q_s32(output4_tm, _sum4); - vst1q_s32(output5_tm, _sum5); - vst1q_s32(output6_tm, _sum6); - vst1q_s32(output7_tm, _sum7); - - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - output4_tm += 4; - output5_tm += 4; - output6_tm += 4; - output7_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* k0 = kernel01_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - int16x8_t _w4 = vld1q_s16(k0 + 32); - int16x8_t _w5 = vld1q_s16(k0 + 40); - int16x8_t _w6 = vld1q_s16(k0 + 48); - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - - r0 += 8; - k0 += 64; - } - - output0_tm[0] = vgetq_lane_s32(_sum0, 0); - output1_tm[0] = vgetq_lane_s32(_sum0, 1); - output2_tm[0] = vgetq_lane_s32(_sum0, 2); - output3_tm[0] = vgetq_lane_s32(_sum0, 3); - output4_tm[0] = vgetq_lane_s32(_sum1, 0); - output5_tm[0] = vgetq_lane_s32(_sum1, 1); - output6_tm[0] = vgetq_lane_s32(_sum1, 2); - output7_tm[0] = vgetq_lane_s32(_sum1, 3); - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - output4_tm += 1; - output5_tm += 1; - output6_tm += 1; - output7_tm += 1; - } - } - } - - remain_outch_start += nn_outch << 3; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p / 8 + p % 8); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 8); - - const short* kptr = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - for (int q = 0; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - int16x8_t _r1 = vld1q_s16(r0 + 8); - int16x8_t _r2 = vld1q_s16(r0 + 16); - int16x8_t _r3 = vld1q_s16(r0 + 24); - int16x8_t _r4 = vld1q_s16(r0 + 32); - int16x8_t _r5 = vld1q_s16(r0 + 40); - int16x8_t _r6 = vld1q_s16(r0 + 48); - int16x8_t _r7 = vld1q_s16(r0 + 56); - - int16x8_t _k0 = vld1q_s16(kptr); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r0), vget_low_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r0), vget_low_s16(_k0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r1), vget_low_s16(_k0), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r1), vget_low_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r2), vget_low_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r2), vget_low_s16(_k0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r3), vget_low_s16(_k0), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r3), vget_low_s16(_k0), 3); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r4), vget_high_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r4), vget_high_s16(_k0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r5), vget_high_s16(_k0), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r5), vget_high_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r6), vget_high_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r6), vget_high_s16(_k0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_r7), vget_high_s16(_k0), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_r7), vget_high_s16(_k0), 3); - - kptr += 8; - r0 += 64; - } - - _sum0 = vaddq_s32(_sum0, _sum2); - _sum1 = vaddq_s32(_sum1, _sum3); - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - - output0_tm += 8; - } -#endif - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* kptr = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int q = 0; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - int16x8_t _r1 = vld1q_s16(r0 + 8); - int16x8_t _r2 = vld1q_s16(r0 + 16); - int16x8_t _r3 = vld1q_s16(r0 + 24); - - int16x8_t _k0 = vld1q_s16(kptr); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r0), vget_low_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r0), vget_low_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r1), vget_low_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r1), vget_low_s16(_k0), 3); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r2), vget_high_s16(_k0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r2), vget_high_s16(_k0), 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_r3), vget_high_s16(_k0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_r3), vget_high_s16(_k0), 3); - - kptr += 8; - r0 += 32; - } - - int32x4_t _sum01 = vaddq_s32(_sum0, _sum1); - - vst1q_s32(output0_tm, _sum01); - - output0_tm += 4; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 8 + (i % 8) / 4 + i % 4); -#else - const short* r0 = bb2.row(i / 4 + i % 4); -#endif - const short* kptr = kernel0_tm.row(r); - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int q = 0; q < inch; q++) - { - int16x8_t _r0 = vld1q_s16(r0); - - int16x8_t _k0 = vld1q_s16(kptr); - - _sum0 = vmlal_s16(_sum0, vget_low_s16(_r0), vget_low_s16(_k0)); - _sum1 = vmlal_s16(_sum1, vget_high_s16(_r0), vget_high_s16(_k0)); - - kptr += 8; - r0 += 8; - } - - int32x4_t _sum = vaddq_s32(_sum0, _sum1); -#if __aarch64__ - int sum = vaddvq_s32(_sum); // dot -#else - int32x2_t _ss = vadd_s32(vget_low_s32(_sum), vget_high_s32(_sum)); - _ss = vpadd_s32(_ss, _ss); - int sum = vget_lane_s32(_ss, 0); -#endif - - output0_tm[0] = sum; - - output0_tm++; - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_dot_pack8to4_int8.h b/src/layer/arm/convolution_winograd_dot_pack8to4_int8.h deleted file mode 100644 index a17559f6cc2f..000000000000 --- a/src/layer/arm/convolution_winograd_dot_pack8to4_int8.h +++ /dev/null @@ -1,1835 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void convolution_winograd_dot_pack8to4_int8_neon(Mat& bottom_blob_tm, int outch, const Mat& kernel_tm, Mat& top_blob_tm, const Option& opt) -{ - // Mat bottom_blob_tm(tiles, 16/36/64, inch, 16u, 8, opt.workspace_allocator); - - const int tiles = bottom_blob_tm.w; - const int batch = bottom_blob_tm.h; - const int inch = bottom_blob_tm.c; - - // permute - Mat bottom_blob_tm2; -#if __aarch64__ - if (tiles >= 12) - bottom_blob_tm2.create(12 * inch, tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 8) - bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#else - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, batch, 16u, 8, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 8, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < batch; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __aarch64__ - for (; i + 11 < tiles; i += 12) - { - short* tm2p = tm2.row(i / 12); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 12x8 - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" - "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n" - "ld4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%0] \n" - - "sub %0, %0, #128 \n" - - "uzp1 v20.8h, v0.8h, v4.8h \n" // 0 - "uzp1 v21.8h, v16.8h, v1.8h \n" // 1 - "uzp1 v22.8h, v5.8h, v17.8h \n" // 2 - "uzp1 v23.8h, v2.8h, v6.8h \n" // 3 - "uzp1 v24.8h, v18.8h, v3.8h \n" // 4 - "uzp1 v25.8h, v7.8h, v19.8h \n" // 5 - "uzp2 v26.8h, v0.8h, v4.8h \n" // 6 - "uzp2 v27.8h, v16.8h, v1.8h \n" // 7 - "uzp2 v28.8h, v5.8h, v17.8h \n" // 8 - "uzp2 v29.8h, v2.8h, v6.8h \n" // 9 - "uzp2 v30.8h, v18.8h, v3.8h \n" // 10 - "uzp2 v31.8h, v7.8h, v19.8h \n" // 11 - - "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" - "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [%1], #64 \n" - "st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tm2p) // %1 - : "0"(r0), - "1"(tm2p) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i + 7 < tiles; i += 8) - { - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - // transpose 8x8 - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n" - "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n" - "sub %0, %0, #64 \n" - - "uzp1 v16.8h, v0.8h, v4.8h \n" - "uzp2 v20.8h, v0.8h, v4.8h \n" - "uzp1 v17.8h, v1.8h, v5.8h \n" - "uzp2 v21.8h, v1.8h, v5.8h \n" - "uzp1 v18.8h, v2.8h, v6.8h \n" - "uzp2 v22.8h, v2.8h, v6.8h \n" - "uzp1 v19.8h, v3.8h, v7.8h \n" - "uzp2 v23.8h, v3.8h, v7.8h \n" - - "st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n" - "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); - - r0 += bottom_blob_tm.cstep * 8; - } - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); -#else - short* tmpptr = tm2.row(i / 4); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #512] \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n" - "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0", "v1", "v2", "v3"); -#else - asm volatile( - "pld [%0, #512] \n" - "vldm %0, {d0-d7} \n" - "vstm %1!, {d0-d7} \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "q0", "q1", "q2", "q3"); -#endif - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i + 1 < tiles; i += 2) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); -#else - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #256] \n" - "ld1 {v0.8h, v1.8h}, [%0] \n" - "st1 {v0.8h, v1.8h}, [%1], #32 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0", "v1"); -#else - asm volatile( - "pld [%0, #256] \n" - "vld1.s16 {d0-d3}, [%0 :128] \n" - "vst1.s16 {d0-d3}, [%1 :128]! \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "q0", "q1"); -#endif - r0 += bottom_blob_tm.cstep * 8; - } - } - for (; i < tiles; i++) - { -#if __aarch64__ - short* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); -#else - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { -#if __aarch64__ - asm volatile( - "prfm pldl1keep, [%0, #128] \n" - "ld1 {v0.8h}, [%0] \n" - "st1 {v0.8h}, [%1], #16 \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "v0"); -#else - asm volatile( - "pld [%0, #128] \n" - "vld1.s16 {d0-d1}, [%0 :128] \n" - "vst1.s16 {d0-d1}, [%1 :128]! \n" - : "=r"(r0), // %0 - "=r"(tmpptr) // %1 - : "0"(r0), - "1"(tmpptr) - : "memory", "q0"); -#endif - r0 += bottom_blob_tm.cstep * 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, batch, outch, 16u, 4, opt.workspace_allocator); - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 1; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 2; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - - const Mat kernel0_tm = kernel_tm.channel(p / 2); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 11 < tiles; i += 12) - { - const short* r0 = bb2.row(i / 12); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - asm volatile( - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r01 - - "eor v8.16b, v8.16b, v8.16b \n" - "eor v9.16b, v9.16b, v9.16b \n" - - "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w01 - - "eor v10.16b, v10.16b, v10.16b \n" - "eor v11.16b, v11.16b, v11.16b \n" - - "prfm pldl1keep, [%3, #256] \n" - - "eor v12.16b, v12.16b, v12.16b \n" - "eor v13.16b, v13.16b, v13.16b \n" - - "prfm pldl1keep, [%4, #256] \n" - - "eor v14.16b, v14.16b, v14.16b \n" - "eor v15.16b, v15.16b, v15.16b \n" - "eor v16.16b, v16.16b, v16.16b \n" - "eor v17.16b, v17.16b, v17.16b \n" - "eor v18.16b, v18.16b, v18.16b \n" - "eor v19.16b, v19.16b, v19.16b \n" - "eor v20.16b, v20.16b, v20.16b \n" - "eor v21.16b, v21.16b, v21.16b \n" - "eor v22.16b, v22.16b, v22.16b \n" - "eor v23.16b, v23.16b, v23.16b \n" - "eor v24.16b, v24.16b, v24.16b \n" - "eor v25.16b, v25.16b, v25.16b \n" - "eor v26.16b, v26.16b, v26.16b \n" - "eor v27.16b, v27.16b, v27.16b \n" - "eor v28.16b, v28.16b, v28.16b \n" - "eor v29.16b, v29.16b, v29.16b \n" - "eor v30.16b, v30.16b, v30.16b \n" - "eor v31.16b, v31.16b, v31.16b \n" - - "0: \n" - - "smlal v8.4s, v4.4h, v0.h[0] \n" - "smlal2 v20.4s, v4.8h, v0.h[0] \n" - "smlal v9.4s, v4.4h, v0.h[1] \n" - "smlal2 v21.4s, v4.8h, v0.h[1] \n" - "smlal v10.4s, v4.4h, v0.h[2] \n" - "smlal2 v22.4s, v4.8h, v0.h[2] \n" - "smlal v11.4s, v4.4h, v0.h[3] \n" - "smlal2 v23.4s, v4.8h, v0.h[3] \n" - "smlal v12.4s, v4.4h, v0.h[4] \n" - "smlal2 v24.4s, v4.8h, v0.h[4] \n" - "smlal v13.4s, v4.4h, v0.h[5] \n" - "smlal2 v25.4s, v4.8h, v0.h[5] \n" - "smlal v14.4s, v4.4h, v0.h[6] \n" - "smlal2 v26.4s, v4.8h, v0.h[6] \n" - "smlal v15.4s, v4.4h, v0.h[7] \n" - "smlal2 v27.4s, v4.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r23 - - "smlal v16.4s, v4.4h, v1.h[0] \n" - "smlal2 v28.4s, v4.8h, v1.h[0] \n" - "smlal v17.4s, v4.4h, v1.h[1] \n" - "smlal2 v29.4s, v4.8h, v1.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v18.4s, v4.4h, v1.h[2] \n" - "smlal2 v30.4s, v4.8h, v1.h[2] \n" - "smlal v19.4s, v4.4h, v1.h[3] \n" - "smlal2 v31.4s, v4.8h, v1.h[3] \n" - - "ld1 {v6.8h, v7.8h}, [%4], #32 \n" // w23 - - "smlal v8.4s, v5.4h, v1.h[4] \n" - "smlal2 v20.4s, v5.8h, v1.h[4] \n" - "smlal v9.4s, v5.4h, v1.h[5] \n" - "smlal2 v21.4s, v5.8h, v1.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v5.4h, v1.h[6] \n" - "smlal2 v22.4s, v5.8h, v1.h[6] \n" - "smlal v11.4s, v5.4h, v1.h[7] \n" - "smlal2 v23.4s, v5.8h, v1.h[7] \n" - "smlal v12.4s, v5.4h, v2.h[0] \n" - "smlal2 v24.4s, v5.8h, v2.h[0] \n" - "smlal v13.4s, v5.4h, v2.h[1] \n" - "smlal2 v25.4s, v5.8h, v2.h[1] \n" - "smlal v14.4s, v5.4h, v2.h[2] \n" - "smlal2 v26.4s, v5.8h, v2.h[2] \n" - "smlal v15.4s, v5.4h, v2.h[3] \n" - "smlal2 v27.4s, v5.8h, v2.h[3] \n" - "smlal v16.4s, v5.4h, v2.h[4] \n" - "smlal2 v28.4s, v5.8h, v2.h[4] \n" - "smlal v17.4s, v5.4h, v2.h[5] \n" - "smlal2 v29.4s, v5.8h, v2.h[5] \n" - "smlal v18.4s, v5.4h, v2.h[6] \n" - "smlal2 v30.4s, v5.8h, v2.h[6] \n" - "smlal v19.4s, v5.4h, v2.h[7] \n" - "smlal2 v31.4s, v5.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r45 - - "smlal v8.4s, v6.4h, v3.h[0] \n" - "smlal2 v20.4s, v6.8h, v3.h[0] \n" - "smlal v9.4s, v6.4h, v3.h[1] \n" - "smlal2 v21.4s, v6.8h, v3.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v10.4s, v6.4h, v3.h[2] \n" - "smlal2 v22.4s, v6.8h, v3.h[2] \n" - "smlal v11.4s, v6.4h, v3.h[3] \n" - "smlal2 v23.4s, v6.8h, v3.h[3] \n" - "smlal v12.4s, v6.4h, v3.h[4] \n" - "smlal2 v24.4s, v6.8h, v3.h[4] \n" - "smlal v13.4s, v6.4h, v3.h[5] \n" - "smlal2 v25.4s, v6.8h, v3.h[5] \n" - "smlal v14.4s, v6.4h, v3.h[6] \n" - "smlal2 v26.4s, v6.8h, v3.h[6] \n" - "smlal v15.4s, v6.4h, v3.h[7] \n" - "smlal2 v27.4s, v6.8h, v3.h[7] \n" - - "smlal v16.4s, v6.4h, v0.h[0] \n" - "smlal2 v28.4s, v6.8h, v0.h[0] \n" - "smlal v17.4s, v6.4h, v0.h[1] \n" - "smlal2 v29.4s, v6.8h, v0.h[1] \n" - "smlal v18.4s, v6.4h, v0.h[2] \n" - "smlal2 v30.4s, v6.8h, v0.h[2] \n" - "smlal v19.4s, v6.4h, v0.h[3] \n" - "smlal2 v31.4s, v6.8h, v0.h[3] \n" - - "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w45 - - "smlal v8.4s, v7.4h, v0.h[4] \n" - "smlal2 v20.4s, v7.8h, v0.h[4] \n" - "smlal v9.4s, v7.4h, v0.h[5] \n" - "smlal2 v21.4s, v7.8h, v0.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v7.4h, v0.h[6] \n" - "smlal2 v22.4s, v7.8h, v0.h[6] \n" - "smlal v11.4s, v7.4h, v0.h[7] \n" - "smlal2 v23.4s, v7.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r67 - - "smlal v12.4s, v7.4h, v1.h[0] \n" - "smlal2 v24.4s, v7.8h, v1.h[0] \n" - "smlal v13.4s, v7.4h, v1.h[1] \n" - "smlal2 v25.4s, v7.8h, v1.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v14.4s, v7.4h, v1.h[2] \n" - "smlal2 v26.4s, v7.8h, v1.h[2] \n" - "smlal v15.4s, v7.4h, v1.h[3] \n" - "smlal2 v27.4s, v7.8h, v1.h[3] \n" - "smlal v16.4s, v7.4h, v1.h[4] \n" - "smlal2 v28.4s, v7.8h, v1.h[4] \n" - "smlal v17.4s, v7.4h, v1.h[5] \n" - "smlal2 v29.4s, v7.8h, v1.h[5] \n" - "smlal v18.4s, v7.4h, v1.h[6] \n" - "smlal2 v30.4s, v7.8h, v1.h[6] \n" - "smlal v19.4s, v7.4h, v1.h[7] \n" - "smlal2 v31.4s, v7.8h, v1.h[7] \n" - - "smlal v8.4s, v4.4h, v2.h[0] \n" - "smlal2 v20.4s, v4.8h, v2.h[0] \n" - "smlal v9.4s, v4.4h, v2.h[1] \n" - "smlal2 v21.4s, v4.8h, v2.h[1] \n" - "smlal v10.4s, v4.4h, v2.h[2] \n" - "smlal2 v22.4s, v4.8h, v2.h[2] \n" - "smlal v11.4s, v4.4h, v2.h[3] \n" - "smlal2 v23.4s, v4.8h, v2.h[3] \n" - "smlal v12.4s, v4.4h, v2.h[4] \n" - "smlal2 v24.4s, v4.8h, v2.h[4] \n" - "smlal v13.4s, v4.4h, v2.h[5] \n" - "smlal2 v25.4s, v4.8h, v2.h[5] \n" - "smlal v14.4s, v4.4h, v2.h[6] \n" - "smlal2 v26.4s, v4.8h, v2.h[6] \n" - "smlal v15.4s, v4.4h, v2.h[7] \n" - "smlal2 v27.4s, v4.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r89 - - "smlal v16.4s, v4.4h, v3.h[0] \n" - "smlal2 v28.4s, v4.8h, v3.h[0] \n" - "smlal v17.4s, v4.4h, v3.h[1] \n" - "smlal2 v29.4s, v4.8h, v3.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v18.4s, v4.4h, v3.h[2] \n" - "smlal2 v30.4s, v4.8h, v3.h[2] \n" - "smlal v19.4s, v4.4h, v3.h[3] \n" - "smlal2 v31.4s, v4.8h, v3.h[3] \n" - - "ld1 {v6.8h, v7.8h}, [%4], #32 \n" // w67 - - "smlal v8.4s, v5.4h, v3.h[4] \n" - "smlal2 v20.4s, v5.8h, v3.h[4] \n" - "smlal v9.4s, v5.4h, v3.h[5] \n" - "smlal2 v21.4s, v5.8h, v3.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v5.4h, v3.h[6] \n" - "smlal2 v22.4s, v5.8h, v3.h[6] \n" - "smlal v11.4s, v5.4h, v3.h[7] \n" - "smlal2 v23.4s, v5.8h, v3.h[7] \n" - - "smlal v12.4s, v5.4h, v0.h[0] \n" - "smlal2 v24.4s, v5.8h, v0.h[0] \n" - "smlal v13.4s, v5.4h, v0.h[1] \n" - "smlal2 v25.4s, v5.8h, v0.h[1] \n" - "smlal v14.4s, v5.4h, v0.h[2] \n" - "smlal2 v26.4s, v5.8h, v0.h[2] \n" - "smlal v15.4s, v5.4h, v0.h[3] \n" - "smlal2 v27.4s, v5.8h, v0.h[3] \n" - "smlal v16.4s, v5.4h, v0.h[4] \n" - "smlal2 v28.4s, v5.8h, v0.h[4] \n" - "smlal v17.4s, v5.4h, v0.h[5] \n" - "smlal2 v29.4s, v5.8h, v0.h[5] \n" - "smlal v18.4s, v5.4h, v0.h[6] \n" - "smlal2 v30.4s, v5.8h, v0.h[6] \n" - "smlal v19.4s, v5.4h, v0.h[7] \n" - "smlal2 v31.4s, v5.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%3], #32 \n" // r1011 - - "smlal v8.4s, v6.4h, v1.h[0] \n" - "smlal2 v20.4s, v6.8h, v1.h[0] \n" - "smlal v9.4s, v6.4h, v1.h[1] \n" - "smlal2 v21.4s, v6.8h, v1.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v10.4s, v6.4h, v1.h[2] \n" - "smlal2 v22.4s, v6.8h, v1.h[2] \n" - "smlal v11.4s, v6.4h, v1.h[3] \n" - "smlal2 v23.4s, v6.8h, v1.h[3] \n" - "smlal v12.4s, v6.4h, v1.h[4] \n" - "smlal2 v24.4s, v6.8h, v1.h[4] \n" - "smlal v13.4s, v6.4h, v1.h[5] \n" - "smlal2 v25.4s, v6.8h, v1.h[5] \n" - "smlal v14.4s, v6.4h, v1.h[6] \n" - "smlal2 v26.4s, v6.8h, v1.h[6] \n" - "smlal v15.4s, v6.4h, v1.h[7] \n" - "smlal2 v27.4s, v6.8h, v1.h[7] \n" - "smlal v16.4s, v6.4h, v2.h[0] \n" - "smlal2 v28.4s, v6.8h, v2.h[0] \n" - "smlal v17.4s, v6.4h, v2.h[1] \n" - "smlal2 v29.4s, v6.8h, v2.h[1] \n" - "smlal v18.4s, v6.4h, v2.h[2] \n" - "smlal2 v30.4s, v6.8h, v2.h[2] \n" - "smlal v19.4s, v6.4h, v2.h[3] \n" - "smlal2 v31.4s, v6.8h, v2.h[3] \n" - - "ld1 {v4.8h, v5.8h}, [%4], #32 \n" // w01 - - "smlal v8.4s, v7.4h, v2.h[4] \n" - "smlal2 v20.4s, v7.8h, v2.h[4] \n" - "smlal v9.4s, v7.4h, v2.h[5] \n" - "smlal2 v21.4s, v7.8h, v2.h[5] \n" - - "prfm pldl1keep, [%4, #256] \n" - - "smlal v10.4s, v7.4h, v2.h[6] \n" - "smlal2 v22.4s, v7.8h, v2.h[6] \n" - "smlal v11.4s, v7.4h, v2.h[7] \n" - "smlal2 v23.4s, v7.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%3], #32 \n" // r01 - - "smlal v12.4s, v7.4h, v3.h[0] \n" - "smlal2 v24.4s, v7.8h, v3.h[0] \n" - "smlal v13.4s, v7.4h, v3.h[1] \n" - "smlal2 v25.4s, v7.8h, v3.h[1] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal v14.4s, v7.4h, v3.h[2] \n" - "smlal2 v26.4s, v7.8h, v3.h[2] \n" - "smlal v15.4s, v7.4h, v3.h[3] \n" - "smlal2 v27.4s, v7.8h, v3.h[3] \n" - "smlal v16.4s, v7.4h, v3.h[4] \n" - "smlal2 v28.4s, v7.8h, v3.h[4] \n" - "smlal v17.4s, v7.4h, v3.h[5] \n" - "smlal2 v29.4s, v7.8h, v3.h[5] \n" - - "subs %w0, %w0, #1 \n" - - "smlal v18.4s, v7.4h, v3.h[6] \n" - "smlal2 v30.4s, v7.8h, v3.h[6] \n" - "smlal v19.4s, v7.4h, v3.h[7] \n" - "smlal2 v31.4s, v7.8h, v3.h[7] \n" - - "bne 0b \n" - - "sub %3, %3, #32 \n" - "sub %4, %4, #32 \n" - - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" - "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%2], #64 \n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n" - "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%2], #64 \n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n" - "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%2], #64 \n" - - : "=r"(nn), // %0 - "=r"(output0_tm), // %1 - "=r"(output1_tm), // %2 - "=r"(r0), // %3 - "=r"(k0) // %4 - : "0"(nn), - "1"(output0_tm), - "2"(output1_tm), - "3"(r0), - "4"(k0) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 12 + (i % 12) / 8); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - int32x4_t _sum8 = vdupq_n_s32(0); - int32x4_t _sum9 = vdupq_n_s32(0); - int32x4_t _suma = vdupq_n_s32(0); - int32x4_t _sumb = vdupq_n_s32(0); - int32x4_t _sumc = vdupq_n_s32(0); - int32x4_t _sumd = vdupq_n_s32(0); - int32x4_t _sume = vdupq_n_s32(0); - int32x4_t _sumf = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - int16x8_t _val4 = vld1q_s16(r0 + 32); - int16x8_t _val5 = vld1q_s16(r0 + 40); - int16x8_t _val6 = vld1q_s16(r0 + 48); - int16x8_t _val7 = vld1q_s16(r0 + 56); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val0), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val0), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_low_s16(_val0), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_low_s16(_val0), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_low_s16(_val0), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_low_s16(_val0), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w0), vget_high_s16(_val0), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w0), vget_high_s16(_val0), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w0), vget_high_s16(_val0), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w0), vget_high_s16(_val0), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w0), vget_high_s16(_val0), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w0), vget_high_s16(_val0), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w0), vget_high_s16(_val0), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w0), vget_high_s16(_val0), 3); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val1), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_low_s16(_val1), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_low_s16(_val1), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_low_s16(_val1), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_low_s16(_val1), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w1), vget_high_s16(_val1), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w1), vget_high_s16(_val1), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w1), vget_high_s16(_val1), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w1), vget_high_s16(_val1), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w1), vget_high_s16(_val1), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w1), vget_high_s16(_val1), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w1), vget_high_s16(_val1), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w1), vget_high_s16(_val1), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val2), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val2), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val2), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val2), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_low_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_low_s16(_val2), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_low_s16(_val2), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_low_s16(_val2), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w2), vget_high_s16(_val2), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w2), vget_high_s16(_val2), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w2), vget_high_s16(_val2), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w2), vget_high_s16(_val2), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w2), vget_high_s16(_val2), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w2), vget_high_s16(_val2), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w2), vget_high_s16(_val2), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w2), vget_high_s16(_val2), 3); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val3), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val3), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val3), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val3), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_low_s16(_val3), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_low_s16(_val3), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_low_s16(_val3), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_low_s16(_val3), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w3), vget_high_s16(_val3), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w3), vget_high_s16(_val3), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w3), vget_high_s16(_val3), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w3), vget_high_s16(_val3), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w3), vget_high_s16(_val3), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w3), vget_high_s16(_val3), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w3), vget_high_s16(_val3), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w3), vget_high_s16(_val3), 3); - - int16x8_t _w4 = vld1q_s16(k0 + 32); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_low_s16(_val4), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_low_s16(_val4), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w4), vget_low_s16(_val4), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w4), vget_low_s16(_val4), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w4), vget_low_s16(_val4), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w4), vget_low_s16(_val4), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w4), vget_low_s16(_val4), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w4), vget_low_s16(_val4), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w4), vget_high_s16(_val4), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w4), vget_high_s16(_val4), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w4), vget_high_s16(_val4), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w4), vget_high_s16(_val4), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w4), vget_high_s16(_val4), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w4), vget_high_s16(_val4), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w4), vget_high_s16(_val4), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w4), vget_high_s16(_val4), 3); - - int16x8_t _w5 = vld1q_s16(k0 + 40); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_low_s16(_val5), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_low_s16(_val5), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w5), vget_low_s16(_val5), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w5), vget_low_s16(_val5), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w5), vget_low_s16(_val5), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w5), vget_low_s16(_val5), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w5), vget_low_s16(_val5), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w5), vget_low_s16(_val5), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w5), vget_high_s16(_val5), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w5), vget_high_s16(_val5), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w5), vget_high_s16(_val5), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w5), vget_high_s16(_val5), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w5), vget_high_s16(_val5), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w5), vget_high_s16(_val5), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w5), vget_high_s16(_val5), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w5), vget_high_s16(_val5), 3); - - int16x8_t _w6 = vld1q_s16(k0 + 48); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_low_s16(_val6), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_low_s16(_val6), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w6), vget_low_s16(_val6), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w6), vget_low_s16(_val6), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w6), vget_low_s16(_val6), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w6), vget_low_s16(_val6), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w6), vget_low_s16(_val6), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w6), vget_low_s16(_val6), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w6), vget_high_s16(_val6), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w6), vget_high_s16(_val6), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w6), vget_high_s16(_val6), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w6), vget_high_s16(_val6), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w6), vget_high_s16(_val6), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w6), vget_high_s16(_val6), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w6), vget_high_s16(_val6), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w6), vget_high_s16(_val6), 3); - - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_low_s16(_val7), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_low_s16(_val7), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w7), vget_low_s16(_val7), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w7), vget_low_s16(_val7), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w7), vget_low_s16(_val7), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w7), vget_low_s16(_val7), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w7), vget_low_s16(_val7), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w7), vget_low_s16(_val7), 3); - _sum8 = vmlal_lane_s16(_sum8, vget_low_s16(_w7), vget_high_s16(_val7), 0); - _sum9 = vmlal_lane_s16(_sum9, vget_high_s16(_w7), vget_high_s16(_val7), 0); - _suma = vmlal_lane_s16(_suma, vget_low_s16(_w7), vget_high_s16(_val7), 1); - _sumb = vmlal_lane_s16(_sumb, vget_high_s16(_w7), vget_high_s16(_val7), 1); - _sumc = vmlal_lane_s16(_sumc, vget_low_s16(_w7), vget_high_s16(_val7), 2); - _sumd = vmlal_lane_s16(_sumd, vget_high_s16(_w7), vget_high_s16(_val7), 2); - _sume = vmlal_lane_s16(_sume, vget_low_s16(_w7), vget_high_s16(_val7), 3); - _sumf = vmlal_lane_s16(_sumf, vget_high_s16(_w7), vget_high_s16(_val7), 3); - - r0 += 64; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output1_tm + 4, _sum3); - vst1q_s32(output0_tm + 8, _sum4); - vst1q_s32(output1_tm + 8, _sum5); - vst1q_s32(output0_tm + 12, _sum6); - vst1q_s32(output1_tm + 12, _sum7); - vst1q_s32(output0_tm + 16, _sum8); - vst1q_s32(output1_tm + 16, _sum9); - vst1q_s32(output0_tm + 20, _suma); - vst1q_s32(output1_tm + 20, _sumb); - vst1q_s32(output0_tm + 24, _sumc); - vst1q_s32(output1_tm + 24, _sumd); - vst1q_s32(output0_tm + 28, _sume); - vst1q_s32(output1_tm + 28, _sumf); - output0_tm += 32; - output1_tm += 32; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __aarch64__ - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 0); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_low_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_low_s16(_val2), 0); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_low_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_low_s16(_val3), 0); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_low_s16(_val2), 1); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_low_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_low_s16(_val3), 1); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_low_s16(_val3), 1); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val1), 2); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_low_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_low_s16(_val2), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_low_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_low_s16(_val3), 2); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_low_s16(_val2), 3); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_low_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_low_s16(_val3), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_low_s16(_val3), 3); - - int16x8_t _w4 = vld1q_s16(k0 + 32); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w4), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w4), vget_high_s16(_val1), 0); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w4), vget_high_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w4), vget_high_s16(_val2), 0); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w4), vget_high_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w4), vget_high_s16(_val3), 0); - - int16x8_t _w5 = vld1q_s16(k0 + 40); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w5), vget_high_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w5), vget_high_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w5), vget_high_s16(_val2), 1); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w5), vget_high_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w5), vget_high_s16(_val3), 1); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w5), vget_high_s16(_val3), 1); - - int16x8_t _w6 = vld1q_s16(k0 + 48); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w6), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w6), vget_high_s16(_val1), 2); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w6), vget_high_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w6), vget_high_s16(_val2), 2); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w6), vget_high_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w6), vget_high_s16(_val3), 2); - - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w7), vget_high_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w7), vget_high_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w7), vget_high_s16(_val2), 3); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w7), vget_high_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w7), vget_high_s16(_val3), 3); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w7), vget_high_s16(_val3), 3); - - r0 += 32; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output1_tm + 4, _sum3); - vst1q_s32(output0_tm + 8, _sum4); - vst1q_s32(output1_tm + 8, _sum5); - vst1q_s32(output0_tm + 12, _sum6); - vst1q_s32(output1_tm + 12, _sum7); - output0_tm += 16; - output1_tm += 16; -#else - asm volatile( - "veor q8, q8 \n" - "veor q9, q9 \n" - "veor q10, q10 \n" - "veor q11, q11 \n" - "veor q12, q12 \n" - "veor q13, q13 \n" - "veor q14, q14 \n" - "veor q15, q15 \n" - - "0: \n" - - "pld [%3, #256] \n" - "pld [%3, #512] \n" - "vldm %3!, {d0-d7} \n" - - "pld [%4, #256] \n" - "vld1.s16 {d8-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d0[0] \n" - "vmlal.s16 q12, d9, d0[0] \n" - "vmlal.s16 q9, d8, d2[0] \n" - "vmlal.s16 q13, d9, d2[0] \n" - "vmlal.s16 q10, d8, d4[0] \n" - "vmlal.s16 q14, d9, d4[0] \n" - "vmlal.s16 q11, d8, d6[0] \n" - "vmlal.s16 q15, d9, d6[0] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d8-d9}, [%4 :128]! \n" - - "vmlal.s16 q8, d10, d0[1] \n" - "vmlal.s16 q12, d11, d0[1] \n" - "vmlal.s16 q9, d10, d2[1] \n" - "vmlal.s16 q13, d11, d2[1] \n" - "vmlal.s16 q10, d10, d4[1] \n" - "vmlal.s16 q14, d11, d4[1] \n" - "vmlal.s16 q11, d10, d6[1] \n" - "vmlal.s16 q15, d11, d6[1] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d10-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d0[2] \n" - "vmlal.s16 q12, d9, d0[2] \n" - "vmlal.s16 q9, d8, d2[2] \n" - "vmlal.s16 q13, d9, d2[2] \n" - "vmlal.s16 q10, d8, d4[2] \n" - "vmlal.s16 q14, d9, d4[2] \n" - "vmlal.s16 q11, d8, d6[2] \n" - "vmlal.s16 q15, d9, d6[2] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d8-d9}, [%4 :128]! \n" - - "vmlal.s16 q8, d10, d0[3] \n" - "vmlal.s16 q12, d11, d0[3] \n" - "vmlal.s16 q9, d10, d2[3] \n" - "vmlal.s16 q13, d11, d2[3] \n" - "vmlal.s16 q10, d10, d4[3] \n" - "vmlal.s16 q14, d11, d4[3] \n" - "vmlal.s16 q11, d10, d6[3] \n" - "vmlal.s16 q15, d11, d6[3] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d10-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d1[0] \n" - "vmlal.s16 q12, d9, d1[0] \n" - "vmlal.s16 q9, d8, d3[0] \n" - "vmlal.s16 q13, d9, d3[0] \n" - "vmlal.s16 q10, d8, d5[0] \n" - "vmlal.s16 q14, d9, d5[0] \n" - "vmlal.s16 q11, d8, d7[0] \n" - "vmlal.s16 q15, d9, d7[0] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d8-d9}, [%4 :128]! \n" - - "vmlal.s16 q8, d10, d1[1] \n" - "vmlal.s16 q12, d11, d1[1] \n" - "vmlal.s16 q9, d10, d3[1] \n" - "vmlal.s16 q13, d11, d3[1] \n" - "vmlal.s16 q10, d10, d5[1] \n" - "vmlal.s16 q14, d11, d5[1] \n" - "vmlal.s16 q11, d10, d7[1] \n" - "vmlal.s16 q15, d11, d7[1] \n" - - "pld [%4, #128] \n" - "vld1.s16 {d10-d11}, [%4 :128]! \n" - - "vmlal.s16 q8, d8, d1[2] \n" - "vmlal.s16 q12, d9, d1[2] \n" - "vmlal.s16 q9, d8, d3[2] \n" - "vmlal.s16 q13, d9, d3[2] \n" - "vmlal.s16 q10, d8, d5[2] \n" - "vmlal.s16 q14, d9, d5[2] \n" - "vmlal.s16 q11, d8, d7[2] \n" - "vmlal.s16 q15, d9, d7[2] \n" - - "subs %0, %0, #1 \n" - - "vmlal.s16 q8, d10, d1[3] \n" - "vmlal.s16 q12, d11, d1[3] \n" - "vmlal.s16 q9, d10, d3[3] \n" - "vmlal.s16 q13, d11, d3[3] \n" - "vmlal.s16 q10, d10, d5[3] \n" - "vmlal.s16 q14, d11, d5[3] \n" - "vmlal.s16 q11, d10, d7[3] \n" - "vmlal.s16 q15, d11, d7[3] \n" - - "bne 0b \n" - - "vstm %1!, {d16-d23} \n" - "vstm %2!, {d24-d31} \n" - - : "=r"(nn), - "=r"(output0_tm), - "=r"(output1_tm), - "=r"(r0), - "=r"(k0) - : "0"(nn), - "1"(output0_tm), - "2"(output1_tm), - "3"(r0), - "4"(k0) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#endif - } - for (; i + 1 < tiles; i += 2) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - int16x8_t _w4 = vld1q_s16(k0 + 32); - int16x8_t _w5 = vld1q_s16(k0 + 40); - int16x8_t _w6 = vld1q_s16(k0 + 48); - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val1), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val1), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w4), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w4), vget_high_s16(_val1), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w5), vget_high_s16(_val1), 1); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w5), vget_high_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w6), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w6), vget_high_s16(_val1), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w7), vget_high_s16(_val1), 3); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w7), vget_high_s16(_val1), 3); - - r0 += 16; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output1_tm + 4, _sum3); - output0_tm += 8; - output1_tm += 8; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - int16x8_t _w4 = vld1q_s16(k0 + 32); - int16x8_t _w5 = vld1q_s16(k0 + 40); - int16x8_t _w6 = vld1q_s16(k0 + 48); - int16x8_t _w7 = vld1q_s16(k0 + 56); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w4), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w4), vget_high_s16(_val0), 0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w5), vget_high_s16(_val0), 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w5), vget_high_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w6), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w6), vget_high_s16(_val0), 2); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w7), vget_high_s16(_val0), 3); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w7), vget_high_s16(_val0), 3); - - r0 += 8; - k0 += 64; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output1_tm, _sum1); - output0_tm += 4; - output1_tm += 4; - } - } - } - - remain_outch_start += nn_outch << 1; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2); - - for (int r = 0; r < batch; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __aarch64__ - for (; i + 11 < tiles; i += 12) - { - const short* r0 = bb2.row(i / 12); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - asm volatile( - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 - - "eor v8.16b, v8.16b, v8.16b \n" - "eor v9.16b, v9.16b, v9.16b \n" - - "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 - - "eor v10.16b, v10.16b, v10.16b \n" - "eor v11.16b, v11.16b, v11.16b \n" - - "prfm pldl1keep, [%2, #256] \n" - - "eor v12.16b, v12.16b, v12.16b \n" - "eor v13.16b, v13.16b, v13.16b \n" - - "prfm pldl1keep, [%3, #256] \n" - - "eor v14.16b, v14.16b, v14.16b \n" - "eor v15.16b, v15.16b, v15.16b \n" - "eor v16.16b, v16.16b, v16.16b \n" - "eor v17.16b, v17.16b, v17.16b \n" - "eor v18.16b, v18.16b, v18.16b \n" - "eor v19.16b, v19.16b, v19.16b \n" - - "0: \n" - - "smlal v8.4s, v4.4h, v0.h[0] \n" - "smlal v9.4s, v4.4h, v0.h[1] \n" - "smlal v10.4s, v4.4h, v0.h[2] \n" - "smlal v11.4s, v4.4h, v0.h[3] \n" - "smlal v12.4s, v4.4h, v0.h[4] \n" - "smlal v13.4s, v4.4h, v0.h[5] \n" - "smlal v14.4s, v4.4h, v0.h[6] \n" - "smlal v15.4s, v4.4h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r23 - - "smlal v16.4s, v4.4h, v1.h[0] \n" - "smlal v17.4s, v4.4h, v1.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v18.4s, v4.4h, v1.h[2] \n" - "smlal v19.4s, v4.4h, v1.h[3] \n" - - "smlal2 v8.4s, v4.8h, v1.h[4] \n" - "smlal2 v9.4s, v4.8h, v1.h[5] \n" - "smlal2 v10.4s, v4.8h, v1.h[6] \n" - "smlal2 v11.4s, v4.8h, v1.h[7] \n" - "smlal2 v12.4s, v4.8h, v2.h[0] \n" - "smlal2 v13.4s, v4.8h, v2.h[1] \n" - "smlal2 v14.4s, v4.8h, v2.h[2] \n" - "smlal2 v15.4s, v4.8h, v2.h[3] \n" - "smlal2 v16.4s, v4.8h, v2.h[4] \n" - "smlal2 v17.4s, v4.8h, v2.h[5] \n" - "smlal2 v18.4s, v4.8h, v2.h[6] \n" - "smlal2 v19.4s, v4.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r45 - - "smlal v8.4s, v5.4h, v3.h[0] \n" - "smlal v9.4s, v5.4h, v3.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v10.4s, v5.4h, v3.h[2] \n" - "smlal v11.4s, v5.4h, v3.h[3] \n" - "smlal v12.4s, v5.4h, v3.h[4] \n" - "smlal v13.4s, v5.4h, v3.h[5] \n" - "smlal v14.4s, v5.4h, v3.h[6] \n" - "smlal v15.4s, v5.4h, v3.h[7] \n" - "smlal v16.4s, v5.4h, v0.h[0] \n" - "smlal v17.4s, v5.4h, v0.h[1] \n" - "smlal v18.4s, v5.4h, v0.h[2] \n" - "smlal v19.4s, v5.4h, v0.h[3] \n" - - "ld1 {v6.8h, v7.8h}, [%3], #32 \n" // w23 - - "smlal2 v8.4s, v5.8h, v0.h[4] \n" - "smlal2 v9.4s, v5.8h, v0.h[5] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal2 v10.4s, v5.8h, v0.h[6] \n" - "smlal2 v11.4s, v5.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r67 - - "smlal2 v12.4s, v5.8h, v1.h[0] \n" - "smlal2 v13.4s, v5.8h, v1.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal2 v14.4s, v5.8h, v1.h[2] \n" - "smlal2 v15.4s, v5.8h, v1.h[3] \n" - "smlal2 v16.4s, v5.8h, v1.h[4] \n" - "smlal2 v17.4s, v5.8h, v1.h[5] \n" - "smlal2 v18.4s, v5.8h, v1.h[6] \n" - "smlal2 v19.4s, v5.8h, v1.h[7] \n" - - "smlal v8.4s, v6.4h, v2.h[0] \n" - "smlal v9.4s, v6.4h, v2.h[1] \n" - "smlal v10.4s, v6.4h, v2.h[2] \n" - "smlal v11.4s, v6.4h, v2.h[3] \n" - "smlal v12.4s, v6.4h, v2.h[4] \n" - "smlal v13.4s, v6.4h, v2.h[5] \n" - "smlal v14.4s, v6.4h, v2.h[6] \n" - "smlal v15.4s, v6.4h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r89 - - "smlal v16.4s, v6.4h, v3.h[0] \n" - "smlal v17.4s, v6.4h, v3.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v18.4s, v6.4h, v3.h[2] \n" - "smlal v19.4s, v6.4h, v3.h[3] \n" - - "smlal2 v8.4s, v6.8h, v3.h[4] \n" - "smlal2 v9.4s, v6.8h, v3.h[5] \n" - "smlal2 v10.4s, v6.8h, v3.h[6] \n" - "smlal2 v11.4s, v6.8h, v3.h[7] \n" - "smlal2 v12.4s, v6.8h, v0.h[0] \n" - "smlal2 v13.4s, v6.8h, v0.h[1] \n" - "smlal2 v14.4s, v6.8h, v0.h[2] \n" - "smlal2 v15.4s, v6.8h, v0.h[3] \n" - "smlal2 v16.4s, v6.8h, v0.h[4] \n" - "smlal2 v17.4s, v6.8h, v0.h[5] \n" - "smlal2 v18.4s, v6.8h, v0.h[6] \n" - "smlal2 v19.4s, v6.8h, v0.h[7] \n" - - "ld1 {v2.8h, v3.8h}, [%2], #32 \n" // r1011 - - "smlal v8.4s, v7.4h, v1.h[0] \n" - "smlal v9.4s, v7.4h, v1.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal v10.4s, v7.4h, v1.h[2] \n" - "smlal v11.4s, v7.4h, v1.h[3] \n" - "smlal v12.4s, v7.4h, v1.h[4] \n" - "smlal v13.4s, v7.4h, v1.h[5] \n" - "smlal v14.4s, v7.4h, v1.h[6] \n" - "smlal v15.4s, v7.4h, v1.h[7] \n" - "smlal v16.4s, v7.4h, v2.h[0] \n" - "smlal v17.4s, v7.4h, v2.h[1] \n" - "smlal v18.4s, v7.4h, v2.h[2] \n" - "smlal v19.4s, v7.4h, v2.h[3] \n" - - "ld1 {v4.8h, v5.8h}, [%3], #32 \n" // w01 - - "smlal2 v8.4s, v7.8h, v2.h[4] \n" - "smlal2 v9.4s, v7.8h, v2.h[5] \n" - - "prfm pldl1keep, [%3, #256] \n" - - "smlal2 v10.4s, v7.8h, v2.h[6] \n" - "smlal2 v11.4s, v7.8h, v2.h[7] \n" - - "ld1 {v0.8h, v1.8h}, [%2], #32 \n" // r01 - - "smlal2 v12.4s, v7.8h, v3.h[0] \n" - "smlal2 v13.4s, v7.8h, v3.h[1] \n" - - "prfm pldl1keep, [%2, #256] \n" - - "smlal2 v14.4s, v7.8h, v3.h[2] \n" - "smlal2 v15.4s, v7.8h, v3.h[3] \n" - "smlal2 v16.4s, v7.8h, v3.h[4] \n" - "smlal2 v17.4s, v7.8h, v3.h[5] \n" - - "subs %w0, %w0, #1 \n" - - "smlal2 v18.4s, v7.8h, v3.h[6] \n" - "smlal2 v19.4s, v7.8h, v3.h[7] \n" - - "bne 0b \n" - - "sub %2, %2, #32 \n" - "sub %3, %3, #32 \n" - - "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" - "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n" - "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n" - - : "=r"(nn), // %0 - "=r"(output0_tm), // %1 - "=r"(r0), // %2 - "=r"(k0) // %3 - : "0"(nn), - "1"(output0_tm), - "2"(r0), - "3"(k0) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"); - } - for (; i + 7 < tiles; i += 8) - { - const short* r0 = bb2.row(i / 12 + (i % 12) / 8); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - int16x8_t _val4 = vld1q_s16(r0 + 32); - int16x8_t _val5 = vld1q_s16(r0 + 40); - int16x8_t _val6 = vld1q_s16(r0 + 48); - int16x8_t _val7 = vld1q_s16(r0 + 56); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w0), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val0), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w0), vget_low_s16(_val0), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_high_s16(_val0), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w0), vget_high_s16(_val0), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_high_s16(_val0), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w0), vget_high_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w0), vget_low_s16(_val1), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val1), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w0), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w0), vget_high_s16(_val1), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_high_s16(_val1), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w0), vget_high_s16(_val1), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_high_s16(_val1), 3); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val2), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w1), vget_low_s16(_val2), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val2), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w1), vget_low_s16(_val2), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_high_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w1), vget_high_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_high_s16(_val2), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w1), vget_high_s16(_val2), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w1), vget_low_s16(_val3), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val3), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w1), vget_low_s16(_val3), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val3), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w1), vget_high_s16(_val3), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_high_s16(_val3), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w1), vget_high_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_high_s16(_val3), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_low_s16(_val4), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w2), vget_low_s16(_val4), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_low_s16(_val4), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w2), vget_low_s16(_val4), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_high_s16(_val4), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w2), vget_high_s16(_val4), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_high_s16(_val4), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w2), vget_high_s16(_val4), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w2), vget_low_s16(_val5), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_low_s16(_val5), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w2), vget_low_s16(_val5), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_low_s16(_val5), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w2), vget_high_s16(_val5), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_high_s16(_val5), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w2), vget_high_s16(_val5), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_high_s16(_val5), 3); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_low_s16(_val6), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_low_s16(_w3), vget_low_s16(_val6), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_low_s16(_val6), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_low_s16(_w3), vget_low_s16(_val6), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_high_s16(_val6), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_low_s16(_w3), vget_high_s16(_val6), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_high_s16(_val6), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_low_s16(_w3), vget_high_s16(_val6), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_high_s16(_w3), vget_low_s16(_val7), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_low_s16(_val7), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_high_s16(_w3), vget_low_s16(_val7), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_low_s16(_val7), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_high_s16(_w3), vget_high_s16(_val7), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_high_s16(_val7), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_high_s16(_w3), vget_high_s16(_val7), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_high_s16(_val7), 3); - - r0 += 64; - k0 += 32; - } - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum1); - vst1q_s32(output0_tm + 8, _sum2); - vst1q_s32(output0_tm + 12, _sum3); - vst1q_s32(output0_tm + 16, _sum4); - vst1q_s32(output0_tm + 20, _sum5); - vst1q_s32(output0_tm + 24, _sum6); - vst1q_s32(output0_tm + 28, _sum7); - output0_tm += 32; - } -#endif // __aarch64__ - for (; i + 3 < tiles; i += 4) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); -#else - const short* r0 = bb2.row(i / 4); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __aarch64__ - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - int32x4_t _sum4 = vdupq_n_s32(0); - int32x4_t _sum5 = vdupq_n_s32(0); - int32x4_t _sum6 = vdupq_n_s32(0); - int32x4_t _sum7 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - int16x8_t _val2 = vld1q_s16(r0 + 16); - int16x8_t _val3 = vld1q_s16(r0 + 24); - - int16x8_t _w0 = vld1q_s16(k0); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w0), vget_low_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w0), vget_low_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w0), vget_low_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w0), vget_low_s16(_val3), 1); - - int16x8_t _w1 = vld1q_s16(k0 + 8); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w1), vget_low_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w1), vget_low_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w1), vget_low_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w1), vget_low_s16(_val3), 3); - - int16x8_t _w2 = vld1q_s16(k0 + 16); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_high_s16(_val1), 1); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w2), vget_high_s16(_val2), 0); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w2), vget_high_s16(_val2), 1); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w2), vget_high_s16(_val3), 0); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w2), vget_high_s16(_val3), 1); - - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_high_s16(_val1), 3); - _sum4 = vmlal_lane_s16(_sum4, vget_low_s16(_w3), vget_high_s16(_val2), 2); - _sum5 = vmlal_lane_s16(_sum5, vget_high_s16(_w3), vget_high_s16(_val2), 3); - _sum6 = vmlal_lane_s16(_sum6, vget_low_s16(_w3), vget_high_s16(_val3), 2); - _sum7 = vmlal_lane_s16(_sum7, vget_high_s16(_w3), vget_high_s16(_val3), 3); - - r0 += 32; - k0 += 32; - } - - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - _sum4 = vaddq_s32(_sum4, _sum5); - _sum6 = vaddq_s32(_sum6, _sum7); - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum2); - vst1q_s32(output0_tm + 8, _sum4); - vst1q_s32(output0_tm + 12, _sum6); - output0_tm += 16; -#else - asm volatile( - "veor q8, q8 \n" - "veor q9, q9 \n" - "veor q10, q10 \n" - "veor q11, q11 \n" - "veor q12, q12 \n" - "veor q13, q13 \n" - "veor q14, q14 \n" - "veor q15, q15 \n" - - "0: \n" - - "pld [%2, #256] \n" - "pld [%2, #512] \n" - "vldm %2!, {d0-d7} \n" - - "pld [%3, #256] \n" - "vld1.s16 {d8-d11}, [%3 :128]! \n" - - "vmlal.s16 q8, d8, d0[0] \n" - "vmlal.s16 q12, d9, d0[1] \n" - "vmlal.s16 q9, d8, d2[0] \n" - "vmlal.s16 q13, d9, d2[1] \n" - "vmlal.s16 q10, d8, d4[0] \n" - "vmlal.s16 q14, d9, d4[1] \n" - "vmlal.s16 q11, d8, d6[0] \n" - "vmlal.s16 q15, d9, d6[1] \n" - - "pld [%3, #128] \n" - "vld1.s16 {d8-d9}, [%3 :128]! \n" - - "vmlal.s16 q8, d10, d0[2] \n" - "vmlal.s16 q12, d11, d0[3] \n" - "vmlal.s16 q9, d10, d2[2] \n" - "vmlal.s16 q13, d11, d2[3] \n" - "vmlal.s16 q10, d10, d4[2] \n" - "vmlal.s16 q14, d11, d4[3] \n" - "vmlal.s16 q11, d10, d6[2] \n" - "vmlal.s16 q15, d11, d6[3] \n" - - "pld [%3, #128] \n" - "vld1.s16 {d10-d11}, [%3 :128]! \n" - - "vmlal.s16 q8, d8, d1[0] \n" - "vmlal.s16 q12, d9, d1[1] \n" - "vmlal.s16 q9, d8, d3[0] \n" - "vmlal.s16 q13, d9, d3[1] \n" - "vmlal.s16 q10, d8, d5[0] \n" - "vmlal.s16 q14, d9, d5[1] \n" - "vmlal.s16 q11, d8, d7[0] \n" - "vmlal.s16 q15, d9, d7[1] \n" - - "subs %0, %0, #1 \n" - - "vmlal.s16 q8, d10, d1[2] \n" - "vmlal.s16 q12, d11, d1[3] \n" - "vmlal.s16 q9, d10, d3[2] \n" - "vmlal.s16 q13, d11, d3[3] \n" - "vmlal.s16 q10, d10, d5[2] \n" - "vmlal.s16 q14, d11, d5[3] \n" - "vmlal.s16 q11, d10, d7[2] \n" - "vmlal.s16 q15, d11, d7[3] \n" - - "bne 0b \n" - - "vadd.s32 q8, q8, q12 \n" - "vadd.s32 q9, q9, q13 \n" - "vadd.s32 q10, q10, q14 \n" - "vadd.s32 q11, q11, q15 \n" - - "vstm %1!, {d16-d23} \n" - - : "=r"(nn), - "=r"(output0_tm), - "=r"(r0), - "=r"(k0) - : "0"(nn), - "1"(output0_tm), - "2"(r0), - "3"(k0) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); -#endif - } - for (; i + 1 < tiles; i += 2) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - int32x4_t _sum2 = vdupq_n_s32(0); - int32x4_t _sum3 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - int16x8_t _val1 = vld1q_s16(r0 + 8); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w0), vget_low_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w0), vget_low_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w1), vget_low_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w1), vget_low_s16(_val1), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w2), vget_high_s16(_val1), 0); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w2), vget_high_s16(_val1), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); - _sum2 = vmlal_lane_s16(_sum2, vget_low_s16(_w3), vget_high_s16(_val1), 2); - _sum3 = vmlal_lane_s16(_sum3, vget_high_s16(_w3), vget_high_s16(_val1), 3); - - r0 += 16; - k0 += 32; - } - - _sum0 = vaddq_s32(_sum0, _sum1); - _sum2 = vaddq_s32(_sum2, _sum3); - - vst1q_s32(output0_tm, _sum0); - vst1q_s32(output0_tm + 4, _sum2); - output0_tm += 8; - } - for (; i < tiles; i++) - { -#if __aarch64__ - const short* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); -#else - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - int32x4_t _sum0 = vdupq_n_s32(0); - int32x4_t _sum1 = vdupq_n_s32(0); - - for (int j = 0; j < nn; j++) - { - int16x8_t _val0 = vld1q_s16(r0); - - int16x8_t _w0 = vld1q_s16(k0); - int16x8_t _w1 = vld1q_s16(k0 + 8); - int16x8_t _w2 = vld1q_s16(k0 + 16); - int16x8_t _w3 = vld1q_s16(k0 + 24); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w0), vget_low_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w0), vget_low_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w1), vget_low_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w1), vget_low_s16(_val0), 3); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w2), vget_high_s16(_val0), 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w2), vget_high_s16(_val0), 1); - - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_w3), vget_high_s16(_val0), 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_w3), vget_high_s16(_val0), 3); - - r0 += 8; - k0 += 32; - } - - _sum0 = vaddq_s32(_sum0, _sum1); - - vst1q_s32(output0_tm, _sum0); - output0_tm += 4; - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_transform_int8.h b/src/layer/arm/convolution_winograd_transform_int8.h deleted file mode 100644 index 4e27e8c6287e..000000000000 --- a/src/layer/arm/convolution_winograd_transform_int8.h +++ /dev/null @@ -1,230 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_input_int8_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - - const int w_tiles = (w - 2) / 4; - const int h_tiles = (h - 2) / 4; - const int tiles = w_tiles * h_tiles; - - // const float itm[6][6] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4); - - for (int m = 0; m < 6; m++) - { - signed char r00 = r0[0]; - signed char r01 = r0[1]; - signed char r02 = r0[2]; - signed char r03 = r0[3]; - signed char r04 = r0[4]; - signed char r05 = r0[5]; - - short tmp0m = 4 * r00 - 5 * r02 + r04; - short tmp1m = -4 * (r01 + r02) + r04 + r03; - short tmp2m = 4 * (r01 - r02) + r04 - r03; - short tmp3m = -2 * (r01 - r03) + r04 - r02; - short tmp4m = 2 * (r01 - r03) + r04 - r02; - short tmp5m = 4 * r01 - 5 * r03 + r05; - - tmp[0][m] = tmp0m; - tmp[1][m] = tmp1m; - tmp[2][m] = tmp2m; - tmp[3][m] = tmp3m; - tmp[4][m] = tmp4m; - tmp[5][m] = tmp5m; - - r0 += w; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tiles + j); - short* r0_tm_1 = r0_tm_0 + tiles; - short* r0_tm_2 = r0_tm_0 + tiles * 2; - short* r0_tm_3 = r0_tm_0 + tiles * 3; - short* r0_tm_4 = r0_tm_0 + tiles * 4; - short* r0_tm_5 = r0_tm_0 + tiles * 5; - - for (int m = 0; m < 6; m++) - { - short tmp00 = tmp[m][0]; - short tmp01 = tmp[m][1]; - short tmp02 = tmp[m][2]; - short tmp03 = tmp[m][3]; - short tmp04 = tmp[m][4]; - short tmp05 = tmp[m][5]; - - short r0tm0 = 4 * tmp00 - 5 * tmp02 + tmp04; - short r0tm1 = -4 * (tmp01 + tmp02) + tmp04 + tmp03; - short r0tm2 = 4 * (tmp01 - tmp02) + tmp04 - tmp03; - short r0tm3 = -2 * (tmp01 - tmp03) + tmp04 - tmp02; - short r0tm4 = 2 * (tmp01 - tmp03) + tmp04 - tmp02; - short r0tm5 = 4 * tmp01 - 5 * tmp03 + tmp05; - - r0_tm_0[0] = r0tm0; - r0_tm_1[0] = r0tm1; - r0_tm_2[0] = r0tm2; - r0_tm_3[0] = r0tm3; - r0_tm_4[0] = r0tm4; - r0_tm_5[0] = r0tm5; - - r0_tm_0 += tiles * 6; - r0_tm_1 += tiles * 6; - r0_tm_2 += tiles * 6; - r0_tm_3 += tiles * 6; - r0_tm_4 += tiles * 6; - r0_tm_5 += tiles * 6; - } - } - } - } -} - -static void conv3x3s1_winograd43_transform_output_int8_neon(const Mat& top_blob_tm, Mat& top_blob, const Option& opt) -{ - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - - const int w_tiles = outw / 4; - const int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob.channel(p); - - int tmp[4][6]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tiles + j) * 1; - const int* output0_tm_1 = output0_tm_0 + tiles * 1; - const int* output0_tm_2 = output0_tm_0 + tiles * 2; - const int* output0_tm_3 = output0_tm_0 + tiles * 3; - const int* output0_tm_4 = output0_tm_0 + tiles * 4; - const int* output0_tm_5 = output0_tm_0 + tiles * 5; - - int* output0 = out0.row(i * 4) + j * 4; - - // TODO neon optimize - for (int m = 0; m < 5; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = output0_tm_0[0] + tmp02a + tmp02b; - tmp[1][m] = tmp13a + tmp13b * 2; - tmp[2][m] = tmp02a + tmp02b * 4; - tmp[3][m] = output0_tm_5[0] * 4 + tmp13a + tmp13b * 8; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - for (int m = 5; m < 6; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = (output0_tm_0[0] + tmp02a + tmp02b) * 4; - tmp[1][m] = (tmp13a + tmp13b * 2) * 4; - tmp[2][m] = (tmp02a + tmp02b * 4) * 4; - tmp[3][m] = (output0_tm_5[0] * 4 + tmp13a + tmp13b * 8) * 4; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - - for (int m = 0; m < 4; m++) - { - const int* tmp0 = tmp[m]; - - int tmp02a = tmp0[1] + tmp0[2]; - int tmp13a = tmp0[1] - tmp0[2]; - - int tmp02b = tmp0[3] + tmp0[4]; - int tmp13b = tmp0[3] - tmp0[4]; - - output0[0] = (tmp0[0] + tmp02a + tmp02b) / 576; - output0[1] = (tmp13a + tmp13b * 2) / 576; - output0[2] = (tmp02a + tmp02b * 4) / 576; - output0[3] = (tmp0[5] + tmp13a + tmp13b * 8) / 576; - - output0 += outw; - } - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_transform_pack4_int8.h b/src/layer/arm/convolution_winograd_transform_pack4_int8.h deleted file mode 100644 index fff5f7d66506..000000000000 --- a/src/layer/arm/convolution_winograd_transform_pack4_int8.h +++ /dev/null @@ -1,178 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_output_pack4_int8_neon(const Mat& top_blob_tm, Mat& top_blob, const Option& opt) -{ - const int outw = top_blob.w; - const int outh = top_blob.h; - const int outch = top_blob.c; - - const int w_tiles = outw / 4; - const int h_tiles = outh / 4; - const int tiles = w_tiles * h_tiles; - - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob.channel(p); - - int tmp[4][6][4]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tiles + j) * 4; - const int* output0_tm_1 = output0_tm_0 + tiles * 4; - const int* output0_tm_2 = output0_tm_0 + tiles * 8; - const int* output0_tm_3 = output0_tm_0 + tiles * 12; - const int* output0_tm_4 = output0_tm_0 + tiles * 16; - const int* output0_tm_5 = output0_tm_0 + tiles * 20; - - int* output0 = out0.row(i * 4) + (j * 4) * 4; - - for (int m = 0; m < 5; m++) - { - int32x4_t _out0tm0 = vld1q_s32(output0_tm_0); - int32x4_t _out0tm1 = vld1q_s32(output0_tm_1); - int32x4_t _out0tm2 = vld1q_s32(output0_tm_2); - int32x4_t _out0tm3 = vld1q_s32(output0_tm_3); - int32x4_t _out0tm4 = vld1q_s32(output0_tm_4); - int32x4_t _out0tm5 = vld1q_s32(output0_tm_5); - - int32x4_t _tmp02a = vaddq_s32(_out0tm1, _out0tm2); - int32x4_t _tmp13a = vsubq_s32(_out0tm1, _out0tm2); - - int32x4_t _tmp02b = vaddq_s32(_out0tm3, _out0tm4); - int32x4_t _tmp13b = vsubq_s32(_out0tm3, _out0tm4); - - int32x4_t _v2 = vdupq_n_s32(2); - int32x4_t _v4 = vdupq_n_s32(4); - int32x4_t _v8 = vdupq_n_s32(8); - - int32x4_t _tmp0m = vaddq_s32(vaddq_s32(_out0tm0, _tmp02a), _tmp02b); - int32x4_t _tmp1m = vmlaq_s32(_tmp13a, _tmp13b, _v2); - int32x4_t _tmp2m = vmlaq_s32(_tmp02a, _tmp02b, _v4); - int32x4_t _tmp3m = vmlaq_s32(vmlaq_s32(_tmp13a, _out0tm5, _v4), _tmp13b, _v8); - - vst1q_s32(tmp[0][m], _tmp0m); - vst1q_s32(tmp[1][m], _tmp1m); - vst1q_s32(tmp[2][m], _tmp2m); - vst1q_s32(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - for (int m = 5; m < 6; m++) - { - int32x4_t _out0tm0 = vld1q_s32(output0_tm_0); - int32x4_t _out0tm1 = vld1q_s32(output0_tm_1); - int32x4_t _out0tm2 = vld1q_s32(output0_tm_2); - int32x4_t _out0tm3 = vld1q_s32(output0_tm_3); - int32x4_t _out0tm4 = vld1q_s32(output0_tm_4); - int32x4_t _out0tm5 = vld1q_s32(output0_tm_5); - - int32x4_t _tmp02a = vaddq_s32(_out0tm1, _out0tm2); - int32x4_t _tmp13a = vsubq_s32(_out0tm1, _out0tm2); - - int32x4_t _tmp02b = vaddq_s32(_out0tm3, _out0tm4); - int32x4_t _tmp13b = vsubq_s32(_out0tm3, _out0tm4); - - int32x4_t _v2 = vdupq_n_s32(2); - int32x4_t _v4 = vdupq_n_s32(4); - int32x4_t _v8 = vdupq_n_s32(8); - - int32x4_t _tmp0m = vaddq_s32(vaddq_s32(_out0tm0, _tmp02a), _tmp02b); - int32x4_t _tmp1m = vmlaq_s32(_tmp13a, _tmp13b, _v2); - int32x4_t _tmp2m = vmlaq_s32(_tmp02a, _tmp02b, _v4); - int32x4_t _tmp3m = vmlaq_s32(vmlaq_s32(_tmp13a, _out0tm5, _v4), _tmp13b, _v8); - - _tmp0m = vmulq_s32(_tmp0m, _v4); - _tmp1m = vmulq_s32(_tmp1m, _v4); - _tmp2m = vmulq_s32(_tmp2m, _v4); - _tmp3m = vmulq_s32(_tmp3m, _v4); - - vst1q_s32(tmp[0][m], _tmp0m); - vst1q_s32(tmp[1][m], _tmp1m); - vst1q_s32(tmp[2][m], _tmp2m); - vst1q_s32(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - - for (int m = 0; m < 4; m++) - { - int32x4_t _tmp00 = vld1q_s32(tmp[m][0]); - int32x4_t _tmp01 = vld1q_s32(tmp[m][1]); - int32x4_t _tmp02 = vld1q_s32(tmp[m][2]); - int32x4_t _tmp03 = vld1q_s32(tmp[m][3]); - int32x4_t _tmp04 = vld1q_s32(tmp[m][4]); - int32x4_t _tmp05 = vld1q_s32(tmp[m][5]); - - int32x4_t _tmp02a = vaddq_s32(_tmp01, _tmp02); - int32x4_t _tmp13a = vsubq_s32(_tmp01, _tmp02); - - int32x4_t _tmp02b = vaddq_s32(_tmp03, _tmp04); - int32x4_t _tmp13b = vsubq_s32(_tmp03, _tmp04); - - int32x4_t _v2 = vdupq_n_s32(2); - int32x4_t _v4 = vdupq_n_s32(4); - int32x4_t _v8 = vdupq_n_s32(8); - - int32x4_t _out00 = vaddq_s32(vaddq_s32(_tmp00, _tmp02a), _tmp02b); - int32x4_t _out01 = vmlaq_s32(_tmp13a, _tmp13b, _v2); - int32x4_t _out02 = vmlaq_s32(_tmp02a, _tmp02b, _v4); - int32x4_t _out03 = vmlaq_s32(vaddq_s32(_tmp05, _tmp13a), _tmp13b, _v8); - - // TODO use integer trick for division by 576 - float32x4_t _v576 = vdupq_n_f32(1.0 / 576); - _out00 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out00), _v576)); - _out01 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out01), _v576)); - _out02 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out02), _v576)); - _out03 = vcvtq_s32_f32(vmulq_f32(vcvtq_f32_s32(_out03), _v576)); - - vst1q_s32(output0, _out00); - vst1q_s32(output0 + 4, _out01); - vst1q_s32(output0 + 8, _out02); - vst1q_s32(output0 + 12, _out03); - - output0 += outw * 4; - } - } - } - } -} diff --git a/src/layer/arm/convolution_winograd_transform_pack8_int8.h b/src/layer/arm/convolution_winograd_transform_pack8_int8.h deleted file mode 100644 index f0d8981ef77d..000000000000 --- a/src/layer/arm/convolution_winograd_transform_pack8_int8.h +++ /dev/null @@ -1,131 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -static void conv3x3s1_winograd43_transform_input_pack8_int8_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) -{ - const int w = bottom_blob.w; - const int h = bottom_blob.h; - const int inch = bottom_blob.c; - - const int w_tiles = (w - 2) / 4; - const int h_tiles = (h - 2) / 4; - const int tiles = w_tiles * h_tiles; - - // const float itm[6][6] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tiles; i++) - { - for (int j = 0; j < w_tiles; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - int8x8_t _r00 = vld1_s8(r0); - int8x8_t _r01 = vld1_s8(r0 + 8); - int8x8_t _r02 = vld1_s8(r0 + 16); - int8x8_t _r03 = vld1_s8(r0 + 24); - int8x8_t _r04 = vld1_s8(r0 + 32); - int8x8_t _r05 = vld1_s8(r0 + 40); - - int8x8_t _v4s8 = vdup_n_s8(4); - int8x8_t _v5s8 = vdup_n_s8(5); - int16x8_t _v2 = vdupq_n_s16(2); - int16x8_t _v4 = vdupq_n_s16(4); - - int16x8_t _tmp0m = vsubq_s16(vaddw_s8(vmull_s8(_r00, _v4s8), _r04), vmull_s8(_r02, _v5s8)); - int16x8_t _tmp1m = vmlsq_s16(vaddl_s8(_r04, _r03), vaddl_s8(_r01, _r02), _v4); - int16x8_t _tmp2m = vmlaq_s16(vsubl_s8(_r04, _r03), vsubl_s8(_r01, _r02), _v4); - int16x8_t _tmp3m = vmlsq_s16(vsubl_s8(_r04, _r02), vsubl_s8(_r01, _r03), _v2); - int16x8_t _tmp4m = vmlaq_s16(vsubl_s8(_r04, _r02), vsubl_s8(_r01, _r03), _v2); - int16x8_t _tmp5m = vsubq_s16(vaddw_s8(vmull_s8(_r01, _v4s8), _r05), vmull_s8(_r03, _v5s8)); - - vst1q_s16(tmp[0][m], _tmp0m); - vst1q_s16(tmp[1][m], _tmp1m); - vst1q_s16(tmp[2][m], _tmp2m); - vst1q_s16(tmp[3][m], _tmp3m); - vst1q_s16(tmp[4][m], _tmp4m); - vst1q_s16(tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tiles + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - int16x8_t _tmp00 = vld1q_s16(tmp[m][0]); - int16x8_t _tmp01 = vld1q_s16(tmp[m][1]); - int16x8_t _tmp02 = vld1q_s16(tmp[m][2]); - int16x8_t _tmp03 = vld1q_s16(tmp[m][3]); - int16x8_t _tmp04 = vld1q_s16(tmp[m][4]); - int16x8_t _tmp05 = vld1q_s16(tmp[m][5]); - - int16x8_t _v2 = vdupq_n_s16(2); - int16x8_t _v4 = vdupq_n_s16(4); - int16x8_t _v5 = vdupq_n_s16(5); - - int16x8_t _r0tm0 = vmlsq_s16(vmlaq_s16(_tmp04, _tmp00, _v4), _tmp02, _v5); - int16x8_t _r0tm1 = vmlsq_s16(vaddq_s16(_tmp04, _tmp03), vaddq_s16(_tmp01, _tmp02), _v4); - int16x8_t _r0tm2 = vmlaq_s16(vsubq_s16(_tmp04, _tmp03), vsubq_s16(_tmp01, _tmp02), _v4); - int16x8_t _r0tm3 = vmlsq_s16(vsubq_s16(_tmp04, _tmp02), vsubq_s16(_tmp01, _tmp03), _v2); - int16x8_t _r0tm4 = vmlaq_s16(vsubq_s16(_tmp04, _tmp02), vsubq_s16(_tmp01, _tmp03), _v2); - int16x8_t _r0tm5 = vmlsq_s16(vmlaq_s16(_tmp05, _tmp01, _v4), _tmp03, _v5); - - vst1q_s16(r0_tm_0, _r0tm0); - vst1q_s16(r0_tm_1, _r0tm1); - vst1q_s16(r0_tm_2, _r0tm2); - vst1q_s16(r0_tm_3, _r0tm3); - vst1q_s16(r0_tm_4, _r0tm4); - vst1q_s16(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } -} diff --git a/src/net.cpp b/src/net.cpp index aed2f20a48e0..f4e70e98ae08 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -610,67 +610,41 @@ int NetPrivate::forward_layer(int layer_index, std::vector& blob_mats, std: int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Option& opt) const { - // clang-format off - // *INDENT-OFF* -#if NCNN_ARM82 - if (opt.use_fp16_storage && cpu_support_arm_asimdhp()) + if (bottom_blob.elembits() == 32) { - if (bottom_blob.elembits() == 32 && layer->support_fp16_storage) + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (opt.use_fp16_storage && cpu_support_arm_asimdhp() && layer->support_fp16_storage) { Mat bottom_blob_fp16; cast_float32_to_float16(bottom_blob, bottom_blob_fp16, opt); bottom_blob = bottom_blob_fp16; } - if (bottom_blob.elembits() == 16 && !layer->support_fp16_storage) - { - Mat bottom_blob_fp32; - cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); - bottom_blob = bottom_blob_fp32; - } - } - else + else #endif // NCNN_ARM82 #if NCNN_RVV - if (opt.use_fp16_storage && cpu_support_riscv_v() && cpu_support_riscv_zfh()) - { - if (bottom_blob.elembits() == 32 && layer->support_fp16_storage) + if (opt.use_fp16_storage && cpu_support_riscv_v() && cpu_support_riscv_zfh() && layer->support_fp16_storage) { Mat bottom_blob_fp16; cast_float32_to_float16(bottom_blob, bottom_blob_fp16, opt); bottom_blob = bottom_blob_fp16; } - if (bottom_blob.elembits() == 16 && !layer->support_fp16_storage) - { - Mat bottom_blob_fp32; - cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); - bottom_blob = bottom_blob_fp32; - } - } - else + else #endif // NCNN_RVV #if NCNN_BF16 - if (opt.use_bf16_storage) - { - if (bottom_blob.elembits() == 32 && layer->support_bf16_storage) + if (opt.use_bf16_storage && layer->support_bf16_storage) { Mat bottom_blob_bf16; cast_float32_to_bfloat16(bottom_blob, bottom_blob_bf16, opt); bottom_blob = bottom_blob_bf16; } - if (bottom_blob.elembits() == 16 && !layer->support_bf16_storage) - { - Mat bottom_blob_fp32; - cast_bfloat16_to_float32(bottom_blob, bottom_blob_fp32, opt); - bottom_blob = bottom_blob_fp32; - } - } - else #endif // NCNN_BF16 - { - // no type conversion + + // *INDENT-ON* + // clang-format on } - // *INDENT-ON* - // clang-format on int dst_elempack = 1; if (opt.use_packing_layout) @@ -746,6 +720,42 @@ int NetPrivate::convert_layout(Mat& bottom_blob, const Layer* layer, const Optio bottom_blob = bottom_blob_packed; } + if (bottom_blob.elembits() == 16) + { + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (opt.use_fp16_storage && cpu_support_arm_asimdhp() && !layer->support_fp16_storage) + { + Mat bottom_blob_fp32; + cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); + bottom_blob = bottom_blob_fp32; + } + else +#endif // NCNN_ARM82 +#if NCNN_RVV + if (opt.use_fp16_storage && cpu_support_riscv_v() && cpu_support_riscv_zfh() && !layer->support_fp16_storage) + { + Mat bottom_blob_fp32; + cast_float16_to_float32(bottom_blob, bottom_blob_fp32, opt); + bottom_blob = bottom_blob_fp32; + } + else +#endif // NCNN_RVV +#if NCNN_BF16 + if (opt.use_bf16_storage && !layer->support_bf16_storage) + { + Mat bottom_blob_fp32; + cast_bfloat16_to_float32(bottom_blob, bottom_blob_fp32, opt); + bottom_blob = bottom_blob_fp32; + } +#endif // NCNN_BF16 + + // *INDENT-ON* + // clang-format on + } + return 0; }