From f38fc95e0c4388b8a040f6d381ff393553a1551a Mon Sep 17 00:00:00 2001 From: Baozhu Zuo Date: Thu, 28 Sep 2023 14:45:40 +0800 Subject: [PATCH 01/16] benchmark: add raspberry pi 5 8G benchmark (#5058) --- benchmark/README.md | 85 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/benchmark/README.md b/benchmark/README.md index 83347dec97bc..93feb65041a4 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1631,7 +1631,92 @@ cooling_down = 1 vision_transformer min = 6605.19 max = 6606.66 avg = 6605.73 FastestDet min = 52.11 max = 52.97 avg = 52.61 ``` +### Raspberry Pi 5 Broadcom BCM2712, Cortex-A76 (ARMv8) (2.4GHz x 4) +``` +pi@raspberrypi:~/ncnn/benchmark $ ./benchncnn 10 4 0 -1 1 +loop_count = 10 +num_threads = 4 +powersave = 0 +gpu_device = -1 +cooling_down = 1 + squeezenet min = 8.56 max = 8.65 avg = 8.61 + squeezenet_int8 min = 11.65 max = 12.64 avg = 11.94 + mobilenet min = 11.32 max = 13.46 avg = 11.75 + mobilenet_int8 min = 11.30 max = 11.60 avg = 11.45 + mobilenet_v2 min = 13.57 max = 13.77 avg = 13.63 + mobilenet_v3 min = 9.18 max = 10.52 avg = 9.48 + shufflenet min = 4.56 max = 6.19 avg = 5.98 + shufflenet_v2 min = 5.04 max = 5.13 avg = 5.09 + mnasnet min = 8.27 max = 9.86 avg = 8.65 + proxylessnasnet min = 9.36 max = 11.18 avg = 9.62 + efficientnet_b0 min = 14.77 max = 14.96 avg = 14.87 + efficientnetv2_b0 min = 19.91 max = 20.11 avg = 19.99 + regnety_400m min = 11.91 max = 12.10 avg = 11.96 + blazeface min = 2.26 max = 2.29 avg = 2.28 + googlenet min = 32.80 max = 33.17 avg = 32.97 + googlenet_int8 min = 32.63 max = 32.99 avg = 32.78 + resnet18 min = 23.95 max = 24.21 avg = 24.12 + resnet18_int8 min = 32.50 max = 32.79 avg = 32.68 + alexnet min = 25.31 max = 25.75 avg = 25.51 + vgg16 min = 162.19 max = 165.08 avg = 163.75 + vgg16_int8 min = 187.46 max = 191.21 avg = 189.09 + resnet50 min = 55.95 max = 56.61 avg = 56.29 + resnet50_int8 min = 73.34 max = 73.97 avg = 73.59 + squeezenet_ssd min = 40.48 max = 41.39 avg = 40.92 + squeezenet_ssd_int8 min = 45.67 max = 46.35 avg = 46.06 + mobilenet_ssd min = 31.15 max = 31.73 avg = 31.48 + mobilenet_ssd_int8 min = 31.09 max = 31.44 avg = 31.27 + mobilenet_yolo min = 71.51 max = 72.38 avg = 71.95 + mobilenetv2_yolov3 min = 47.86 max = 48.41 avg = 48.04 + yolov4-tiny min = 55.95 max = 56.51 avg = 56.19 + nanodet_m min = 14.26 max = 14.68 avg = 14.48 + yolo-fastest-1.1 min = 6.48 max = 8.10 avg = 7.30 + yolo-fastestv2 min = 6.03 max = 7.33 avg = 7.04 + vision_transformer min = 613.62 max = 637.97 avg = 629.51 + FastestDet min = 6.53 max = 6.66 avg = 6.59 +pi@raspberrypi:~/ncnn/benchmark $ ./benchncnn 10 1 0 -1 1 +loop_count = 10 +num_threads = 1 +powersave = 0 +gpu_device = -1 +cooling_down = 1 + squeezenet min = 13.18 max = 13.27 avg = 13.22 + squeezenet_int8 min = 15.69 max = 15.93 avg = 15.78 + mobilenet min = 21.42 max = 21.55 avg = 21.46 + mobilenet_int8 min = 14.92 max = 20.91 avg = 17.34 + mobilenet_v2 min = 18.56 max = 23.06 avg = 19.24 + mobilenet_v3 min = 13.16 max = 13.33 avg = 13.25 + shufflenet min = 7.25 max = 11.14 avg = 8.43 + shufflenet_v2 min = 7.17 max = 11.15 avg = 7.70 + mnasnet min = 13.89 max = 13.94 avg = 13.91 + proxylessnasnet min = 17.01 max = 17.26 avg = 17.07 + efficientnet_b0 min = 26.19 max = 26.30 avg = 26.24 + efficientnetv2_b0 min = 39.69 max = 40.12 avg = 39.97 + regnety_400m min = 17.30 max = 17.44 avg = 17.36 + blazeface min = 4.74 max = 4.78 avg = 4.76 + googlenet min = 57.64 max = 57.84 avg = 57.72 + googlenet_int8 min = 55.80 max = 56.01 avg = 55.93 + resnet18 min = 31.90 max = 32.09 avg = 32.00 + resnet18_int8 min = 56.92 max = 57.16 avg = 57.01 + alexnet min = 39.84 max = 40.12 avg = 39.92 + vgg16 min = 208.33 max = 211.06 avg = 209.64 + vgg16_int8 min = 437.53 max = 440.55 avg = 439.35 + resnet50 min = 95.75 max = 96.68 avg = 96.28 + resnet50_int8 min = 116.80 max = 118.01 avg = 117.57 + squeezenet_ssd min = 47.75 max = 47.97 avg = 47.86 + squeezenet_ssd_int8 min = 61.98 max = 62.90 avg = 62.47 + mobilenet_ssd min = 52.83 max = 53.39 avg = 53.07 + mobilenet_ssd_int8 min = 46.15 max = 46.60 avg = 46.35 + mobilenet_yolo min = 117.68 max = 117.97 avg = 117.81 + mobilenetv2_yolov3 min = 67.37 max = 67.67 avg = 67.48 + yolov4-tiny min = 73.85 max = 74.35 avg = 74.10 + nanodet_m min = 22.78 max = 23.33 avg = 22.96 + yolo-fastest-1.1 min = 8.82 max = 8.91 avg = 8.87 + yolo-fastestv2 min = 8.18 max = 11.42 avg = 8.59 + vision_transformer min = 1267.90 max = 1269.45 avg = 1268.82 + FastestDet min = 7.79 max = 11.14 avg = 9.03 +``` ### Raspberry Pi Zero 2 W Broadcom BCM2710A1, Cortex-A53 (ARMv8) (1.0GHz x 4) ``` From 75ad1cc7490c38e47df892b9025320e2a2d175c9 Mon Sep 17 00:00:00 2001 From: daquexian Date: Mon, 2 Oct 2023 08:04:09 +0800 Subject: [PATCH 02/16] support tag in memorydata layer (#5061) Signed-off-by: daquexian --- src/layer/memorydata.cpp | 9 +++++---- src/layer/memorydata.h | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/layer/memorydata.cpp b/src/layer/memorydata.cpp index 6cd314d76b97..02a0e0be078c 100644 --- a/src/layer/memorydata.cpp +++ b/src/layer/memorydata.cpp @@ -28,6 +28,7 @@ int MemoryData::load_param(const ParamDict& pd) h = pd.get(1, 0); d = pd.get(11, 0); c = pd.get(2, 0); + load_type = pd.get(21, 1); return 0; } @@ -36,19 +37,19 @@ int MemoryData::load_model(const ModelBin& mb) { if (d != 0) { - data = mb.load(w, h, d, c, 1); + data = mb.load(w, h, d, c, load_type); } else if (c != 0) { - data = mb.load(w, h, c, 1); + data = mb.load(w, h, c, load_type); } else if (h != 0) { - data = mb.load(w, h, 1); + data = mb.load(w, h, load_type); } else if (w != 0) { - data = mb.load(w, 1); + data = mb.load(w, load_type); } else // 0 0 0 { diff --git a/src/layer/memorydata.h b/src/layer/memorydata.h index 4b2c697912f8..d5175ad0dd8a 100644 --- a/src/layer/memorydata.h +++ b/src/layer/memorydata.h @@ -35,6 +35,7 @@ class MemoryData : public Layer int h; int d; int c; + int load_type; Mat data; }; From 54a9a563e9913d3142782dac0e497f2be50075b5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Oct 2023 22:45:20 +0800 Subject: [PATCH 03/16] Bump pypa/cibuildwheel from 2.15.0 to 2.16.2 (#5064) Bumps [pypa/cibuildwheel](https://github.com/pypa/cibuildwheel) from 2.15.0 to 2.16.2. - [Release notes](https://github.com/pypa/cibuildwheel/releases) - [Changelog](https://github.com/pypa/cibuildwheel/blob/main/docs/changelog.md) - [Commits](https://github.com/pypa/cibuildwheel/compare/v2.15.0...v2.16.2) --- updated-dependencies: - dependency-name: pypa/cibuildwheel dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release-python.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index 9d99ffdd630d..da75661b0b4c 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -68,7 +68,7 @@ jobs: brew uninstall --ignore-dependencies libomp - name: Build wheels - uses: pypa/cibuildwheel@v2.15.0 + uses: pypa/cibuildwheel@v2.16.2 env: CIBW_ARCHS_MACOS: ${{ matrix.arch }} CIBW_ARCHS_LINUX: ${{ matrix.arch }} @@ -124,7 +124,7 @@ jobs: platforms: all - name: Build wheels - uses: pypa/cibuildwheel@v2.15.0 + uses: pypa/cibuildwheel@v2.16.2 env: CIBW_ARCHS_LINUX: ${{ matrix.arch }} CIBW_BUILD: ${{ matrix.build }} From c8f92b9f38d6c269402dfedefdcd5a3e4db0962e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 11:42:30 +0800 Subject: [PATCH 04/16] Bump stefanzweifel/git-auto-commit-action from 4 to 5 (#5073) Bumps [stefanzweifel/git-auto-commit-action](https://github.com/stefanzweifel/git-auto-commit-action) from 4 to 5. - [Release notes](https://github.com/stefanzweifel/git-auto-commit-action/releases) - [Changelog](https://github.com/stefanzweifel/git-auto-commit-action/blob/master/CHANGELOG.md) - [Commits](https://github.com/stefanzweifel/git-auto-commit-action/compare/v4...v5) --- updated-dependencies: - dependency-name: stefanzweifel/git-auto-commit-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/code-format.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 9ffb978181d1..8051371d3e51 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -51,7 +51,7 @@ jobs: rm -rf $GITHUB_WORKSPACE/clang-format-install export PATH=~/bin:$PATH sh codeformat.sh - - uses: stefanzweifel/git-auto-commit-action@v4 + - uses: stefanzweifel/git-auto-commit-action@v5 with: commit_message: apply code-format changes From b4f8fa6d38c0fb776b4b8b083ccb1c2d8013e816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=B5=E5=B0=8F=E5=87=A1?= <2672931+whyb@users.noreply.github.com> Date: Tue, 10 Oct 2023 15:56:42 +0800 Subject: [PATCH 05/16] Fixed _mm256_set_m128 is only availble on gcc8+. issue#5072 (#5075) --- src/layer/x86/shufflechannel_x86.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/x86/shufflechannel_x86.cpp b/src/layer/x86/shufflechannel_x86.cpp index f4326289b8b0..8afb22b2e2ed 100644 --- a/src/layer/x86/shufflechannel_x86.cpp +++ b/src/layer/x86/shufflechannel_x86.cpp @@ -343,9 +343,9 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt for (int i = 0; i < size; i++) { __m256 _p0 = _mm256_loadu_ps(ptr0); - // macro `_mm256_loadu2_m128` is declared in IntelĀ® Intrinsics Guide but somehow missed in - // __m256 _p1 = _mm256_loadu2_m128(ptr2, ptr1); - __m256 _p1 = _mm256_set_m128(_mm_loadu_ps(ptr2), _mm_loadu_ps(ptr1)); + + __m256 _p1 = _mm256_castps128_ps256(_mm_loadu_ps(ptr1)); + _p1 = _mm256_insertf128_ps(_p1, _mm_loadu_ps(ptr2), 1); __m256 _lo = _mm256_unpacklo_ps(_p0, _p1); __m256 _hi = _mm256_unpackhi_ps(_p0, _p1); From bedbe599ff6042dc15dd16fd32e41505fdb2a59d Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 10 Oct 2023 16:29:30 +0800 Subject: [PATCH 06/16] pnnx support torch-2.1 (#5074) --- .ci/pnnx.yml | 4 +++ tools/pnnx/CMakeLists.txt | 5 ++++ tools/pnnx/src/main.cpp | 5 ++++ .../F_scaled_dot_product_attention.cpp | 27 +++++++++++++++++++ 4 files changed, 41 insertions(+) diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 267a0afa289b..3f116a4fa2e4 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -48,6 +48,10 @@ jobs: torchvision-version: 0.15.1 torchvision-cache-key: '0_15_1' + - torch-version: 2.1.0 + torchvision-version: 0.16.0 + torchvision-cache-key: '0_16_0' + runs-on: pool-name: docker container: diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index 3a08cbc249e8..0c8326fc942f 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -70,6 +70,11 @@ if(Torch_VERSION VERSION_LESS "1.8") message(FATAL_ERROR "pnnx only supports PyTorch >= 1.8") endif() +if(Torch_VERSION VERSION_GREATER_EQUAL "2.1") + # c++17 is required for using torch 2.1+ headers + set(CMAKE_CXX_STANDARD 17) +endif() + if(TorchVision_FOUND) message(STATUS "Building with TorchVision") add_definitions(-DPNNX_TORCHVISION) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index dc8ca72dc7e8..f745ef03473e 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -300,6 +300,11 @@ int main(int argc, char** argv) fprintf(stderr, "\n"); } +#ifdef PNNX_TORCHVISION + // call some vision api to register vision ops :P + (void)vision::cuda_version(); +#endif + for (auto m : customop_modules) { fprintf(stderr, "load custom module %s\n", m.c_str()); diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index e7ca7bbf8243..8dcfafaf12b4 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -42,4 +42,31 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10) +class F_scaled_dot_product_attention_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +prim::Constant op_0 0 1 dropout_p value=%dropout_p +prim::Constant op_1 0 1 is_causal value=%is_causal +prim::Constant op_2 0 1 scale value=%scale +aten::scaled_dot_product_attention op_3 7 1 query key value attn_mask dropout_p is_causal scale out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.scaled_dot_product_attention"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) + } // namespace pnnx From d1289fb12da11cc451ac67a47e24c91add5a1989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E8=8F=9C=E8=90=9D=E5=8D=9C=E5=86=AC=E7=93=9C?= Date: Tue, 10 Oct 2023 16:36:55 +0800 Subject: [PATCH 07/16] benchmark: add RTX A3000 6G benchmark (#5070) --- benchmark/README.md | 55 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/benchmark/README.md b/benchmark/README.md index 93feb65041a4..ebe977bd46e0 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -5164,6 +5164,61 @@ cooling_down = 0 mobilenetv2_yolov3 min = 3.69 max = 5.14 avg = 3.91 ``` +### nVIDIA RTX A3000 of Notebook (6GB) +``` +cx@HP-ZBook-Fury-15-6-inch-G8-Mobile-Workstation-PC:~/ncnn/build/benchmark$ ./benchncnn 10 1 0 1 +[0 Intel(R) UHD Graphics (TGL GT1)] queueC=0[1] queueG=0[1] queueT=0[1] +[0 Intel(R) UHD Graphics (TGL GT1)] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 Intel(R) UHD Graphics (TGL GT1)] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 Intel(R) UHD Graphics (TGL GT1)] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[0 Intel(R) UHD Graphics (TGL GT1)] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +[1 NVIDIA RTX A3000 Laptop GPU] queueC=2[8] queueG=0[16] queueT=1[2] +[1 NVIDIA RTX A3000 Laptop GPU] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[1 NVIDIA RTX A3000 Laptop GPU] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[1 NVIDIA RTX A3000 Laptop GPU] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[1 NVIDIA RTX A3000 Laptop GPU] fp16-matrix-16_8_8/16_8_16/16_16_16=1/1/1 +loop_count = 10 +num_threads = 1 +powersave = 0 +gpu_device = 1 +cooling_down = 1 + squeezenet min = 1.49 max = 1.94 avg = 1.74 + squeezenet_int8 min = 6.13 max = 6.20 avg = 6.16 + mobilenet min = 4.05 max = 4.82 avg = 4.65 + mobilenet_int8 min = 10.24 max = 10.29 avg = 10.26 + mobilenet_v2 min = 0.98 max = 1.14 avg = 1.03 + mobilenet_v3 min = 1.74 max = 1.82 avg = 1.77 + shufflenet min = 1.43 max = 30.51 avg = 9.51 + shufflenet_v2 min = 3.43 max = 3.89 avg = 3.77 + mnasnet min = 6.50 max = 6.75 avg = 6.62 + proxylessnasnet min = 6.46 max = 7.28 avg = 7.00 + efficientnet_b0 min = 3.14 max = 15.11 avg = 7.29 + efficientnetv2_b0 min = 18.50 max = 20.13 avg = 19.17 + regnety_400m min = 2.16 max = 3.57 avg = 2.70 + blazeface min = 2.52 max = 2.76 avg = 2.65 + googlenet min = 2.67 max = 14.67 avg = 9.85 + googlenet_int8 min = 19.08 max = 19.40 avg = 19.19 + resnet18 min = 5.19 max = 9.44 avg = 8.48 + resnet18_int8 min = 16.57 max = 17.69 avg = 16.96 + alexnet min = 1.98 max = 3.24 avg = 2.23 + vgg16 min = 3.59 max = 12.34 avg = 10.99 + vgg16_int8 min = 110.63 max = 124.31 avg = 118.16 + resnet50 min = 3.01 max = 4.93 avg = 3.77 + resnet50_int8 min = 41.58 max = 44.80 avg = 43.24 + squeezenet_ssd min = 4.08 max = 4.70 avg = 4.32 + squeezenet_ssd_int8 min = 17.32 max = 17.92 avg = 17.46 + mobilenet_ssd min = 2.26 max = 8.23 avg = 5.57 + mobilenet_ssd_int8 min = 20.35 max = 21.89 avg = 20.76 + mobilenet_yolo min = 2.14 max = 16.94 avg = 6.44 + mobilenetv2_yolov3 min = 3.64 max = 5.09 avg = 4.02 + yolov4-tiny min = 10.94 max = 17.46 avg = 13.58 + nanodet_m min = 6.57 max = 13.91 avg = 9.82 + yolo-fastest-1.1 min = 5.40 max = 14.22 avg = 10.78 + yolo-fastestv2 min = 7.49 max = 9.43 avg = 7.99 + vision_transformer min = 76.04 max = 76.96 avg = 76.43 + FastestDet min = 6.31 max = 6.60 avg = 6.43 +``` + ### nVIDIA RTX2080 of Desktop ``` E:\projects\framework\ncnn\benchmark>benchncnn.exe 4096 1 0 0 0 From 7b024252460d53e9b12ef847db99c67a9fb7df30 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 10 Oct 2023 21:29:32 +0800 Subject: [PATCH 08/16] x86 optimization for convolution int8 winograd unified elempack (#5054) --- src/layer/x86/convolution_3x3_int8.h | 827 --- src/layer/x86/convolution_3x3_pack8to1_int8.h | 1125 --- src/layer/x86/convolution_3x3_pack8to4_int8.h | 945 --- src/layer/x86/convolution_3x3_winograd_int8.h | 6407 +++++++++++++++++ src/layer/x86/convolution_im2col_gemm_int8.h | 2950 ++++++-- src/layer/x86/convolution_x86.cpp | 59 +- src/layer/x86/convolution_x86_avx2.cpp | 19 +- src/layer/x86/convolution_x86_avx512vnni.cpp | 21 +- src/layer/x86/convolution_x86_avxvnni.cpp | 21 +- src/layer/x86/convolution_x86_xop.cpp | 21 +- src/layer/x86/x86_usability.h | 196 +- tests/test_convolution_3.cpp | 25 + 12 files changed, 8972 insertions(+), 3644 deletions(-) delete mode 100644 src/layer/x86/convolution_3x3_pack8to1_int8.h delete mode 100644 src/layer/x86/convolution_3x3_pack8to4_int8.h create mode 100644 src/layer/x86/convolution_3x3_winograd_int8.h diff --git a/src/layer/x86/convolution_3x3_int8.h b/src/layer/x86/convolution_3x3_int8.h index a5c5dfe4e71d..ceaf75b92e1f 100644 --- a/src/layer/x86/convolution_3x3_int8.h +++ b/src/layer/x86/convolution_3x3_int8.h @@ -78,833 +78,6 @@ static void conv3x3s1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& } } -static void conv3x3s1_winograd23_transform_kernel_int8_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - kernel_tm.create(4 * 4, inch, outch, (size_t)2u); - - // G - const short ktm[4][3] = { - {2, 0, 0}, - {1, 1, 1}, - {1, -1, 1}, - {0, 0, 2} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[4][3]; - for (int i = 0; i < 4; i++) - { - tmp[i][0] = (short)k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = (short)k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = (short)k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 4; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 4; i++) - { - kernel_tm0[j * 4 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } -} - -static void conv3x3s1_winograd23_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 2n+2, winograd F(2,3) - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 1) / 2 * 2; - outh = (outh + 1) / 2 * 2; - - w = outw + 2; - h = outh + 2; - Option opt_b = opt; - opt_b.blob_allocator = opt.workspace_allocator; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - const int tiles = nColBlocks * nRowBlocks; - - bottom_blob_tm.create(4 * 4, tiles, inch, 2u, opt.workspace_allocator); - - // BT - // const float itm[4][4] = { - // {1.0f, 0.0f, -1.0f, 0.0f}, - // {0.0f, 1.0f, 1.00f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 0.0f}, - // {0.0f, -1.0f, 0.00f, 1.0f} - // }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const signed char* img = bottom_blob_bordered.channel(q); - short* out_tm0 = bottom_blob_tm.channel(q); - - for (int j = 0; j < nColBlocks; j++) - { - const signed char* r0 = img + w * j * 2; - const signed char* r1 = r0 + w; - const signed char* r2 = r1 + w; - const signed char* r3 = r2 + w; - - for (int i = 0; i < nRowBlocks; i++) - { - short d0[4], d1[4], d2[4], d3[4]; - short w0[4], w1[4], w2[4], w3[4]; - short t0[4], t1[4], t2[4], t3[4]; - // load - for (int n = 0; n < 4; n++) - { - d0[n] = r0[n]; - d1[n] = r1[n]; - d2[n] = r2[n]; - d3[n] = r3[n]; - } - // w = B_t * d - for (int n = 0; n < 4; n++) - { - w0[n] = d0[n] - d2[n]; - w1[n] = d1[n] + d2[n]; - w2[n] = d2[n] - d1[n]; - w3[n] = d3[n] - d1[n]; - } - // transpose d to d_t - { - t0[0] = w0[0]; - t1[0] = w0[1]; - t2[0] = w0[2]; - t3[0] = w0[3]; - t0[1] = w1[0]; - t1[1] = w1[1]; - t2[1] = w1[2]; - t3[1] = w1[3]; - t0[2] = w2[0]; - t1[2] = w2[1]; - t2[2] = w2[2]; - t3[2] = w2[3]; - t0[3] = w3[0]; - t1[3] = w3[1]; - t2[3] = w3[2]; - t3[3] = w3[3]; - } - // U = B_t * d_t - for (int n = 0; n < 4; n++) - { - d0[n] = t0[n] - t2[n]; - d1[n] = t1[n] + t2[n]; - d2[n] = t2[n] - t1[n]; - d3[n] = t3[n] - t1[n]; - } - // save to out_tm - for (int n = 0; n < 4; n++) - { - out_tm0[n] = d0[n]; - out_tm0[n + 4] = d1[n]; - out_tm0[n + 8] = d2[n]; - out_tm0[n + 12] = d3[n]; - } - - r0 += 2; - r1 += 2; - r2 += 2; - r3 += 2; - - out_tm0 += 16; - } - } - } - } - bottom_blob_bordered = Mat(); - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - const int tiles = nColBlocks * nRowBlocks; - - top_blob_tm.create(16, tiles, outch, 4u, opt.workspace_allocator); - - int nn_outch = outch >> 2; - int remain_outch_start = nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 4; - - Mat out0_tm = top_blob_tm.channel(p); - Mat out1_tm = top_blob_tm.channel(p + 1); - Mat out2_tm = top_blob_tm.channel(p + 2); - Mat out3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p); - const Mat kernel1_tm = kernel_tm.channel(p + 1); - const Mat kernel2_tm = kernel_tm.channel(p + 2); - const Mat kernel3_tm = kernel_tm.channel(p + 3); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - int* output1_tm = out1_tm.row(i); - int* output2_tm = out2_tm.row(i); - int* output3_tm = out3_tm.row(i); - - int sum0[16] = {0}; - int sum1[16] = {0}; - int sum2[16] = {0}; - int sum3[16] = {0}; - - int q = 0; - for (; q + 3 < inch; q += 4) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* r1 = bottom_blob_tm.channel(q + 1).row(i); - const short* r2 = bottom_blob_tm.channel(q + 2).row(i); - const short* r3 = bottom_blob_tm.channel(q + 3).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel1_tm.row(q); - const short* k2 = kernel2_tm.row(q); - const short* k3 = kernel3_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r1[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r2[n] * k0[n]; - k0 += 16; - sum0[n] += (int)r3[n] * k0[n]; - k0 -= 16 * 3; - - sum1[n] += (int)r0[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r1[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r2[n] * k1[n]; - k1 += 16; - sum1[n] += (int)r3[n] * k1[n]; - k1 -= 16 * 3; - - sum2[n] += (int)r0[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r1[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r2[n] * k2[n]; - k2 += 16; - sum2[n] += (int)r3[n] * k2[n]; - k2 -= 16 * 3; - - sum3[n] += (int)r0[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r1[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r2[n] * k3[n]; - k3 += 16; - sum3[n] += (int)r3[n] * k3[n]; - k3 -= 16 * 3; - } - } - - for (; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel1_tm.row(q); - const short* k2 = kernel2_tm.row(q); - const short* k3 = kernel3_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - sum1[n] += (int)r0[n] * k1[n]; - sum2[n] += (int)r0[n] * k2[n]; - sum3[n] += (int)r0[n] * k3[n]; - } - } - - for (int n = 0; n < 16; n++) - { - output0_tm[n] = sum0[n]; - output1_tm[n] = sum1[n]; - output2_tm[n] = sum2[n]; - output3_tm[n] = sum3[n]; - } - } - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - Mat out0_tm = top_blob_tm.channel(p); - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - - int sum0[16] = {0}; - - int q = 0; - for (; q + 3 < inch; q += 4) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* r1 = bottom_blob_tm.channel(q + 1).row(i); - const short* r2 = bottom_blob_tm.channel(q + 2).row(i); - const short* r3 = bottom_blob_tm.channel(q + 3).row(i); - - const short* k0 = kernel0_tm.row(q); - const short* k1 = kernel0_tm.row(q + 1); - const short* k2 = kernel0_tm.row(q + 2); - const short* k3 = kernel0_tm.row(q + 3); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - sum0[n] += (int)r1[n] * k1[n]; - sum0[n] += (int)r2[n] * k2[n]; - sum0[n] += (int)r3[n] * k3[n]; - } - } - - for (; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* k0 = kernel0_tm.row(q); - - for (int n = 0; n < 16; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - } - } - - for (int n = 0; n < 16; n++) - { - output0_tm[n] = sum0[n]; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator); - { - // AT - // const float itm[2][4] = { - // {1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 1.0f} - // }; - - int w_tm = outw / 2 * 4; - int h_tm = outh / 2 * 4; - - int nColBlocks = h_tm / 4; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 4; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out_tm = top_blob_tm.channel(p); - Mat out = top_blob_bordered.channel(p); - - for (int j = 0; j < nColBlocks; j++) - { - int* outRow0 = out.row(j * 2); - int* outRow1 = out.row(j * 2 + 1); - - for (int i = 0; i < nRowBlocks; i++) - { - int* out_tile = out_tm.row(j * nRowBlocks + i); - - int s0[4], s1[4], s2[4], s3[4]; - int w0[4], w1[4]; - int d0[2], d1[2], d2[2], d3[2]; - int o0[2], o1[2]; - // load - for (int n = 0; n < 4; n++) - { - s0[n] = out_tile[n]; - s1[n] = out_tile[n + 4]; - s2[n] = out_tile[n + 8]; - s3[n] = out_tile[n + 12]; - } - // w = A_T * W - for (int n = 0; n < 4; n++) - { - w0[n] = s0[n] + s1[n] + s2[n]; - w1[n] = s1[n] - s2[n] + s3[n]; - } - // transpose w to w_t - { - d0[0] = w0[0]; - d0[1] = w1[0]; - d1[0] = w0[1]; - d1[1] = w1[1]; - d2[0] = w0[2]; - d2[1] = w1[2]; - d3[0] = w0[3]; - d3[1] = w1[3]; - } - // Y = A_T * w_t - for (int n = 0; n < 2; n++) - { - o0[n] = d0[n] + d1[n] + d2[n]; - o1[n] = d1[n] - d2[n] + d3[n]; - } - // save to top blob tm,why right 2,because the G' = G*2 - outRow0[0] = o0[0] >> 2; - outRow0[1] = o0[1] >> 2; - outRow1[0] = o1[0] >> 2; - outRow1[1] = o1[1] >> 2; - - outRow0 += 2; - outRow1 += 2; - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - -static void conv3x3s1_winograd43_transform_kernel_int8_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - kernel_tm.create(6 * 6, inch, outch, (size_t)2u); - - // G - // const float ktm[6][3] = { - // { 1.0f/4, 0.0f, 0.0f}, - // { -1.0f/6, -1.0f/6, -1.0f/6}, - // { -1.0f/6, 1.0f/6, -1.0f/6}, - // { 1.0f/24, 1.0f/12, 1.0f/6}, - // { 1.0f/24, -1.0f/12, 1.0f/6}, - // { 0.0f, 0.0f, 1.0f} - // }; - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 24} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } -} - -static void conv3x3s1_winograd43_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2, winograd F(4,3) - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - Option opt_b = opt; - opt_b.blob_allocator = opt.workspace_allocator; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - const int tiles = nColBlocks * nRowBlocks; - - bottom_blob_tm.create(6 * 6, tiles, inch, 2u, opt.workspace_allocator); - - // BT - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r03 + r04 - // 2 = 4 * (r01 - r02) - r03 + r04 - // 3 = -2 * r01 - r02 + 2 * r03 + r04 - // 4 = 2 * r01 - r02 - 2 * r03 + r04 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const signed char* img = bottom_blob_bordered.channel(q); - short* out_tm0 = bottom_blob_tm.channel(q); - - for (int j = 0; j < nColBlocks; j++) - { - const signed char* r0 = img + w * j * 4; - const signed char* r1 = r0 + w; - const signed char* r2 = r1 + w; - const signed char* r3 = r2 + w; - const signed char* r4 = r3 + w; - const signed char* r5 = r4 + w; - - for (int i = 0; i < nRowBlocks; i++) - { - short d0[6], d1[6], d2[6], d3[6], d4[6], d5[6]; - short w0[6], w1[6], w2[6], w3[6], w4[6], w5[6]; - short t0[6], t1[6], t2[6], t3[6], t4[6], t5[6]; - - // load - for (int n = 0; n < 6; n++) - { - d0[n] = r0[n]; - d1[n] = r1[n]; - d2[n] = r2[n]; - d3[n] = r3[n]; - d4[n] = r4[n]; - d5[n] = r5[n]; - } - // w = B_t * d - for (int n = 0; n < 6; n++) - { - w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n]; - w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n]; - w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n]; - w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n]; - w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n]; - w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n]; - } - // transpose d to d_t - { - t0[0] = w0[0]; - t1[0] = w0[1]; - t2[0] = w0[2]; - t3[0] = w0[3]; - t4[0] = w0[4]; - t5[0] = w0[5]; - t0[1] = w1[0]; - t1[1] = w1[1]; - t2[1] = w1[2]; - t3[1] = w1[3]; - t4[1] = w1[4]; - t5[1] = w1[5]; - t0[2] = w2[0]; - t1[2] = w2[1]; - t2[2] = w2[2]; - t3[2] = w2[3]; - t4[2] = w2[4]; - t5[2] = w2[5]; - t0[3] = w3[0]; - t1[3] = w3[1]; - t2[3] = w3[2]; - t3[3] = w3[3]; - t4[3] = w3[4]; - t5[3] = w3[5]; - t0[4] = w4[0]; - t1[4] = w4[1]; - t2[4] = w4[2]; - t3[4] = w4[3]; - t4[4] = w4[4]; - t5[4] = w4[5]; - t0[5] = w5[0]; - t1[5] = w5[1]; - t2[5] = w5[2]; - t3[5] = w5[3]; - t4[5] = w5[4]; - t5[5] = w5[5]; - } - // d = B_t * d_t - for (int n = 0; n < 6; n++) - { - d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n]; - d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n]; - d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n]; - d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n]; - d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n]; - d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n]; - } - // save to out_tm - for (int n = 0; n < 6; n++) - { - out_tm0[n] = d0[n]; - out_tm0[n + 6] = d1[n]; - out_tm0[n + 12] = d2[n]; - out_tm0[n + 18] = d3[n]; - out_tm0[n + 24] = d4[n]; - out_tm0[n + 30] = d5[n]; - } - - r0 += 4; - r1 += 4; - r2 += 4; - r3 += 4; - r4 += 4; - r5 += 4; - - out_tm0 += 36; - } - } - } - } - bottom_blob_bordered = Mat(); - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - const int tiles = nColBlocks * nRowBlocks; - - top_blob_tm.create(36, tiles, outch, 4u, opt.workspace_allocator); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out0_tm = top_blob_tm.channel(p); - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int i = 0; i < tiles; i++) - { - int* output0_tm = out0_tm.row(i); - - int sum0[36] = {0}; - - for (int q = 0; q < inch; q++) - { - const short* r0 = bottom_blob_tm.channel(q).row(i); - const short* k0 = kernel0_tm.row(q); - - for (int n = 0; n < 36; n++) - { - sum0[n] += (int)r0[n] * k0[n]; - } - } - - for (int n = 0; n < 36; n++) - { - output0_tm[n] = sum0[n]; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator); - { - // AT - // const float itm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + r01 + r02 + r03 + r04 - // 1 = r01 - r02 + 2 * (r03 - r04) - // 2 = r01 + r02 + 4 * (r03 + r04) - // 3 = r01 - r02 + 8 * (r03 - r04) + r05 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - int nColBlocks = h_tm / 6; // may be the block num in Feathercnn - int nRowBlocks = w_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - Mat out_tm = top_blob_tm.channel(p); - Mat out = top_blob_bordered.channel(p); - - for (int j = 0; j < nColBlocks; j++) - { - int* outRow0 = out.row(j * 4); - int* outRow1 = out.row(j * 4 + 1); - int* outRow2 = out.row(j * 4 + 2); - int* outRow3 = out.row(j * 4 + 3); - - for (int i = 0; i < nRowBlocks; i++) - { - int* out_tile = out_tm.row(j * nRowBlocks + i); - - int s0[6], s1[6], s2[6], s3[6], s4[6], s5[6]; - int w0[6], w1[6], w2[6], w3[6]; - int d0[4], d1[4], d2[4], d3[4], d4[4], d5[4]; - int o0[4], o1[4], o2[4], o3[4]; - // load - for (int n = 0; n < 6; n++) - { - s0[n] = out_tile[n]; - s1[n] = out_tile[n + 6]; - s2[n] = out_tile[n + 12]; - s3[n] = out_tile[n + 18]; - s4[n] = out_tile[n + 24]; - s5[n] = out_tile[n + 30]; - } - // w = A_T * W - for (int n = 0; n < 6; n++) - { - w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n]; - w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n]; - w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n]; - w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + s5[n]; - } - // transpose w to w_t - { - d0[0] = w0[0]; - d0[1] = w1[0]; - d0[2] = w2[0]; - d0[3] = w3[0]; - d1[0] = w0[1]; - d1[1] = w1[1]; - d1[2] = w2[1]; - d1[3] = w3[1]; - d2[0] = w0[2]; - d2[1] = w1[2]; - d2[2] = w2[2]; - d2[3] = w3[2]; - d3[0] = w0[3]; - d3[1] = w1[3]; - d3[2] = w2[3]; - d3[3] = w3[3]; - d4[0] = w0[4]; - d4[1] = w1[4]; - d4[2] = w2[4]; - d4[3] = w3[4]; - d5[0] = w0[5]; - d5[1] = w1[5]; - d5[2] = w2[5]; - d5[3] = w3[5]; - } - // Y = A_T * w_t - for (int n = 0; n < 4; n++) - { - o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n]; - o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n]; - o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n]; - o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n]; - } - // save to top blob tm - for (int n = 0; n < 4; n++) - { - outRow0[n] = o0[n] / 576; - outRow1[n] = o1[n] / 576; - outRow2[n] = o2[n] / 576; - outRow3[n] = o3[n] / 576; - } - - outRow0 += 4; - outRow1 += 4; - outRow2 += 4; - outRow3 += 4; - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} - static void conv3x3s2_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt) { int w = bottom_blob.w; diff --git a/src/layer/x86/convolution_3x3_pack8to1_int8.h b/src/layer/x86/convolution_3x3_pack8to1_int8.h deleted file mode 100644 index d5957faf6d89..000000000000 --- a/src/layer/x86/convolution_3x3_pack8to1_int8.h +++ /dev/null @@ -1,1125 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to1_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif -#endif - -static void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(kernel, kernel_tm_pack8to1, inch, outch, opt); - return; - } -#endif -#endif - - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8to1.create(8 * inch / 8, 36, outch / 4 + outch % 4, (size_t)2u * 4, 4); - - int p = 0; - for (; p + 3 < outch; p += 4) - { - Mat g0 = kernel_tm_pack8to1.channel(p / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 4; i++) - { - for (int j = 0; j < 8; j++) - { - const short* k00 = kernel_tm.channel(p + i).row(q + j); - g00[0] = k00[k]; - g00 += 1; - } - } - } - } - } - for (; p < outch; p++) - { - const Mat k0 = kernel_tm.channel(p); - - Mat g0 = kernel_tm_pack8to1.channel(p / 4 + p % 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int q = 0; q + 7 < inch; q += 8) - { - for (int i = 0; i < 8; i++) - { - g00[0] = k0.row(q + i)[k]; - - g00 += 1; - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to1_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_avx2(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_pack8to1_int8_sse_xop(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif -#endif - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - // TODO use _mm_cvtepi8_epi16 on sse4.1 - __m128i _r00_01 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r02_03 = _mm_loadu_si128((const __m128i*)(r0 + 16)); - __m128i _r04_05 = _mm_loadu_si128((const __m128i*)(r0 + 32)); - __m128i _extr0001 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r00_01); - __m128i _extr0203 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r02_03); - __m128i _extr0405 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r04_05); - __m128i _r00 = _mm_unpacklo_epi8(_r00_01, _extr0001); - __m128i _r01 = _mm_unpackhi_epi8(_r00_01, _extr0001); - __m128i _r02 = _mm_unpacklo_epi8(_r02_03, _extr0203); - __m128i _r03 = _mm_unpackhi_epi8(_r02_03, _extr0203); - __m128i _r04 = _mm_unpacklo_epi8(_r04_05, _extr0405); - __m128i _r05 = _mm_unpackhi_epi8(_r04_05, _extr0405); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _tmp0m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r00, 2), _r04), _mm_mullo_epi16(_r02, _v5)); - __m128i _tmp1m = _mm_sub_epi16(_mm_add_epi16(_r04, _r03), _mm_slli_epi16(_mm_add_epi16(_r01, _r02), 2)); - __m128i _tmp2m = _mm_add_epi16(_mm_sub_epi16(_r04, _r03), _mm_slli_epi16(_mm_sub_epi16(_r01, _r02), 2)); - __m128i _tmp3m = _mm_sub_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp4m = _mm_add_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp5m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r01, 2), _r05), _mm_mullo_epi16(_r03, _v5)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4m); - _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tm / 6 + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _r0tm0 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp00, 2), _tmp04), _mm_mullo_epi16(_tmp02, _v5)); - __m128i _r0tm1 = _mm_sub_epi16(_mm_add_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_add_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm2 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm3 = _mm_sub_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm4 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm5 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp01, 2), _tmp05), _mm_mullo_epi16(_tmp03, _v5)); - - _mm_storeu_si128((__m128i*)r0_tm_0, _r0tm0); - _mm_storeu_si128((__m128i*)r0_tm_1, _r0tm1); - _mm_storeu_si128((__m128i*)r0_tm_2, _r0tm2); - _mm_storeu_si128((__m128i*)r0_tm_3, _r0tm3); - _mm_storeu_si128((__m128i*)r0_tm_4, _r0tm4); - _mm_storeu_si128((__m128i*)r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = h_tm / 6 * w_tm / 6; - - // permute - // bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); - Mat bottom_blob_tm2; -#if __AVX2__ - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#else - if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < 36; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - short* tmpptr = tm2.row(i / 4); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - _mm256_storeu_si256((__m256i*)tmpptr, _r0); - _mm256_storeu_si256((__m256i*)(tmpptr + 16), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 32; - } - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#else - short* tmpptr = tm2.row(i / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - _mm_storeu_si128((__m128i*)(tmpptr + 8), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 16; - } - } - for (; i < tiles; i++) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, 36, outch, 4u, 1, opt.workspace_allocator); - - int nn_outch = 0; - int remain_outch_start = 0; - - nn_outch = outch >> 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int pp = 0; pp < nn_outch; pp++) - { - int p = pp * 4; - - int* output0_tm = top_blob_tm.channel(p); - int* output1_tm = top_blob_tm.channel(p + 1); - int* output2_tm = top_blob_tm.channel(p + 2); - int* output3_tm = top_blob_tm.channel(p + 3); - - const Mat kernel0_tm = kernel_tm.channel(p / 4); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); - - __m256i _sum04_15 = _mm256_setzero_si256(); - __m256i _sum14_05 = _mm256_setzero_si256(); - __m256i _sum06_17 = _mm256_setzero_si256(); - __m256i _sum16_07 = _mm256_setzero_si256(); - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif - - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m256i _val32 = _mm256_permute4x64_epi64(_val23, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum04_15 = _mm256_dpwssd_epi32(_sum04_15, _val23, _w01); - _sum14_05 = _mm256_dpwssd_epi32(_sum14_05, _val32, _w01); - _sum06_17 = _mm256_dpwssd_epi32(_sum06_17, _val23, _w23); - _sum16_07 = _mm256_dpwssd_epi32(_sum16_07, _val32, _w23); -#else - _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_madd_epi16(_val23, _w01)); - _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_madd_epi16(_val32, _w01)); - _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_madd_epi16(_val23, _w23)); - _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_madd_epi16(_val32, _w23)); -#endif - - r0 += 32; - k0 += 32; - } - - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); - _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); - _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); - _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); - _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); - - int sum[16]; - _mm256_storeu_si256((__m256i*)sum, _sum00_11); - _mm256_storeu_si256((__m256i*)(sum + 8), _sum04_15); - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm[1] = sum[4]; - output1_tm[1] = sum[5]; - output2_tm[1] = sum[6]; - output3_tm[1] = sum[7]; - output0_tm[2] = sum[8]; - output1_tm[2] = sum[9]; - output2_tm[2] = sum[10]; - output3_tm[2] = sum[11]; - output0_tm[3] = sum[12]; - output1_tm[3] = sum[13]; - output2_tm[3] = sum[14]; - output3_tm[3] = sum[15]; - output0_tm += 4; - output1_tm += 4; - output2_tm += 4; - output3_tm += 4; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); -#else - __m128i _sum00 = _mm_setzero_si128(); - __m128i _sum01 = _mm_setzero_si128(); - __m128i _sum02 = _mm_setzero_si128(); - __m128i _sum03 = _mm_setzero_si128(); - __m128i _sum10 = _mm_setzero_si128(); - __m128i _sum11 = _mm_setzero_si128(); - __m128i _sum12 = _mm_setzero_si128(); - __m128i _sum13 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { -#if __AVX2__ - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif -#else - // 0 1 2 3 4 5 6 7 - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum00 = _mm_maddd_epi16(_val0, _w0, _sum00); - _sum01 = _mm_maddd_epi16(_val0, _w1, _sum01); - _sum02 = _mm_maddd_epi16(_val0, _w2, _sum02); - _sum03 = _mm_maddd_epi16(_val0, _w3, _sum03); - _sum10 = _mm_maddd_epi16(_val1, _w0, _sum10); - _sum11 = _mm_maddd_epi16(_val1, _w1, _sum11); - _sum12 = _mm_maddd_epi16(_val1, _w2, _sum12); - _sum13 = _mm_maddd_epi16(_val1, _w3, _sum13); -#else - _sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00); - _sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01); - _sum02 = _mm_add_epi32(_mm_madd_epi16(_val0, _w2), _sum02); - _sum03 = _mm_add_epi32(_mm_madd_epi16(_val0, _w3), _sum03); - _sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10); - _sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11); - _sum12 = _mm_add_epi32(_mm_madd_epi16(_val1, _w2), _sum12); - _sum13 = _mm_add_epi32(_mm_madd_epi16(_val1, _w3), _sum13); -#endif -#endif - - r0 += 16; - k0 += 32; - } - -#if __AVX2__ - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - - int sum[8]; - _mm256_storeu_si256((__m256i*)sum, _sum00_11); -#else - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); - _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); - _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); - _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); - _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); - _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); - _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); - _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); - _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00 = _mm_add_epi32(_sum00, _sum01); - _sum02 = _mm_add_epi32(_sum02, _sum03); - _sum10 = _mm_add_epi32(_sum10, _sum11); - _sum12 = _mm_add_epi32(_sum12, _sum13); - - _sum00 = _mm_add_epi32(_sum00, _sum02); - _sum10 = _mm_add_epi32(_sum10, _sum12); - - int sum[8]; - _mm_storeu_si128((__m128i*)sum, _sum00); - _mm_storeu_si128((__m128i*)(sum + 4), _sum10); -#endif - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm[1] = sum[4]; - output1_tm[1] = sum[5]; - output2_tm[1] = sum[6]; - output3_tm[1] = sum[7]; - output0_tm += 2; - output1_tm += 2; - output2_tm += 2; - output3_tm += 2; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum0_1 = _mm256_setzero_si256(); - __m256i _sum2_3 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); - __m128i _sum2 = _mm_setzero_si128(); - __m128i _sum3 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 - __m128i _val = _mm_loadu_si128((const __m128i*)r0); - -#if __AVX2__ - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _valval = _mm256_inserti128_si256(_mm256_castsi128_si256(_val), _val, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0_1 = _mm256_dpwssd_epi32(_sum0_1, _valval, _w01); - _sum2_3 = _mm256_dpwssd_epi32(_sum2_3, _valval, _w23); -#else - _sum0_1 = _mm256_add_epi32(_sum0_1, _mm256_madd_epi16(_valval, _w01)); - _sum2_3 = _mm256_add_epi32(_sum2_3, _mm256_madd_epi16(_valval, _w23)); -#endif -#else - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val, _w1, _sum1); - _sum2 = _mm_maddd_epi16(_val, _w2, _sum2); - _sum3 = _mm_maddd_epi16(_val, _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val, _w1), _sum1); - _sum2 = _mm_add_epi32(_mm_madd_epi16(_val, _w2), _sum2); - _sum3 = _mm_add_epi32(_mm_madd_epi16(_val, _w3), _sum3); -#endif -#endif - - r0 += 8; - k0 += 32; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum0_1, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum0_1, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum2_3, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum2_3, 1); -#endif - - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); - _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); - _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); - _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum0 = _mm_add_epi32(_sum0, _sum1); - _sum2 = _mm_add_epi32(_sum2, _sum3); - - _sum0 = _mm_add_epi32(_sum0, _sum2); - - int sum[4]; - _mm_storeu_si128((__m128i*)sum, _sum0); - - output0_tm[0] = sum[0]; - output1_tm[0] = sum[1]; - output2_tm[0] = sum[2]; - output3_tm[0] = sum[3]; - output0_tm += 1; - output1_tm += 1; - output2_tm += 1; - output3_tm += 1; - } - } - } - - remain_outch_start += nn_outch << 2; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = remain_outch_start; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p / 4 + p % 4); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - __m256i _sum01 = _mm256_setzero_si256(); - __m256i _sum23 = _mm256_setzero_si256(); - - for (int q = 0; q < inch; q++) - { - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m256i _w01 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w0), _w0, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum01 = _mm256_dpwssd_epi32(_sum01, _val01, _w01); - _sum23 = _mm256_dpwssd_epi32(_sum23, _val23, _w01); -#else - _sum01 = _mm256_add_epi32(_sum01, _mm256_madd_epi16(_val01, _w01)); - _sum23 = _mm256_add_epi32(_sum23, _mm256_madd_epi16(_val23, _w01)); -#endif - - k0 += 8; - r0 += 32; - } - - __m128i _sum0 = _mm256_extracti128_si256(_sum01, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum01, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum23, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum23, 1); - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm[1] = _mm_reduce_add_epi32(_sum1); - output0_tm[2] = _mm_reduce_add_epi32(_sum2); - output0_tm[3] = _mm_reduce_add_epi32(_sum3); - output0_tm += 4; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - -#if __AVX2__ - __m256i _sum01 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); -#endif - - for (int q = 0; q < inch; q++) - { -#if __AVX2__ - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m256i _w01 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w0), _w0, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum01 = _mm256_dpwssd_epi32(_sum01, _val01, _w01); -#else - _sum01 = _mm256_add_epi32(_sum01, _mm256_madd_epi16(_val01, _w01)); -#endif -#else - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val0, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val1, _w0, _sum1); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum1); -#endif -#endif - - k0 += 8; - r0 += 16; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum01, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum01, 1); -#endif - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm[1] = _mm_reduce_add_epi32(_sum1); - output0_tm += 2; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - __m128i _sum0 = _mm_setzero_si128(); - - for (int q = 0; q < inch; q++) - { - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val0, _w0, _sum0); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum0); -#endif - - k0 += 8; - r0 += 8; - } - - output0_tm[0] = _mm_reduce_add_epi32(_sum0); - output0_tm++; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); - } - { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - int tmp[4][6]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, 4u, 1, opt.workspace_allocator); - - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 1; - const int* output0_tm_1 = output0_tm_0 + tiles * 1; - const int* output0_tm_2 = output0_tm_0 + tiles * 2; - const int* output0_tm_3 = output0_tm_0 + tiles * 3; - const int* output0_tm_4 = output0_tm_0 + tiles * 4; - const int* output0_tm_5 = output0_tm_0 + tiles * 5; - - int* output0 = out0.row(i * 4) + j * 4; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - // TODO sse optimize - for (int m = 0; m < 5; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = output0_tm_0[0] + tmp02a + tmp02b; - tmp[1][m] = tmp13a + tmp13b * 2; - tmp[2][m] = tmp02a + tmp02b * 4; - tmp[3][m] = output0_tm_5[0] * 4 + tmp13a + tmp13b * 8; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - for (int m = 5; m < 6; m++) - { - int tmp02a = output0_tm_1[0] + output0_tm_2[0]; - int tmp13a = output0_tm_1[0] - output0_tm_2[0]; - - int tmp02b = output0_tm_3[0] + output0_tm_4[0]; - int tmp13b = output0_tm_3[0] - output0_tm_4[0]; - - tmp[0][m] = (output0_tm_0[0] + tmp02a + tmp02b) * 4; - tmp[1][m] = (tmp13a + tmp13b * 2) * 4; - tmp[2][m] = (tmp02a + tmp02b * 4) * 4; - tmp[3][m] = (output0_tm_5[0] * 4 + tmp13a + tmp13b * 8) * 4; - - output0_tm_0 += tiles * 6; - output0_tm_1 += tiles * 6; - output0_tm_2 += tiles * 6; - output0_tm_3 += tiles * 6; - output0_tm_4 += tiles * 6; - output0_tm_5 += tiles * 6; - } - - for (int m = 0; m < 4; m++) - { - const int* tmp0 = tmp[m]; - - int tmp02a = tmp0[1] + tmp0[2]; - int tmp13a = tmp0[1] - tmp0[2]; - - int tmp02b = tmp0[3] + tmp0[4]; - int tmp13b = tmp0[3] - tmp0[4]; - - output0[0] = (tmp0[0] + tmp02a + tmp02b) / 576; - output0[1] = (tmp13a + tmp13b * 2) / 576; - output0[2] = (tmp02a + tmp02b * 4) / 576; - output0[3] = (tmp0[5] + tmp13a + tmp13b * 8) / 576; - - output0 += outw; - } - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/x86/convolution_3x3_pack8to4_int8.h b/src/layer/x86/convolution_3x3_pack8to4_int8.h deleted file mode 100644 index 2bb48ce1903a..000000000000 --- a/src/layer/x86/convolution_3x3_pack8to4_int8.h +++ /dev/null @@ -1,945 +0,0 @@ -// Tencent is pleased to support the open source community by making ncnn available. -// -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt); -void conv3x3s1_winograd43_pack8to4_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt); -#endif -#endif - -static void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(kernel, kernel_tm_pack8, inch, outch, opt); - return; - } -#endif -#endif - - // winograd43 transform kernel - Mat kernel_tm(6 * 6, inch, outch, (size_t)2u); - - const short ktm[6][3] = { - {6, 0, 0}, - {-4, -4, -4}, - {-4, 4, -4}, - {1, 2, 4}, - {1, -2, 4}, - {0, 0, 6} - }; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - for (int q = 0; q < inch; q++) - { - const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9; - short* kernel_tm0 = kernel_tm.channel(p).row(q); - - // transform kernel - const signed char* k0 = kernel0; - const signed char* k1 = kernel0 + 3; - const signed char* k2 = kernel0 + 6; - - // h - short tmp[6][3]; - for (int i = 0; i < 6; i++) - { - tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2]; - tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2]; - tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2]; - } - - // U - for (int j = 0; j < 6; j++) - { - short* tmpp = &tmp[j][0]; - - for (int i = 0; i < 6; i++) - { - kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2]; - } - } - } - } - - // interleave - // src = 36-inch-outch - // dst = 4b-8a-inch/8a-36-outch/4b - kernel_tm_pack8.create(inch / 8, 36, outch / 4, (size_t)2u * 32, 32); - - int q = 0; - for (; q + 3 < outch; q += 4) - { - Mat g0 = kernel_tm_pack8.channel(q / 4); - - for (int k = 0; k < 36; k++) - { - short* g00 = g0.row(k); - - for (int p = 0; p + 7 < inch; p += 8) - { - for (int i = 0; i < 4; i++) - { - for (int j = 0; j < 8; j++) - { - const short* k00 = kernel_tm.channel(q + i).row(p + j); - g00[0] = k00[k]; - g00 += 1; - } - } - } - } - } -} - -static void conv3x3s1_winograd43_pack8to4_int8_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Option& opt) -{ -#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) -#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ - if (ncnn::cpu_support_x86_avx512_vnni()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ - if (ncnn::cpu_support_x86_avx_vnni()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ - if (ncnn::cpu_support_x86_avx2()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_avx2(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif - -#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ - if (ncnn::cpu_support_x86_xop()) - { - conv3x3s1_winograd43_pack8to4_int8_sse_xop(bottom_blob, top_blob, kernel_tm, opt); - return; - } -#endif -#endif - - int w = bottom_blob.w; - int h = bottom_blob.h; - int inch = bottom_blob.c; - // size_t elemsize = bottom_blob.elemsize; - int elempack = bottom_blob.elempack; - - int outw = top_blob.w; - int outh = top_blob.h; - int outch = top_blob.c; - - // pad to 4n+2 - Mat bottom_blob_bordered = bottom_blob; - - outw = (outw + 3) / 4 * 4; - outh = (outh + 3) / 4 * 4; - - w = outw + 2; - h = outh + 2; - copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - - // BEGIN transform input - Mat bottom_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - short tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const signed char* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - // TODO use _mm_cvtepi8_epi16 on sse4.1 - __m128i _r00_01 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r02_03 = _mm_loadu_si128((const __m128i*)(r0 + 16)); - __m128i _r04_05 = _mm_loadu_si128((const __m128i*)(r0 + 32)); - __m128i _extr0001 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r00_01); - __m128i _extr0203 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r02_03); - __m128i _extr0405 = _mm_cmpgt_epi8(_mm_setzero_si128(), _r04_05); - __m128i _r00 = _mm_unpacklo_epi8(_r00_01, _extr0001); - __m128i _r01 = _mm_unpackhi_epi8(_r00_01, _extr0001); - __m128i _r02 = _mm_unpacklo_epi8(_r02_03, _extr0203); - __m128i _r03 = _mm_unpackhi_epi8(_r02_03, _extr0203); - __m128i _r04 = _mm_unpacklo_epi8(_r04_05, _extr0405); - __m128i _r05 = _mm_unpackhi_epi8(_r04_05, _extr0405); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _tmp0m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r00, 2), _r04), _mm_mullo_epi16(_r02, _v5)); - __m128i _tmp1m = _mm_sub_epi16(_mm_add_epi16(_r04, _r03), _mm_slli_epi16(_mm_add_epi16(_r01, _r02), 2)); - __m128i _tmp2m = _mm_add_epi16(_mm_sub_epi16(_r04, _r03), _mm_slli_epi16(_mm_sub_epi16(_r01, _r02), 2)); - __m128i _tmp3m = _mm_sub_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp4m = _mm_add_epi16(_mm_sub_epi16(_r04, _r02), _mm_slli_epi16(_mm_sub_epi16(_r01, _r03), 1)); - __m128i _tmp5m = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_r01, 2), _r05), _mm_mullo_epi16(_r03, _v5)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4m); - _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - short* r0_tm_0 = (short*)img0_tm + (i * w_tm / 6 + j) * 8; - short* r0_tm_1 = r0_tm_0 + tiles * 8; - short* r0_tm_2 = r0_tm_0 + tiles * 16; - short* r0_tm_3 = r0_tm_0 + tiles * 24; - short* r0_tm_4 = r0_tm_0 + tiles * 32; - short* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _v5 = _mm_set1_epi16(5); - - __m128i _r0tm0 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp00, 2), _tmp04), _mm_mullo_epi16(_tmp02, _v5)); - __m128i _r0tm1 = _mm_sub_epi16(_mm_add_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_add_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm2 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp03), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp02), 2)); - __m128i _r0tm3 = _mm_sub_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm4 = _mm_add_epi16(_mm_sub_epi16(_tmp04, _tmp02), _mm_slli_epi16(_mm_sub_epi16(_tmp01, _tmp03), 1)); - __m128i _r0tm5 = _mm_sub_epi16(_mm_add_epi16(_mm_slli_epi16(_tmp01, 2), _tmp05), _mm_mullo_epi16(_tmp03, _v5)); - - _mm_storeu_si128((__m128i*)r0_tm_0, _r0tm0); - _mm_storeu_si128((__m128i*)r0_tm_1, _r0tm1); - _mm_storeu_si128((__m128i*)r0_tm_2, _r0tm2); - _mm_storeu_si128((__m128i*)r0_tm_3, _r0tm3); - _mm_storeu_si128((__m128i*)r0_tm_4, _r0tm4); - _mm_storeu_si128((__m128i*)r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } - } - bottom_blob_bordered = Mat(); - // END transform input - - // BEGIN dot - Mat top_blob_tm; - { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = h_tm / 6 * w_tm / 6; - - // permute - // bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); - Mat bottom_blob_tm2; -#if __AVX2__ - if (tiles >= 4) - bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#else - if (tiles >= 2) - bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 2u * elempack, elempack, opt.workspace_allocator); - else // if (tiles >= 1) - bottom_blob_tm2.create(1 * inch, tiles, 36, 2u * elempack, elempack, opt.workspace_allocator); -#endif - - #pragma omp parallel for num_threads(opt.num_threads) - for (int r = 0; r < 36; r++) - { - Mat tm2 = bottom_blob_tm2.channel(r); - - // tile - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - short* tmpptr = tm2.row(i / 4); - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)r0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - _mm256_storeu_si256((__m256i*)tmpptr, _r0); - _mm256_storeu_si256((__m256i*)(tmpptr + 16), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 32; - } - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2); -#else - short* tmpptr = tm2.row(i / 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - _mm_storeu_si128((__m128i*)(tmpptr + 8), _r1); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 16; - } - } - for (; i < tiles; i++) - { -#if __AVX2__ - short* tmpptr = tm2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - short* tmpptr = tm2.row(i / 2 + i % 2); -#endif - - const short* r0 = bottom_blob_tm; - - r0 += (r * tiles + i) * 8; - - for (int q = 0; q < inch; q++) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)r0); - _mm_storeu_si128((__m128i*)tmpptr, _r0); - r0 += bottom_blob_tm.cstep * 8; - tmpptr += 8; - } - } - } - - bottom_blob_tm = Mat(); - // permute end - - top_blob_tm.create(tiles, 36, outch, 4u * 4, 4, opt.workspace_allocator); - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - int* output0_tm = top_blob_tm.channel(p); - - const Mat kernel0_tm = kernel_tm.channel(p); - - for (int r = 0; r < 36; r++) - { - const Mat bb2 = bottom_blob_tm2.channel(r); - - int i = 0; -#if __AVX2__ - for (; i + 3 < tiles; i += 4) - { - const short* r0 = bb2.row(i / 4); - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); - - __m256i _sum04_15 = _mm256_setzero_si256(); - __m256i _sum14_05 = _mm256_setzero_si256(); - __m256i _sum06_17 = _mm256_setzero_si256(); - __m256i _sum16_07 = _mm256_setzero_si256(); - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif - - __m256i _val23 = _mm256_loadu_si256((const __m256i*)(r0 + 16)); - - __m256i _val32 = _mm256_permute4x64_epi64(_val23, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum04_15 = _mm256_dpwssd_epi32(_sum04_15, _val23, _w01); - _sum14_05 = _mm256_dpwssd_epi32(_sum14_05, _val32, _w01); - _sum06_17 = _mm256_dpwssd_epi32(_sum06_17, _val23, _w23); - _sum16_07 = _mm256_dpwssd_epi32(_sum16_07, _val32, _w23); -#else - _sum04_15 = _mm256_add_epi32(_sum04_15, _mm256_madd_epi16(_val23, _w01)); - _sum14_05 = _mm256_add_epi32(_sum14_05, _mm256_madd_epi16(_val32, _w01)); - _sum06_17 = _mm256_add_epi32(_sum06_17, _mm256_madd_epi16(_val23, _w23)); - _sum16_07 = _mm256_add_epi32(_sum16_07, _mm256_madd_epi16(_val32, _w23)); -#endif - - r0 += 32; - k0 += 32; - } - - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum04_15, _sum14_05); - _tmp1 = _mm256_unpacklo_epi32(_sum06_17, _sum16_07); - _tmp2 = _mm256_unpackhi_epi32(_sum04_15, _sum14_05); - _tmp3 = _mm256_unpackhi_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum14_05 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum06_17 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum16_07 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum14_05); - _sum06_17 = _mm256_add_epi32(_sum06_17, _sum16_07); - _sum04_15 = _mm256_add_epi32(_sum04_15, _sum06_17); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - _sum04_15 = _mm256_permutevar8x32_epi32(_sum04_15, _perm_mask); - - _mm256_storeu_si256((__m256i*)output0_tm, _sum00_11); - _mm256_storeu_si256((__m256i*)(output0_tm + 8), _sum04_15); - output0_tm += 16; - } -#endif - for (; i + 1 < tiles; i += 2) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2); -#else - const short* r0 = bb2.row(i / 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum00_11 = _mm256_setzero_si256(); - __m256i _sum10_01 = _mm256_setzero_si256(); - __m256i _sum02_13 = _mm256_setzero_si256(); - __m256i _sum12_03 = _mm256_setzero_si256(); -#else - __m128i _sum00 = _mm_setzero_si128(); - __m128i _sum01 = _mm_setzero_si128(); - __m128i _sum02 = _mm_setzero_si128(); - __m128i _sum03 = _mm_setzero_si128(); - __m128i _sum10 = _mm_setzero_si128(); - __m128i _sum11 = _mm_setzero_si128(); - __m128i _sum12 = _mm_setzero_si128(); - __m128i _sum13 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { -#if __AVX2__ - // 0 1 2 3 4 5 6 7 8 9 a b c d e f - __m256i _val01 = _mm256_loadu_si256((const __m256i*)r0); - - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _val10 = _mm256_permute4x64_epi64(_val01, 78); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum00_11 = _mm256_dpwssd_epi32(_sum00_11, _val01, _w01); - _sum10_01 = _mm256_dpwssd_epi32(_sum10_01, _val10, _w01); - _sum02_13 = _mm256_dpwssd_epi32(_sum02_13, _val01, _w23); - _sum12_03 = _mm256_dpwssd_epi32(_sum12_03, _val10, _w23); -#else - _sum00_11 = _mm256_add_epi32(_sum00_11, _mm256_madd_epi16(_val01, _w01)); - _sum10_01 = _mm256_add_epi32(_sum10_01, _mm256_madd_epi16(_val10, _w01)); - _sum02_13 = _mm256_add_epi32(_sum02_13, _mm256_madd_epi16(_val01, _w23)); - _sum12_03 = _mm256_add_epi32(_sum12_03, _mm256_madd_epi16(_val10, _w23)); -#endif -#else - // 0 1 2 3 4 5 6 7 - __m128i _val0 = _mm_loadu_si128((const __m128i*)r0); - __m128i _val1 = _mm_loadu_si128((const __m128i*)(r0 + 8)); - - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum00 = _mm_maddd_epi16(_val0, _w0, _sum00); - _sum01 = _mm_maddd_epi16(_val0, _w1, _sum01); - _sum02 = _mm_maddd_epi16(_val0, _w2, _sum02); - _sum03 = _mm_maddd_epi16(_val0, _w3, _sum03); - _sum10 = _mm_maddd_epi16(_val1, _w0, _sum10); - _sum11 = _mm_maddd_epi16(_val1, _w1, _sum11); - _sum12 = _mm_maddd_epi16(_val1, _w2, _sum12); - _sum13 = _mm_maddd_epi16(_val1, _w3, _sum13); -#else - _sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00); - _sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01); - _sum02 = _mm_add_epi32(_mm_madd_epi16(_val0, _w2), _sum02); - _sum03 = _mm_add_epi32(_mm_madd_epi16(_val0, _w3), _sum03); - _sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10); - _sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11); - _sum12 = _mm_add_epi32(_mm_madd_epi16(_val1, _w2), _sum12); - _sum13 = _mm_add_epi32(_mm_madd_epi16(_val1, _w3), _sum13); -#endif -#endif - - r0 += 16; - k0 += 32; - } - -#if __AVX2__ - // transpose 4x8 - { - __m256i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm256_unpacklo_epi32(_sum00_11, _sum10_01); - _tmp1 = _mm256_unpacklo_epi32(_sum02_13, _sum12_03); - _tmp2 = _mm256_unpackhi_epi32(_sum00_11, _sum10_01); - _tmp3 = _mm256_unpackhi_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _sum10_01 = _mm256_unpackhi_epi64(_tmp0, _tmp1); - _sum02_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3); - _sum12_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum10_01); - _sum02_13 = _mm256_add_epi32(_sum02_13, _sum12_03); - _sum00_11 = _mm256_add_epi32(_sum00_11, _sum02_13); - - __m256i _perm_mask = _mm256_set_epi32(6, 3, 4, 1, 7, 2, 5, 0); - _sum00_11 = _mm256_permutevar8x32_epi32(_sum00_11, _perm_mask); - - _mm256_storeu_si256((__m256i*)output0_tm, _sum00_11); -#else - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum00, _sum01); - _tmp1 = _mm_unpacklo_epi32(_sum02, _sum03); - _tmp2 = _mm_unpackhi_epi32(_sum00, _sum01); - _tmp3 = _mm_unpackhi_epi32(_sum02, _sum03); - _sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum10, _sum11); - _tmp1 = _mm_unpacklo_epi32(_sum12, _sum13); - _tmp2 = _mm_unpackhi_epi32(_sum10, _sum11); - _tmp3 = _mm_unpackhi_epi32(_sum12, _sum13); - _sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum00 = _mm_add_epi32(_sum00, _sum01); - _sum02 = _mm_add_epi32(_sum02, _sum03); - _sum10 = _mm_add_epi32(_sum10, _sum11); - _sum12 = _mm_add_epi32(_sum12, _sum13); - - _sum00 = _mm_add_epi32(_sum00, _sum02); - _sum10 = _mm_add_epi32(_sum10, _sum12); - - _mm_storeu_si128((__m128i*)output0_tm, _sum00); - _mm_storeu_si128((__m128i*)(output0_tm + 4), _sum10); -#endif - output0_tm += 8; - } - for (; i < tiles; i++) - { -#if __AVX2__ - const short* r0 = bb2.row(i / 4 + (i % 4) / 2 + i % 2); -#else - const short* r0 = bb2.row(i / 2 + i % 2); -#endif - const short* k0 = kernel0_tm.row(r); - - int nn = inch; // inch always > 0 - -#if __AVX2__ - __m256i _sum0_1 = _mm256_setzero_si256(); - __m256i _sum2_3 = _mm256_setzero_si256(); -#else - __m128i _sum0 = _mm_setzero_si128(); - __m128i _sum1 = _mm_setzero_si128(); - __m128i _sum2 = _mm_setzero_si128(); - __m128i _sum3 = _mm_setzero_si128(); -#endif - - for (int j = 0; j < nn; j++) - { - // 0 1 2 3 4 5 6 7 - __m128i _val = _mm_loadu_si128((const __m128i*)r0); -#if __AVX2__ - __m256i _w01 = _mm256_loadu_si256((const __m256i*)k0); - __m256i _w23 = _mm256_loadu_si256((const __m256i*)(k0 + 16)); - - __m256i _valval = _mm256_inserti128_si256(_mm256_castsi128_si256(_val), _val, 1); - -#if __AVXVNNI__ || __AVX512VNNI__ - _sum0_1 = _mm256_dpwssd_epi32(_sum0_1, _valval, _w01); - _sum2_3 = _mm256_dpwssd_epi32(_sum2_3, _valval, _w23); -#else - _sum0_1 = _mm256_add_epi32(_sum0_1, _mm256_madd_epi16(_valval, _w01)); - _sum2_3 = _mm256_add_epi32(_sum2_3, _mm256_madd_epi16(_valval, _w23)); -#endif -#else - __m128i _w0 = _mm_loadu_si128((const __m128i*)k0); - __m128i _w1 = _mm_loadu_si128((const __m128i*)(k0 + 8)); - __m128i _w2 = _mm_loadu_si128((const __m128i*)(k0 + 16)); - __m128i _w3 = _mm_loadu_si128((const __m128i*)(k0 + 24)); - -#if __XOP__ - _sum0 = _mm_maddd_epi16(_val, _w0, _sum0); - _sum1 = _mm_maddd_epi16(_val, _w1, _sum1); - _sum2 = _mm_maddd_epi16(_val, _w2, _sum2); - _sum3 = _mm_maddd_epi16(_val, _w3, _sum3); -#else - _sum0 = _mm_add_epi32(_mm_madd_epi16(_val, _w0), _sum0); - _sum1 = _mm_add_epi32(_mm_madd_epi16(_val, _w1), _sum1); - _sum2 = _mm_add_epi32(_mm_madd_epi16(_val, _w2), _sum2); - _sum3 = _mm_add_epi32(_mm_madd_epi16(_val, _w3), _sum3); -#endif -#endif - - r0 += 8; - k0 += 32; - } - -#if __AVX2__ - __m128i _sum0 = _mm256_extracti128_si256(_sum0_1, 0); - __m128i _sum1 = _mm256_extracti128_si256(_sum0_1, 1); - __m128i _sum2 = _mm256_extracti128_si256(_sum2_3, 0); - __m128i _sum3 = _mm256_extracti128_si256(_sum2_3, 1); -#endif - - // transpose 4x4 - { - __m128i _tmp0, _tmp1, _tmp2, _tmp3; - _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); - _tmp1 = _mm_unpacklo_epi32(_sum2, _sum3); - _tmp2 = _mm_unpackhi_epi32(_sum0, _sum1); - _tmp3 = _mm_unpackhi_epi32(_sum2, _sum3); - _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp1); - _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp1); - _sum2 = _mm_unpacklo_epi64(_tmp2, _tmp3); - _sum3 = _mm_unpackhi_epi64(_tmp2, _tmp3); - } - - _sum0 = _mm_add_epi32(_sum0, _sum1); - _sum2 = _mm_add_epi32(_sum2, _sum3); - - _sum0 = _mm_add_epi32(_sum0, _sum2); - - _mm_storeu_si128((__m128i*)output0_tm, _sum0); - output0_tm += 4; - } - } - } - } - bottom_blob_tm = Mat(); - // END dot - - // BEGIN transform output - Mat top_blob_bordered; - if (outw == top_blob.w && outh == top_blob.h) - { - top_blob_bordered = top_blob; - } - else - { - top_blob_bordered.create(outw, outh, outch, 4u * 4, 4, opt.workspace_allocator); - } - { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - int tmp[4][6][4]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const int* output0_tm_0 = (const int*)out0_tm + (i * w_tm / 6 + j) * 4; - const int* output0_tm_1 = output0_tm_0 + tiles * 4; - const int* output0_tm_2 = output0_tm_0 + tiles * 8; - const int* output0_tm_3 = output0_tm_0 + tiles * 12; - const int* output0_tm_4 = output0_tm_0 + tiles * 16; - const int* output0_tm_5 = output0_tm_0 + tiles * 20; - - int* output0 = out0.row(i * 4) + (j * 4) * 4; - - // TODO sse optimize - for (int m = 0; m < 5; m++) - { - __m128i _out0tm0 = _mm_loadu_si128((const __m128i*)output0_tm_0); - __m128i _out0tm1 = _mm_loadu_si128((const __m128i*)output0_tm_1); - __m128i _out0tm2 = _mm_loadu_si128((const __m128i*)output0_tm_2); - __m128i _out0tm3 = _mm_loadu_si128((const __m128i*)output0_tm_3); - __m128i _out0tm4 = _mm_loadu_si128((const __m128i*)output0_tm_4); - __m128i _out0tm5 = _mm_loadu_si128((const __m128i*)output0_tm_5); - - __m128i _tmp02a = _mm_add_epi32(_out0tm1, _out0tm2); - __m128i _tmp13a = _mm_sub_epi32(_out0tm1, _out0tm2); - - __m128i _tmp02b = _mm_add_epi32(_out0tm3, _out0tm4); - __m128i _tmp13b = _mm_sub_epi32(_out0tm3, _out0tm4); - - __m128i _tmp0m = _mm_add_epi32(_mm_add_epi32(_out0tm0, _tmp02a), _tmp02b); - __m128i _tmp1m = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _tmp2m = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _tmp3m = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_out0tm5, 2)), _mm_slli_epi32(_tmp13b, 3)); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - for (int m = 5; m < 6; m++) - { - __m128i _out0tm0 = _mm_loadu_si128((const __m128i*)output0_tm_0); - __m128i _out0tm1 = _mm_loadu_si128((const __m128i*)output0_tm_1); - __m128i _out0tm2 = _mm_loadu_si128((const __m128i*)output0_tm_2); - __m128i _out0tm3 = _mm_loadu_si128((const __m128i*)output0_tm_3); - __m128i _out0tm4 = _mm_loadu_si128((const __m128i*)output0_tm_4); - __m128i _out0tm5 = _mm_loadu_si128((const __m128i*)output0_tm_5); - - __m128i _tmp02a = _mm_add_epi32(_out0tm1, _out0tm2); - __m128i _tmp13a = _mm_sub_epi32(_out0tm1, _out0tm2); - - __m128i _tmp02b = _mm_add_epi32(_out0tm3, _out0tm4); - __m128i _tmp13b = _mm_sub_epi32(_out0tm3, _out0tm4); - - __m128i _tmp0m = _mm_add_epi32(_mm_add_epi32(_out0tm0, _tmp02a), _tmp02b); - __m128i _tmp1m = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _tmp2m = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _tmp3m = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_out0tm5, 2)), _mm_slli_epi32(_tmp13b, 3)); - - _tmp0m = _mm_slli_epi32(_tmp0m, 2); - _tmp1m = _mm_slli_epi32(_tmp1m, 2); - _tmp2m = _mm_slli_epi32(_tmp2m, 2); - _tmp3m = _mm_slli_epi32(_tmp3m, 2); - - _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0m); - _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1m); - _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2m); - _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - - for (int m = 0; m < 4; m++) - { - __m128i _tmp00 = _mm_loadu_si128((const __m128i*)tmp[m][0]); - __m128i _tmp01 = _mm_loadu_si128((const __m128i*)tmp[m][1]); - __m128i _tmp02 = _mm_loadu_si128((const __m128i*)tmp[m][2]); - __m128i _tmp03 = _mm_loadu_si128((const __m128i*)tmp[m][3]); - __m128i _tmp04 = _mm_loadu_si128((const __m128i*)tmp[m][4]); - __m128i _tmp05 = _mm_loadu_si128((const __m128i*)tmp[m][5]); - - __m128i _tmp02a = _mm_add_epi32(_tmp01, _tmp02); - __m128i _tmp13a = _mm_sub_epi32(_tmp01, _tmp02); - - __m128i _tmp02b = _mm_add_epi32(_tmp03, _tmp04); - __m128i _tmp13b = _mm_sub_epi32(_tmp03, _tmp04); - - __m128i _out00 = _mm_add_epi32(_mm_add_epi32(_tmp00, _tmp02a), _tmp02b); - __m128i _out01 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); - __m128i _out02 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); - __m128i _out03 = _mm_add_epi32(_mm_add_epi32(_tmp05, _tmp13a), _mm_slli_epi32(_tmp13b, 3)); - - // TODO use integer trick for division by 576 - __m128 _v576 = _mm_set1_ps(1.0 / 576); - _out00 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out00), _v576)); - _out01 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out01), _v576)); - _out02 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out02), _v576)); - _out03 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_out03), _v576)); - - _mm_storeu_si128((__m128i*)output0, _out00); - _mm_storeu_si128((__m128i*)(output0 + 4), _out01); - _mm_storeu_si128((__m128i*)(output0 + 8), _out02); - _mm_storeu_si128((__m128i*)(output0 + 12), _out03); - - output0 += outw * 4; - } - } - } - } - } - // END transform output - - // cut result pad - copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); -} diff --git a/src/layer/x86/convolution_3x3_winograd_int8.h b/src/layer/x86/convolution_3x3_winograd_int8.h new file mode 100644 index 000000000000..8c7b891b0dda --- /dev/null +++ b/src/layer/x86/convolution_3x3_winograd_int8.h @@ -0,0 +1,6407 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ +void conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ +void conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +void conv3x3s1_winograd23_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); +void conv3x3s1_winograd23_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt); +void conv3x3s1_winograd43_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ +void conv3x3s1_winograd23_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +void conv3x3s1_winograd43_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt); +#endif +#endif + +static void pack_A_tile_int8(const Mat& A, Mat& AT, int batch, int max_ii, int max_kk) +{ + const int N = max_kk * batch; + + for (int b = 0; b < batch; b++) + { + short* pp = AT.row(b); + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + pp[8] = p0[N * 4]; + pp[9] = p0[N * 4 + batch]; + pp[10] = p0[N * 5]; + pp[11] = p0[N * 5 + batch]; + pp[12] = p0[N * 6]; + pp[13] = p0[N * 6 + batch]; + pp[14] = p0[N * 7]; + pp[15] = p0[N * 7 + batch]; + pp[16] = p0[N * 8]; + pp[17] = p0[N * 8 + batch]; + pp[18] = p0[N * 9]; + pp[19] = p0[N * 9 + batch]; + pp[20] = p0[N * 10]; + pp[21] = p0[N * 10 + batch]; + pp[22] = p0[N * 11]; + pp[23] = p0[N * 11 + batch]; + pp[24] = p0[N * 12]; + pp[25] = p0[N * 12 + batch]; + pp[26] = p0[N * 13]; + pp[27] = p0[N * 13 + batch]; + pp[28] = p0[N * 14]; + pp[29] = p0[N * 14 + batch]; + pp[30] = p0[N * 15]; + pp[31] = p0[N * 15 + batch]; + p0 += batch * 2; + pp += 32; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + pp[8] = p0[N * 8]; + pp[9] = p0[N * 9]; + pp[10] = p0[N * 10]; + pp[11] = p0[N * 11]; + pp[12] = p0[N * 12]; + pp[13] = p0[N * 13]; + pp[14] = p0[N * 14]; + pp[15] = p0[N * 15]; + p0 += batch; + pp += 16; + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + pp[8] = p0[N * 4]; + pp[9] = p0[N * 4 + batch]; + pp[10] = p0[N * 5]; + pp[11] = p0[N * 5 + batch]; + pp[12] = p0[N * 6]; + pp[13] = p0[N * 6 + batch]; + pp[14] = p0[N * 7]; + pp[15] = p0[N * 7 + batch]; + p0 += batch * 2; + pp += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + pp[4] = p0[N * 4]; + pp[5] = p0[N * 5]; + pp[6] = p0[N * 6]; + pp[7] = p0[N * 7]; + p0 += batch; + pp += 8; + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + pp[4] = p0[N * 2]; + pp[5] = p0[N * 2 + batch]; + pp[6] = p0[N * 3]; + pp[7] = p0[N * 3 + batch]; + p0 += batch * 2; + pp += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + pp[2] = p0[N * 2]; + pp[3] = p0[N * 3]; + p0 += batch; + pp += 4; + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[batch]; + pp[2] = p0[N]; + pp[3] = p0[N + batch]; + p0 += batch * 2; + pp += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[N]; + p0 += batch; + pp += 2; + } + } + for (; ii < max_ii; ii++) + { + const short* p0 = (const short*)A + ii * N + b; + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += batch; + pp += 1; + } + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int batch, int max_jj, int max_kk, int nT) +{ + #pragma omp parallel for num_threads(nT) + for (int b = 0; b < batch; b++) + { + short* pp = BT.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* p0 = B; + + int kk = 0; + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _r4 = _mm512_loadu_si512((const __m512i*)(p0 + 128)); + __m512i _r5 = _mm512_loadu_si512((const __m512i*)(p0 + 160)); + __m512i _r6 = _mm512_loadu_si512((const __m512i*)(p0 + 192)); + __m512i _r7 = _mm512_loadu_si512((const __m512i*)(p0 + 224)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_r4, _r6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_r4, _r6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_r5, _r7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_r5, _r7, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _r4 = _mm512_unpacklo_epi32(_tmp4, _tmp5); + _r5 = _mm512_unpackhi_epi32(_tmp4, _tmp5); + _r6 = _mm512_unpacklo_epi32(_tmp6, _tmp7); + _r7 = _mm512_unpackhi_epi32(_tmp6, _tmp7); + _tmp0 = _mm512_unpacklo_epi64(_r0, _r2); + _tmp1 = _mm512_unpackhi_epi64(_r0, _r2); + _tmp2 = _mm512_unpacklo_epi64(_r1, _r3); + _tmp3 = _mm512_unpackhi_epi64(_r1, _r3); + _tmp4 = _mm512_unpacklo_epi64(_r4, _r6); + _tmp5 = _mm512_unpackhi_epi64(_r4, _r6); + _tmp6 = _mm512_unpacklo_epi64(_r5, _r7); + _tmp7 = _mm512_unpackhi_epi64(_r5, _r7); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp4, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp1, _tmp5, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_i32x4(_tmp2, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_i32x4(_tmp3, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _r4 = _mm512_shuffle_i32x4(_tmp0, _tmp4, _MM_SHUFFLE(3, 1, 3, 1)); + _r5 = _mm512_shuffle_i32x4(_tmp1, _tmp5, _MM_SHUFFLE(3, 1, 3, 1)); + _r6 = _mm512_shuffle_i32x4(_tmp2, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _r7 = _mm512_shuffle_i32x4(_tmp3, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + _mm512_storeu_si512((__m512i*)(pp + 128), _r4); + _mm512_storeu_si512((__m512i*)(pp + 160), _r5); + _mm512_storeu_si512((__m512i*)(pp + 192), _r6); + _mm512_storeu_si512((__m512i*)(pp + 224), _r7); + p0 += max_jj * batch * 16; + pp += 256; + } + p0 -= (b * max_jj + jj) * 16; + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r2, _r3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r2, _r3, _MM_SHUFFLE(3, 1, 3, 1)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _tmp0 = _mm512_permutex_epi64(_r0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm512_permutex_epi64(_r1, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp2 = _mm512_permutex_epi64(_r2, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp3 = _mm512_permutex_epi64(_r3, _MM_SHUFFLE(3, 1, 2, 0)); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 3, 1)); + _r2 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r3 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + p0 += max_jj * batch * 8; + pp += 128; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + _mm512_storeu_si512((__m512i*)pp, _r0); + p0 += max_jj * batch * 2; + pp += 32; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)p0); + _mm256_store_si256((__m256i*)pp, _r0); + p0 += max_jj * batch; + pp += 16; + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* p0 = B; + + int kk = 0; +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _r2 = _mm512_loadu_si512((const __m512i*)(p0 + 64)); + __m512i _r3 = _mm512_loadu_si512((const __m512i*)(p0 + 96)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_r1, _r3, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r2 = _mm512_unpacklo_epi32(_tmp2, _tmp3); + _r3 = _mm512_unpackhi_epi32(_tmp2, _tmp3); + _tmp0 = _mm512_unpacklo_epi64(_r0, _r2); + _tmp1 = _mm512_unpackhi_epi64(_r0, _r2); + _tmp2 = _mm512_unpacklo_epi64(_r1, _r3); + _tmp3 = _mm512_unpackhi_epi64(_r1, _r3); + _r0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _r2 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _r3 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + _mm512_storeu_si512((__m512i*)(pp + 64), _r2); + _mm512_storeu_si512((__m512i*)(pp + 96), _r3); + p0 += max_jj * batch * 16; + pp += 128; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { +#if __AVX__ + __m256 _r0 = _mm256_loadu_ps((const float*)p0); + __m256 _r1 = _mm256_loadu_ps((const float*)(p0 + 16)); + __m256 _r2 = _mm256_loadu_ps((const float*)(p0 + 32)); + __m256 _r3 = _mm256_loadu_ps((const float*)(p0 + 48)); + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp2 = _mm256_permute2f128_ps(_r1, _r3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp3 = _mm256_permute2f128_ps(_r1, _r3, _MM_SHUFFLE(0, 3, 0, 1)); + _r0 = _mm256_unpacklo_ps(_tmp0, _tmp1); + _r1 = _mm256_unpackhi_ps(_tmp0, _tmp1); + _r2 = _mm256_unpacklo_ps(_tmp2, _tmp3); + _r3 = _mm256_unpackhi_ps(_tmp2, _tmp3); + _tmp0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_r0), _mm256_castps_pd(_r2))); + _tmp1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_r0), _mm256_castps_pd(_r2))); + _tmp2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(_r1), _mm256_castps_pd(_r3))); + _tmp3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(_r1), _mm256_castps_pd(_r3))); + _mm256_storeu_ps((float*)pp, _tmp0); + _mm256_storeu_ps((float*)(pp + 16), _tmp1); + _mm256_storeu_ps((float*)(pp + 32), _tmp2); + _mm256_storeu_ps((float*)(pp + 48), _tmp3); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_load_si128((const __m128i*)(p0 + 8 * 2)); + __m128i _r3 = _mm_load_si128((const __m128i*)(p0 + 8 * 3)); + __m128i _r4 = _mm_load_si128((const __m128i*)(p0 + 8 * 4)); + __m128i _r5 = _mm_load_si128((const __m128i*)(p0 + 8 * 5)); + __m128i _r6 = _mm_load_si128((const __m128i*)(p0 + 8 * 6)); + __m128i _r7 = _mm_load_si128((const __m128i*)(p0 + 8 * 7)); + transpose4x8_epi32(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); + _mm_store_si128((__m128i*)pp, _r0); + _mm_store_si128((__m128i*)(pp + 8), _r1); + _mm_store_si128((__m128i*)(pp + 8 * 2), _r2); + _mm_store_si128((__m128i*)(pp + 8 * 3), _r3); + _mm_store_si128((__m128i*)(pp + 8 * 4), _r4); + _mm_store_si128((__m128i*)(pp + 8 * 5), _r5); + _mm_store_si128((__m128i*)(pp + 8 * 6), _r6); + _mm_store_si128((__m128i*)(pp + 8 * 7), _r7); +#endif // __AVX__ + p0 += max_jj * batch * 8; + pp += 64; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX__ + __m256 _r0 = _mm256_loadu_ps((const float*)p0); + _mm256_storeu_ps((float*)pp, _r0); +#else + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)(p0 + 8)); + _mm_store_si128((__m128i*)pp, _r0); + _mm_store_si128((__m128i*)(pp + 8), _r1); +#endif // __AVX__ + p0 += max_jj * batch * 2; + pp += 16; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + _mm_store_si128((__m128i*)pp, _r0); + p0 += max_jj * batch; + pp += 8; + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* p0 = B; + + int kk = 0; +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m512i _r0 = _mm512_loadu_si512((const __m512i*)p0); + __m512i _r1 = _mm512_loadu_si512((const __m512i*)(p0 + 32)); + __m512i _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _r0 = _mm512_permutex_epi64(_r0, _MM_SHUFFLE(3, 1, 2, 0)); + _r1 = _mm512_permutex_epi64(_r1, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(1, 0, 1, 0)); + _tmp1 = _mm512_shuffle_i32x4(_r0, _r1, _MM_SHUFFLE(3, 2, 3, 2)); + _r0 = _mm512_unpacklo_epi64(_tmp0, _tmp1); + _r1 = _mm512_unpackhi_epi64(_tmp0, _tmp1); + _mm512_storeu_si512((__m512i*)pp, _r0); + _mm512_storeu_si512((__m512i*)(pp + 32), _r1); + p0 += max_jj * batch * 16; + pp += 64; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _r2 = _mm_load_si128((const __m128i*)(p0 + 8 * 2)); + __m128i _r3 = _mm_load_si128((const __m128i*)(p0 + 8 * 3)); + transpose4x4_epi32(_r0, _r1, _r2, _r3); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 8), _r1); + _mm_storeu_si128((__m128i*)(pp + 8 * 2), _r2); + _mm_storeu_si128((__m128i*)(pp + 8 * 3), _r3); + p0 += max_jj * batch * 8; + pp += 32; + } + p0 -= (b * max_jj + jj) * 8; + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _r0); + p0 += max_jj * batch * 2; + pp += 8; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + p0 += max_jj * batch; + pp += 4; + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* p0 = B; + + int kk = 0; +#if __SSE2__ +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)p0); + __m256i _r1 = _mm256_load_si256((const __m256i*)(p0 + 16)); + transpose8x2_epi32(_r0, _r1); + _mm256_storeu_si256((__m256i*)pp, _r0); + _mm256_storeu_si256((__m256i*)(pp + 16), _r1); + p0 += max_jj * batch * 16; + pp += 32; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + __m128i _r1 = _mm_load_si128((const __m128i*)(p0 + 8)); + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _tmp0); + _mm_storeu_si128((__m128i*)(pp + 8), _tmp1); + p0 += max_jj * batch * 8; + pp += 16; + } + p0 -= (b * max_jj + jj) * 8; +#endif // __SSE2__ + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + p0 += max_jj * batch * 2; + pp += 4; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch; + pp += 2; + } + } + for (; jj < max_jj; jj++) + { + const short* p0 = B; + + int kk = 0; +#if __SSE2__ +#if __AVX512F__ + p0 += (b * max_jj + jj) * 16; + for (; kk + 15 < max_kk; kk += 16) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)p0); + _mm256_storeu_si256((__m256i*)pp, _r0); + p0 += max_jj * batch * 16; + pp += 16; + } + p0 -= (b * max_jj + jj) * 16; +#endif // __AVX512F__ + p0 += (b * max_jj + jj) * 8; + for (; kk + 7 < max_kk; kk += 8) + { + __m128i _r0 = _mm_load_si128((const __m128i*)p0); + _mm_storeu_si128((__m128i*)pp, _r0); + p0 += max_jj * batch * 8; + pp += 8; + } + p0 -= (b * max_jj + jj) * 8; +#endif // __SSE2__ + p0 += (b * max_jj + jj) * 2; + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + p0 += max_jj * batch * 2; + pp += 2; + } + p0 -= (b * max_jj + jj) * 2; + p0 += (b * max_jj + jj); + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + p0 += max_jj * batch; + pp += 1; + } + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& top_blob, int batch, int max_ii, int max_jj, int k, int max_kk, bool k_end) +{ + int* outptr = top_blob; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + __m512i _sum8; + __m512i _sum9; + __m512i _suma; + __m512i _sumb; + __m512i _sumc; + __m512i _sumd; + __m512i _sume; + __m512i _sumf; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + _sum8 = _mm512_setzero_si512(); + _sum9 = _mm512_setzero_si512(); + _suma = _mm512_setzero_si512(); + _sumb = _mm512_setzero_si512(); + _sumc = _mm512_setzero_si512(); + _sumd = _mm512_setzero_si512(); + _sume = _mm512_setzero_si512(); + _sumf = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 112)); + _sum8 = _mm512_load_si512((const __m512i*)(outptr + 128)); + _sum9 = _mm512_load_si512((const __m512i*)(outptr + 128 + 16)); + _suma = _mm512_load_si512((const __m512i*)(outptr + 128 + 32)); + _sumb = _mm512_load_si512((const __m512i*)(outptr + 128 + 48)); + _sumc = _mm512_load_si512((const __m512i*)(outptr + 128 + 64)); + _sumd = _mm512_load_si512((const __m512i*)(outptr + 128 + 80)); + _sume = _mm512_load_si512((const __m512i*)(outptr + 128 + 96)); + _sumf = _mm512_load_si512((const __m512i*)(outptr + 128 + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m512i _pA1 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 1, 0, 3)); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pA3 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); + _sum8 = _mm512_dpwssd_epi32(_sum8, _pA2, _pB0); + _sum9 = _mm512_dpwssd_epi32(_sum9, _pA2, _pB1); + _suma = _mm512_dpwssd_epi32(_suma, _pA2, _pB2); + _sumb = _mm512_dpwssd_epi32(_sumb, _pA2, _pB3); + _sumc = _mm512_dpwssd_epi32(_sumc, _pA3, _pB0); + _sumd = _mm512_dpwssd_epi32(_sumd, _pA3, _pB1); + _sume = _mm512_dpwssd_epi32(_sume, _pA3, _pB2); + _sumf = _mm512_dpwssd_epi32(_sumf, _pA3, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); + _sum8 = _mm512_add_epi32(_sum8, _mm512_madd_epi16(_pA2, _pB0)); + _sum9 = _mm512_add_epi32(_sum9, _mm512_madd_epi16(_pA2, _pB1)); + _suma = _mm512_add_epi32(_suma, _mm512_madd_epi16(_pA2, _pB2)); + _sumb = _mm512_add_epi32(_sumb, _mm512_madd_epi16(_pA2, _pB3)); + _sumc = _mm512_add_epi32(_sumc, _mm512_madd_epi16(_pA3, _pB0)); + _sumd = _mm512_add_epi32(_sumd, _mm512_madd_epi16(_pA3, _pB1)); + _sume = _mm512_add_epi32(_sume, _mm512_madd_epi16(_pA3, _pB2)); + _sumf = _mm512_add_epi32(_sumf, _mm512_madd_epi16(_pA3, _pB3)); +#endif + + pA += 32; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pB)); + + __m512i _pA1 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(2, 1, 0, 3)); + __m512i _pA2 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pA3 = _mm512_shuffle_i32x4(_pA0, _pA0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA0, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA0, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA1, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA1, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA1, _pB3); + __m512i _s8 = _mm512_mullo_epi32(_pA2, _pB0); + __m512i _s9 = _mm512_mullo_epi32(_pA2, _pB1); + __m512i _sa = _mm512_mullo_epi32(_pA2, _pB2); + __m512i _sb = _mm512_mullo_epi32(_pA2, _pB3); + __m512i _sc = _mm512_mullo_epi32(_pA3, _pB0); + __m512i _sd = _mm512_mullo_epi32(_pA3, _pB1); + __m512i _se = _mm512_mullo_epi32(_pA3, _pB2); + __m512i _sf = _mm512_mullo_epi32(_pA3, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + _sum8 = _mm512_add_epi32(_sum8, _s8); + _sum9 = _mm512_add_epi32(_sum9, _s9); + _suma = _mm512_add_epi32(_suma, _sa); + _sumb = _mm512_add_epi32(_sumb, _sb); + _sumc = _mm512_add_epi32(_sumc, _sc); + _sumd = _mm512_add_epi32(_sumd, _sd); + _sume = _mm512_add_epi32(_sume, _se); + _sumf = _mm512_add_epi32(_sumf, _sf); + + pA += 16; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 88 99 aa bb cc dd ee ff + // 01 12 23 30 45 56 67 74 89 9a ab b8 cd de ef fc + // 02 13 20 31 46 57 64 75 8a 9b a8 b9 ce df ec fd + // 03 10 21 32 47 54 65 76 8b 98 a9 ba cf dc ed fe + // c0 d1 e2 f3 04 15 26 37 48 59 6a 7b 8c 9d ae bf + // c1 d2 e3 f0 05 16 27 34 49 5a 6b 78 8d 9e af bc + // c2 d3 e0 f1 06 17 24 35 4a 5b 68 79 8e 9f ac bd + // c3 d0 e1 f2 07 14 25 36 4b 58 69 7a 8f 9c ad be + // 80 91 a2 b3 c4 d5 e6 f7 08 19 2a 3b 4c 5d 6e 7f + // 81 92 a3 b0 c5 d6 e7 f4 09 1a 2b 38 4d 5e 6f 7c + // 82 93 a0 b1 c6 d7 e4 f5 0a 1b 28 39 4e 5f 6c 7d + // 83 90 a1 b2 c7 d4 e5 f6 0b 18 29 3a 4f 5c 6d 7e + // 40 51 62 73 84 95 a6 b7 c8 d9 ea fb 0c 1d 2e 3f + // 41 52 63 70 85 96 a7 b4 c9 da eb f8 0d 1e 2f 3c + // 42 53 60 71 86 97 a4 b5 ca db e8 f9 0e 1f 2c 3d + // 43 50 61 72 87 94 a5 b6 cb d8 e9 fa 0f 1c 2d 3e + // to + // 00 10 20 30 44 54 64 74 88 98 a8 b8 cc dc ec fc + // 01 11 21 31 45 55 65 75 89 99 a9 b9 cd dd ed fd + // 02 12 22 32 46 56 66 76 8a 9a aa ba ce de ee fe + // 03 13 23 33 47 57 67 77 8b 9b ab bb cf df ef ff + // c0 d0 e0 f0 04 14 24 34 48 58 68 78 8c 9c ac bc + // c1 d1 e1 f1 05 15 25 35 49 59 69 79 8d 9d ad bd + // c2 d2 e2 f2 06 16 26 36 4a 5a 6a 7a 8e 9e ae be + // c3 d3 e3 f3 07 17 27 37 4b 5b 6b 7b 8f 9f af bf + // 80 90 a0 b0 c4 d4 e4 f4 08 18 28 38 4c 5c 6c 7c + // 81 91 a1 b1 c5 d5 e5 f5 09 19 29 39 4d 5d 6d 7d + // 82 92 a2 b2 c6 d6 e6 f6 0a 1a 2a 3a 4e 5e 6e 7e + // 83 93 a3 b3 c7 d7 e7 f7 0b 1b 2b 3b 4f 5f 6f 7f + // 40 50 60 70 84 94 a4 b4 c8 d8 e8 f8 0c 1c 2c 3c + // 41 51 61 71 85 95 a5 b5 c9 d9 e9 f9 0d 1d 2d 3d + // 42 52 62 72 86 96 a6 b6 ca da ea fa 0e 1e 2e 3e + // 43 53 63 73 87 97 a7 b7 cb db eb fb 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + _sum9 = _mm512_shuffle_epi32(_sum9, _MM_PERM_CBAD); + _suma = _mm512_shuffle_epi32(_suma, _MM_PERM_BADC); + _sumb = _mm512_shuffle_epi32(_sumb, _MM_PERM_ADCB); + _sumd = _mm512_shuffle_epi32(_sumd, _MM_PERM_CBAD); + _sume = _mm512_shuffle_epi32(_sume, _MM_PERM_BADC); + _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + __m512i _tmp8 = _mm512_unpacklo_epi32(_sum8, _sumb); + __m512i _tmp9 = _mm512_unpackhi_epi32(_sum8, _sumb); + __m512i _tmpa = _mm512_unpacklo_epi32(_suma, _sum9); + __m512i _tmpb = _mm512_unpackhi_epi32(_suma, _sum9); + __m512i _tmpc = _mm512_unpacklo_epi32(_sumc, _sumf); + __m512i _tmpd = _mm512_unpackhi_epi32(_sumc, _sumf); + __m512i _tmpe = _mm512_unpacklo_epi32(_sume, _sumd); + __m512i _tmpf = _mm512_unpackhi_epi32(_sume, _sumd); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum8 = _mm512_unpacklo_epi64(_tmp8, _tmpa); + _sum9 = _mm512_unpackhi_epi64(_tmp8, _tmpa); + _suma = _mm512_unpacklo_epi64(_tmpb, _tmp9); + _sumb = _mm512_unpackhi_epi64(_tmpb, _tmp9); + _sumc = _mm512_unpacklo_epi64(_tmpc, _tmpe); + _sumd = _mm512_unpackhi_epi64(_tmpc, _tmpe); + _sume = _mm512_unpacklo_epi64(_tmpf, _tmpd); + _sumf = _mm512_unpackhi_epi64(_tmpf, _tmpd); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + _sum9 = _mm512_shuffle_epi32(_sum9, _MM_PERM_CBAD); + _sumb = _mm512_shuffle_epi32(_sumb, _MM_PERM_CBAD); + _sumd = _mm512_shuffle_epi32(_sumd, _MM_PERM_CBAD); + _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sumc, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum8, _sum4, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sumc, _sum8, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum1, _sumd, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum5, _sum1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum9, _sum5, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sumd, _sum9, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmp8 = _mm512_shuffle_i32x4(_sum2, _sume, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp9 = _mm512_shuffle_i32x4(_sum6, _sum2, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmpa = _mm512_shuffle_i32x4(_suma, _sum6, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmpb = _mm512_shuffle_i32x4(_sume, _suma, _MM_SHUFFLE(1, 3, 1, 3)); + __m512i _tmpc = _mm512_shuffle_i32x4(_sum3, _sumf, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmpd = _mm512_shuffle_i32x4(_sum7, _sum3, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmpe = _mm512_shuffle_i32x4(_sumb, _sum7, _MM_SHUFFLE(0, 2, 0, 2)); + __m512i _tmpf = _mm512_shuffle_i32x4(_sumf, _sumb, _MM_SHUFFLE(1, 3, 1, 3)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp8, _tmpa, _MM_SHUFFLE(3, 1, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmpc, _tmpe, _MM_SHUFFLE(3, 1, 2, 0)); + _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp9, _tmpb, _MM_SHUFFLE(3, 1, 2, 0)); + _sum7 = _mm512_shuffle_i32x4(_tmpd, _tmpf, _MM_SHUFFLE(3, 1, 2, 0)); + _sum8 = _mm512_shuffle_i32x4(_tmp2, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum9 = _mm512_shuffle_i32x4(_tmp6, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _suma = _mm512_shuffle_i32x4(_tmpa, _tmp8, _MM_SHUFFLE(3, 1, 2, 0)); + _sumb = _mm512_shuffle_i32x4(_tmpe, _tmpc, _MM_SHUFFLE(3, 1, 2, 0)); + _sumc = _mm512_shuffle_i32x4(_tmp3, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _sumd = _mm512_shuffle_i32x4(_tmp7, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _sume = _mm512_shuffle_i32x4(_tmpb, _tmp9, _MM_SHUFFLE(3, 1, 2, 0)); + _sumf = _mm512_shuffle_i32x4(_tmpf, _tmpd, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 32), _sum2); + _mm512_store_si512((__m512i*)(outptr + 48), _sum3); + _mm512_store_si512((__m512i*)(outptr + 64), _sum4); + _mm512_store_si512((__m512i*)(outptr + 80), _sum5); + _mm512_store_si512((__m512i*)(outptr + 96), _sum6); + _mm512_store_si512((__m512i*)(outptr + 112), _sum7); + _mm512_store_si512((__m512i*)(outptr + 128), _sum8); + _mm512_store_si512((__m512i*)(outptr + 128 + 16), _sum9); + _mm512_store_si512((__m512i*)(outptr + 128 + 32), _suma); + _mm512_store_si512((__m512i*)(outptr + 128 + 48), _sumb); + _mm512_store_si512((__m512i*)(outptr + 128 + 64), _sumc); + _mm512_store_si512((__m512i*)(outptr + 128 + 80), _sumd); + _mm512_store_si512((__m512i*)(outptr + 128 + 96), _sume); + _mm512_store_si512((__m512i*)(outptr + 128 + 112), _sumf); + outptr += 256; + } + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 16 * 2)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 16 * 3)); + _sum4 = _mm512_load_si512((const __m512i*)(outptr + 16 * 4)); + _sum5 = _mm512_load_si512((const __m512i*)(outptr + 16 * 5)); + _sum6 = _mm512_load_si512((const __m512i*)(outptr + 16 * 6)); + _sum7 = _mm512_load_si512((const __m512i*)(outptr + 16 * 7)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA1, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA0, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA0, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA1, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA1, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA1, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA1, _pB3)); +#endif + + pA += 32; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m256i _pB = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pB)); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA0, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA0, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA1, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA1, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA1, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 16; + pB += 8; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 80 91 a2 b3 c4 d5 e6 f7 + // 01 12 23 30 45 56 67 74 81 92 a3 b0 c5 d6 e7 f4 + // 02 13 20 31 46 57 64 75 82 93 a0 b1 c6 d7 e4 f5 + // 03 10 21 32 47 54 65 76 83 90 a1 b2 c7 d4 e5 f6 + // 40 51 62 73 04 15 26 37 c0 d1 e2 f3 84 95 a6 b7 + // 41 52 63 70 05 16 27 34 c1 d2 e3 f0 85 96 a7 b4 + // 42 53 60 71 06 17 24 35 c2 d3 e0 f1 86 97 a4 b5 + // 43 50 61 72 07 14 25 36 c3 d0 e1 f2 87 94 a5 b6 + // to + // 00 10 20 30 44 54 64 74 80 90 a0 b0 c4 d4 e4 f4 + // 01 11 21 31 45 55 65 75 81 91 a1 b1 c5 d5 e5 f5 + // 02 12 22 32 46 56 66 76 82 92 a2 b2 c6 d6 e6 f6 + // 03 13 23 33 47 57 67 77 83 93 a3 b3 c7 d7 e7 f7 + // 40 50 60 70 04 14 24 34 c0 d0 e0 f0 84 94 a4 b4 + // 41 51 61 71 05 15 25 35 c1 d1 e1 f1 85 95 a5 b5 + // 42 52 62 72 06 16 26 36 c2 d2 e2 f2 86 96 a6 b6 + // 43 53 63 73 07 17 27 37 c3 d3 e3 f3 87 97 a7 b7 + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + } + + // TODO + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(2, 0, 2, 0)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum5, _sum1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum6, _sum2, _MM_SHUFFLE(3, 1, 3, 1)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sum7, _sum3, _MM_SHUFFLE(3, 1, 3, 1)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp1, _tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp2, _tmp2, _MM_SHUFFLE(3, 1, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmp3, _tmp3, _MM_SHUFFLE(3, 1, 2, 0)); + _sum4 = _mm512_shuffle_i32x4(_tmp4, _tmp4, _MM_SHUFFLE(3, 1, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp6, _tmp6, _MM_SHUFFLE(3, 1, 2, 0)); + _sum7 = _mm512_shuffle_i32x4(_tmp7, _tmp7, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 16 * 2), _sum2); + _mm512_store_si512((__m512i*)(outptr + 16 * 3), _sum3); + _mm512_store_si512((__m512i*)(outptr + 16 * 4), _sum4); + _mm512_store_si512((__m512i*)(outptr + 16 * 5), _sum5); + _mm512_store_si512((__m512i*)(outptr + 16 * 6), _sum6); + _mm512_store_si512((__m512i*)(outptr + 16 * 7), _sum7); + outptr += 16 * 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_load_si512((const __m512i*)(outptr + 16 * 2)); + _sum3 = _mm512_load_si512((const __m512i*)(outptr + 16 * 3)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_castsi128_si512(_mm_loadu_si128((const __m128i*)pB)); + __m512i _pB0 = _mm512_shuffle_i32x4(_pB, _pB, _MM_SHUFFLE(0, 0, 0, 0)); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); +#endif + + pA += 32; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB))); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s3 = _mm512_mullo_epi32(_pA1, _pB1); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 16; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 40 51 62 73 80 91 a2 b3 c0 d1 e2 f3 + // 01 12 23 30 41 52 63 70 81 92 a3 b0 c1 d2 e3 f0 + // 20 31 02 13 60 71 42 53 a0 b1 82 93 e0 f1 c2 d3 + // 21 32 03 10 61 72 43 50 a1 b2 83 90 e1 f2 c3 d0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + // 02 12 22 32 42 52 62 72 82 92 a2 b2 c2 d2 e2 f2 + // 03 13 23 33 43 53 63 73 83 93 a3 b3 c3 d3 e3 f3 + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + _mm512_store_si512((__m512i*)(outptr + 16 * 2), _sum2); + _mm512_store_si512((__m512i*)(outptr + 16 * 3), _sum3); + outptr += 16 * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + _sum1 = _mm512_load_si512((const __m512i*)(outptr + 16)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB0 = _mm512_castpd_si512(_mm512_set1_pd(((const double*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CDAB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA, _pB1)); +#endif + + pA += 32; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_set1_epi32(((const int*)pB)[0])); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ABAB); + + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA, _pB1); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 16; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 40 51 60 71 80 91 a0 b1 c0 d1 e0 f1 + // 01 10 21 30 41 50 61 70 81 90 a1 b0 c1 d0 e1 f0 + // to + // 00 10 20 30 40 50 60 70 80 90 a0 b0 c0 d0 e0 f0 + // 01 11 21 31 41 51 61 71 81 91 a1 b1 c1 d1 e1 f1 + { + __m512i _tmp0 = _mm512_shuffle_epi32(_sum0, _MM_PERM_DBCA); + __m512i _tmp1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_ACDB); + _sum0 = _mm512_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm512_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + } + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + _mm512_store_si512((__m512i*)(outptr + 16), _sum1); + outptr += 16 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_load_si512((const __m512i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA = _mm512_loadu_si512((const __m512i*)pA); + __m512i _pB = _mm512_set1_epi32(((const int*)pB)[0]); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA, _pB); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA, _pB)); +#endif + + pA += 32; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_cvtepi16_epi32(_mm256_load_si256((const __m256i*)pA)); + __m512i _pB = _mm512_set1_epi32(pB[0]); + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB); + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 16; + pB += 1; + } + + _mm512_store_si512((__m512i*)outptr, _sum0); + outptr += 16; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + __m512i _sum4; + __m512i _sum5; + __m512i _sum6; + __m512i _sum7; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + _sum4 = _mm512_setzero_si512(); + _sum5 = _mm512_setzero_si512(); + _sum6 = _mm512_setzero_si512(); + _sum7 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + _sum4 = _mm512_loadu_si512((const __m512i*)(outptr + 64)); + _sum5 = _mm512_loadu_si512((const __m512i*)(outptr + 80)); + _sum6 = _mm512_loadu_si512((const __m512i*)(outptr + 96)); + _sum7 = _mm512_loadu_si512((const __m512i*)(outptr + 112)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_shuffle_epi32(_pB0, _MM_PERM_BADC); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA00, _pB2); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA00, _pB3); + _sum4 = _mm512_dpwssd_epi32(_sum4, _pA11, _pB0); + _sum5 = _mm512_dpwssd_epi32(_sum5, _pA11, _pB1); + _sum6 = _mm512_dpwssd_epi32(_sum6, _pA11, _pB2); + _sum7 = _mm512_dpwssd_epi32(_sum7, _pA11, _pB3); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA00, _pB2)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA00, _pB3)); + _sum4 = _mm512_add_epi32(_sum4, _mm512_madd_epi16(_pA11, _pB0)); + _sum5 = _mm512_add_epi32(_sum5, _mm512_madd_epi16(_pA11, _pB1)); + _sum6 = _mm512_add_epi32(_sum6, _mm512_madd_epi16(_pA11, _pB2)); + _sum7 = _mm512_add_epi32(_sum7, _mm512_madd_epi16(_pA11, _pB3)); +#endif // __AVX512VNNI__ + + pA += 16; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + __m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB3 = _mm512_shuffle_epi32(_pB0, _MM_PERM_CBAD); + + __m512i _s0 = _mm512_mullo_epi32(_pA00, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA00, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA00, _pB2); + __m512i _s3 = _mm512_mullo_epi32(_pA00, _pB3); + __m512i _s4 = _mm512_mullo_epi32(_pA11, _pB0); + __m512i _s5 = _mm512_mullo_epi32(_pA11, _pB1); + __m512i _s6 = _mm512_mullo_epi32(_pA11, _pB2); + __m512i _s7 = _mm512_mullo_epi32(_pA11, _pB3); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + _sum4 = _mm512_add_epi32(_sum4, _s4); + _sum5 = _mm512_add_epi32(_sum5, _s5); + _sum6 = _mm512_add_epi32(_sum6, _s6); + _sum7 = _mm512_add_epi32(_sum7, _s7); + + pA += 8; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 08 19 2a 3b 4c 5d 6e 7f + // 01 12 23 30 45 56 67 74 09 1a 2b 38 4d 5e 6f 7c + // 02 13 20 31 46 57 64 75 0a 1b 28 39 4e 5f 6c 7d + // 03 10 21 32 47 54 65 76 0b 18 29 3a 4f 5c 6d 7e + // 40 51 62 73 04 15 26 37 48 59 6a 7b 0c 1d 2e 3f + // 41 52 63 70 05 16 27 34 49 5a 6b 78 0d 1e 2f 3c + // 42 53 60 71 06 17 24 35 4a 5b 68 79 0e 1f 2c 3d + // 43 50 61 72 07 14 25 36 4b 58 69 7a 0f 1c 2d 3e + // to + // 00 10 20 30 44 54 64 74 08 18 28 38 4c 5c 6c 7c + // 01 11 21 31 45 55 65 75 09 19 29 39 4d 5d 6d 7d + // 02 12 22 32 46 56 66 76 0a 1a 2a 3a 4e 5e 6e 7e + // 03 13 23 33 47 57 67 77 0b 1b 2b 3b 4f 5f 6f 7f + // 40 50 60 70 04 14 24 34 48 58 68 78 0c 1c 2c 3c + // 41 51 61 71 05 15 25 35 49 59 69 79 0d 1d 2d 3d + // 42 52 62 72 06 16 26 36 4a 5a 6a 7a 0e 1e 2e 3e + // 43 53 63 73 07 17 27 37 4b 5b 6b 7b 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum2 = _mm512_shuffle_epi32(_sum2, _MM_PERM_BADC); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_ADCB); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum6 = _mm512_shuffle_epi32(_sum6, _MM_PERM_BADC); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_ADCB); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + __m512i _tmp4 = _mm512_unpacklo_epi32(_sum4, _sum7); + __m512i _tmp5 = _mm512_unpackhi_epi32(_sum4, _sum7); + __m512i _tmp6 = _mm512_unpacklo_epi32(_sum6, _sum5); + __m512i _tmp7 = _mm512_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm512_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm512_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm512_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm512_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + _sum5 = _mm512_shuffle_epi32(_sum5, _MM_PERM_CBAD); + _sum7 = _mm512_shuffle_epi32(_sum7, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum0, _sum4, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum1, _sum5, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp4 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp5 = _mm512_shuffle_i32x4(_sum2, _sum6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp6 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp7 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); + _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(1, 3, 1, 3)); + _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); + _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); + _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + _mm512_storeu_si512((__m512i*)(outptr + 64), _sum4); + _mm512_storeu_si512((__m512i*)(outptr + 80), _sum5); + _mm512_storeu_si512((__m512i*)(outptr + 96), _sum6); + _mm512_storeu_si512((__m512i*)(outptr + 112), _sum7); + outptr += 128; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX512F__ + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; +#else + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + __m256i _sum4; + __m256i _sum5; + __m256i _sum6; + __m256i _sum7; +#endif // __AVX512F__ + + if (k == 0) + { +#if __AVX512F__ + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); +#else + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + _sum4 = _mm256_setzero_si256(); + _sum5 = _mm256_setzero_si256(); + _sum6 = _mm256_setzero_si256(); + _sum7 = _mm256_setzero_si256(); +#endif // __AVX512F__ + } + else + { +#if __AVX512F__ + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); +#else + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + _sum4 = _mm256_load_si256((const __m256i*)(outptr + 32)); + _sum5 = _mm256_load_si256((const __m256i*)(outptr + 40)); + _sum6 = _mm256_load_si256((const __m256i*)(outptr + 48)); + _sum7 = _mm256_load_si256((const __m256i*)(outptr + 56)); +#endif // __AVX512F__ + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_permutex_epi64(_pA00, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_shuffle_epi32(_pB01, _MM_PERM_BADC); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA00, _pB01); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA00, _pB23); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA11, _pB01); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA11, _pB23); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA00, _pB01)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA00, _pB23)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA11, _pB01)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA11, _pB23)); +#endif // __AVX512VNNI__ +#else // __AVX512F__ + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + +#if __AVXVNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA0, _pB2); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA0, _pB3); + _sum4 = _mm256_dpwssd_epi32(_sum4, _pA1, _pB0); + _sum5 = _mm256_dpwssd_epi32(_sum5, _pA1, _pB1); + _sum6 = _mm256_dpwssd_epi32(_sum6, _pA1, _pB2); + _sum7 = _mm256_dpwssd_epi32(_sum7, _pA1, _pB3); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA0, _pB2)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA0, _pB3)); + _sum4 = _mm256_add_epi32(_sum4, _mm256_madd_epi16(_pA1, _pB0)); + _sum5 = _mm256_add_epi32(_sum5, _mm256_madd_epi16(_pA1, _pB1)); + _sum6 = _mm256_add_epi32(_sum6, _mm256_madd_epi16(_pA1, _pB2)); + _sum7 = _mm256_add_epi32(_sum7, _mm256_madd_epi16(_pA1, _pB3)); +#endif // __AVXVNNI__ +#endif // __AVX512F__ + + pA += 16; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_load_si128((const __m128i*)pA); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); +#if __AVX512F__ + __m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1); + __m512i _pA11 = _mm512_shuffle_i32x4(_pA00, _pA00, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m512i _pB01 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB0), _pB1, 1); + __m512i _pB23 = _mm512_permutex_epi64(_pB01, _MM_SHUFFLE(2, 3, 0, 1)); + + __m512i _s01 = _mm512_mullo_epi32(_pA00, _pB01); + __m512i _s23 = _mm512_mullo_epi32(_pA00, _pB23); + __m512i _s45 = _mm512_mullo_epi32(_pA11, _pB01); + __m512i _s67 = _mm512_mullo_epi32(_pA11, _pB23); + _sum0 = _mm512_add_epi32(_sum0, _s01); + _sum1 = _mm512_add_epi32(_sum1, _s23); + _sum2 = _mm512_add_epi32(_sum2, _s45); + _sum3 = _mm512_add_epi32(_sum3, _s67); +#else + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _pB2 = _mm256_permute4x64_epi64(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB3 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 1, 0, 3)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA0, _pB2); + __m256i _s3 = _mm256_mullo_epi32(_pA0, _pB3); + __m256i _s4 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s5 = _mm256_mullo_epi32(_pA1, _pB1); + __m256i _s6 = _mm256_mullo_epi32(_pA1, _pB2); + __m256i _s7 = _mm256_mullo_epi32(_pA1, _pB3); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + _sum4 = _mm256_add_epi32(_sum4, _s4); + _sum5 = _mm256_add_epi32(_sum5, _s5); + _sum6 = _mm256_add_epi32(_sum6, _s6); + _sum7 = _mm256_add_epi32(_sum7, _s7); +#endif // __AVX512F__ + + pA += 8; + pB += 8; + } + +#if __AVX512F__ + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 01 12 23 30 45 56 67 74 + // 02 13 20 31 46 57 64 75 03 10 21 32 47 54 65 76 + // 40 51 62 73 04 15 26 37 41 52 63 70 05 16 27 34 + // 42 53 60 71 06 17 24 35 43 50 61 72 07 14 25 36 + // to + // 00 10 20 30 44 54 64 74 04 14 24 34 40 50 60 70 + // 01 11 21 31 45 55 65 75 05 15 25 35 41 51 61 71 + // 02 12 22 32 46 56 66 76 06 16 26 36 42 52 62 72 + // 03 13 23 33 47 57 67 77 07 17 27 37 43 53 63 73 + { + __m512i _s0 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s1 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(2, 3, 3, 2)); + __m512i _s2 = _mm512_shuffle_i32x4(_sum1, _sum3, _MM_SHUFFLE(0, 1, 1, 0)); + __m512i _s3 = _mm512_shuffle_i32x4(_sum0, _sum2, _MM_SHUFFLE(2, 3, 3, 2)); + _s1 = _mm512_shuffle_epi32(_s1, _MM_PERM_ADCB); + _s2 = _mm512_shuffle_epi32(_s2, _MM_PERM_BADC); + _s3 = _mm512_shuffle_epi32(_s3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_s0, _s1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_s0, _s1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_s2, _s3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_s2, _s3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 0, 3, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 2, 1, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 2, 1, 2)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; +#else + if (k_end) + { + // from + // 00 11 22 33 44 55 66 77 + // 01 12 23 30 45 56 67 74 + // 02 13 20 31 46 57 64 75 + // 03 10 21 32 47 54 65 76 + // 40 51 62 73 04 15 26 37 + // 41 52 63 70 05 16 27 34 + // 42 53 60 71 06 17 24 35 + // 43 50 61 72 07 14 25 36 + // to + // 00 10 20 30 44 54 64 74 + // 01 11 21 31 45 55 65 75 + // 02 12 22 32 46 56 66 76 + // 03 13 23 33 47 57 67 77 + // 40 50 60 70 04 14 24 34 + // 41 51 61 71 05 15 25 35 + // 42 52 62 72 06 16 26 36 + // 43 53 63 73 07 17 27 37 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum2 = _mm256_shuffle_epi32(_sum2, _MM_SHUFFLE(1, 0, 3, 2)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(0, 3, 2, 1)); + _sum5 = _mm256_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum6 = _mm256_shuffle_epi32(_sum6, _MM_SHUFFLE(1, 0, 3, 2)); + _sum7 = _mm256_shuffle_epi32(_sum7, _MM_SHUFFLE(0, 3, 2, 1)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + __m256i _tmp4 = _mm256_unpacklo_epi32(_sum4, _sum7); + __m256i _tmp5 = _mm256_unpackhi_epi32(_sum4, _sum7); + __m256i _tmp6 = _mm256_unpacklo_epi32(_sum6, _sum5); + __m256i _tmp7 = _mm256_unpackhi_epi32(_sum6, _sum5); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum4 = _mm256_unpacklo_epi64(_tmp4, _tmp6); + _sum5 = _mm256_unpackhi_epi64(_tmp4, _tmp6); + _sum6 = _mm256_unpacklo_epi64(_tmp7, _tmp5); + _sum7 = _mm256_unpackhi_epi64(_tmp7, _tmp5); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum5 = _mm256_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm256_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + } + + __m256i _tmp0 = _mm256_permute2x128_si256(_sum0, _sum4, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp1 = _mm256_permute2x128_si256(_sum1, _sum5, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp2 = _mm256_permute2x128_si256(_sum2, _sum6, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp3 = _mm256_permute2x128_si256(_sum3, _sum7, _MM_SHUFFLE(0, 2, 0, 0)); + __m256i _tmp4 = _mm256_permute2x128_si256(_sum4, _sum0, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp5 = _mm256_permute2x128_si256(_sum5, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp6 = _mm256_permute2x128_si256(_sum6, _sum2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256i _tmp7 = _mm256_permute2x128_si256(_sum7, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + _sum4 = _tmp4; + _sum5 = _tmp5; + _sum6 = _tmp6; + _sum7 = _tmp7; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 8 * 2), _sum2); + _mm256_store_si256((__m256i*)(outptr + 8 * 3), _sum3); + _mm256_store_si256((__m256i*)(outptr + 8 * 4), _sum4); + _mm256_store_si256((__m256i*)(outptr + 8 * 5), _sum5); + _mm256_store_si256((__m256i*)(outptr + 8 * 6), _sum6); + _mm256_store_si256((__m256i*)(outptr + 8 * 7), _sum7); + outptr += 8 * 8; +#endif // __AVX512F__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_load_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_load_si256((const __m256i*)(outptr + 24)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif + + pA += 16; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m256i _pA0 = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB0 = _mm256_cvtepi16_epi32(_mm_castpd_si128(_mm_load1_pd((const double*)pB))); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + + pA += 8; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 40 51 62 73 + // 01 12 23 30 41 52 63 70 + // 20 31 02 13 60 71 42 53 + // 21 32 03 10 61 72 43 50 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + _mm256_store_si256((__m256i*)(outptr + 16), _sum2); + _mm256_store_si256((__m256i*)(outptr + 24), _sum3); + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m256i _sum0; + __m256i _sum1; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + _sum1 = _mm256_load_si256((const __m256i*)(outptr + 8)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB0 = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pB)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA, _pB1)); +#endif + + pA += 16; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB0 = _mm256_cvtepi16_epi32(_mm_castps_si128(_mm_load1_ps((const float*)pB))); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 1, 0, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA, _pB1); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + + pA += 8; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 40 51 60 71 + // 01 10 21 30 41 50 61 70 + // to + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + { + __m256i _tmp0 = _mm256_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _tmp1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1)); + _sum0 = _mm256_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm256_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + _mm256_store_si256((__m256i*)(outptr + 8), _sum1); + outptr += 16; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m256i _sum0; + + if (k == 0) + { + _sum0 = _mm256_setzero_si256(); + } + else + { + _sum0 = _mm256_load_si256((const __m256i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m256i _pA = _mm256_loadu_si256((const __m256i*)pA); + __m256i _pB = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pB)); + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA, _pB)); + + pA += 16; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_cvtepi16_epi32(_mm_load_si128((const __m128i*)pA)); + __m256i _pB = _mm256_set1_epi32(pB[0]); + __m256i _s0 = _mm256_mullo_epi32(_pA, _pB); + _sum0 = _mm256_add_epi32(_sum0, _s0); + + pA += 8; + pB += 1; + } + + _mm256_store_si256((__m256i*)outptr, _sum0); + outptr += 8; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + __m512i _sum2; + __m512i _sum3; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + _sum2 = _mm512_loadu_si512((const __m512i*)(outptr + 32)); + _sum3 = _mm512_loadu_si512((const __m512i*)(outptr + 48)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + __m256i _pAA = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m512i _pA0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pAA), _pAA, 1); + __m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm512_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm512_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA0, _pB1)); + _sum2 = _mm512_add_epi32(_sum2, _mm512_madd_epi16(_pA1, _pB0)); + _sum3 = _mm512_add_epi32(_sum3, _mm512_madd_epi16(_pA1, _pB1)); +#endif + + pA += 8; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m256i _pA = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)pA)); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + + __m512i _pA0 = _mm512_cvtepi16_epi32(_pA); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + __m512i _pA1 = _mm512_permutex_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA0, _pB1); + __m512i _s2 = _mm512_mullo_epi32(_pA1, _pB0); + __m512i _s3 = _mm512_mullo_epi32(_pA1, _pB1); + + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + pA += 4; + pB += 16; + } + + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 08 19 2a 3b 0c 1d 2e 3f + // 01 12 23 30 05 16 27 34 09 1a 2b 38 0d 1e 2f 3c + // 20 31 02 13 24 35 06 17 28 3a 0a 1b 2c 3d 0e 1f + // 21 32 03 10 25 36 07 14 29 3a 0b 18 2d 3e 0f 1c + // to + // 00 10 20 30 04 14 24 34 08 18 28 38 0c 1c 2c 3c + // 01 11 21 31 05 15 25 35 09 19 29 39 0d 1d 2d 3d + // 02 12 22 32 06 16 26 36 0a 1a 2a 3a 0e 1e 2e 3e + // 03 13 23 33 07 17 27 37 0b 1b 2b 3b 0f 1f 2f 3f + { + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum3); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum3); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum1); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm512_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm512_shuffle_epi32(_sum1, _MM_PERM_CBAD); + _sum3 = _mm512_shuffle_epi32(_sum3, _MM_PERM_CBAD); + } + + __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp1 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512i _tmp2 = _mm512_shuffle_i32x4(_sum0, _sum1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512i _tmp3 = _mm512_shuffle_i32x4(_sum2, _sum3, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + _sum3 = _mm512_shuffle_i32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + _mm512_storeu_si512((__m512i*)(outptr + 32), _sum2); + _mm512_storeu_si512((__m512i*)(outptr + 48), _sum3); + outptr += 64; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; + __m256i _sum2; + __m256i _sum3; +#else + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + __m128i _sum4; + __m128i _sum5; + __m128i _sum6; + __m128i _sum7; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + _sum4 = _mm_setzero_si128(); + _sum5 = _mm_setzero_si128(); + _sum6 = _mm_setzero_si128(); + _sum7 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); + _sum2 = _mm256_loadu_si256((const __m256i*)(outptr + 16)); + _sum3 = _mm256_loadu_si256((const __m256i*)(outptr + 24)); +#else + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + _sum4 = _mm_load_si128((const __m128i*)(outptr + 16)); + _sum5 = _mm_load_si128((const __m128i*)(outptr + 20)); + _sum6 = _mm_load_si128((const __m128i*)(outptr + 24)); + _sum7 = _mm_load_si128((const __m128i*)(outptr + 28)); +#endif + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + __m256i _pA0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1); + __m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _sum0 = _mm256_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm256_dpwssd_epi32(_sum1, _pA0, _pB1); + _sum2 = _mm256_dpwssd_epi32(_sum2, _pA1, _pB0); + _sum3 = _mm256_dpwssd_epi32(_sum3, _pA1, _pB1); +#else + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA0, _pB1)); + _sum2 = _mm256_add_epi32(_sum2, _mm256_madd_epi16(_pA1, _pB0)); + _sum3 = _mm256_add_epi32(_sum3, _mm256_madd_epi16(_pA1, _pB1)); +#endif +#else // __AVX2__ + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maddd_epi16(_pA0, _pB2, _sum2); + _sum3 = _mm_maddd_epi16(_pA0, _pB3, _sum3); + _sum4 = _mm_maddd_epi16(_pA1, _pB0, _sum4); + _sum5 = _mm_maddd_epi16(_pA1, _pB1, _sum5); + _sum6 = _mm_maddd_epi16(_pA1, _pB2, _sum6); + _sum7 = _mm_maddd_epi16(_pA1, _pB3, _sum7); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA0, _pB2)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA0, _pB3)); + _sum4 = _mm_add_epi32(_sum4, _mm_madd_epi16(_pA1, _pB0)); + _sum5 = _mm_add_epi32(_sum5, _mm_madd_epi16(_pA1, _pB1)); + _sum6 = _mm_add_epi32(_sum6, _mm_madd_epi16(_pA1, _pB2)); + _sum7 = _mm_add_epi32(_sum7, _mm_madd_epi16(_pA1, _pB3)); +#endif +#endif // __AVX2__ + + pA += 8; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + +#if __AVX2__ + __m256i _pA0 = _mm256_cvtepi16_epi32(_pA); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); + __m256i _pA1 = _mm256_permute4x64_epi64(_pA0, _MM_SHUFFLE(2, 3, 0, 1)); + __m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA0, _pB1); + __m256i _s2 = _mm256_mullo_epi32(_pA1, _pB0); + __m256i _s3 = _mm256_mullo_epi32(_pA1, _pB1); + + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); +#else // __AVX2__ +#if __XOP__ + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_unpackhi_epi16(_pB, _pB); + __m128i _pB2 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _pB3 = _mm_shuffle_epi32(_pB1, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA0, _pB2, _sum2); + _sum3 = _mm_maccd_epi16(_pA0, _pB3, _sum3); + _sum4 = _mm_maccd_epi16(_pA1, _pB0, _sum4); + _sum5 = _mm_maccd_epi16(_pA1, _pB1, _sum5); + _sum6 = _mm_maccd_epi16(_pA1, _pB2, _sum6); + _sum7 = _mm_maccd_epi16(_pA1, _pB3, _sum7); +#else + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB01 = _pB; + __m128i _pB23 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA0, _pB23); + __m128i _sh1 = _mm_mulhi_epi16(_pA0, _pB23); + __m128i _sl2 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh2 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _sl3 = _mm_mullo_epi16(_pA1, _pB23); + __m128i _sh3 = _mm_mulhi_epi16(_pA1, _pB23); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + __m128i _s4 = _mm_unpacklo_epi16(_sl2, _sh2); + __m128i _s5 = _mm_unpackhi_epi16(_sl2, _sh2); + __m128i _s6 = _mm_unpacklo_epi16(_sl3, _sh3); + __m128i _s7 = _mm_unpackhi_epi16(_sl3, _sh3); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); + _sum4 = _mm_add_epi32(_sum4, _s4); + _sum5 = _mm_add_epi32(_sum5, _s5); + _sum6 = _mm_add_epi32(_sum6, _s6); + _sum7 = _mm_add_epi32(_sum7, _s7); +#endif +#endif // __AVX2__ + + pA += 4; + pB += 8; + } + +#if __AVX2__ + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + // to + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + { + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum3); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum3); + __m256i _tmp2 = _mm256_unpacklo_epi32(_sum2, _sum1); + __m256i _tmp3 = _mm256_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm256_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm256_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm256_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm256_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm256_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm256_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _tmp0 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp1 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 2, 0, 0)); + _tmp2 = _mm256_permute2x128_si256(_sum0, _sum1, _MM_SHUFFLE(0, 3, 0, 1)); + _tmp3 = _mm256_permute2x128_si256(_sum2, _sum3, _MM_SHUFFLE(0, 3, 0, 1)); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + } + + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); + _mm256_storeu_si256((__m256i*)(outptr + 16), _sum2); + _mm256_storeu_si256((__m256i*)(outptr + 24), _sum3); + outptr += 32; +#else // __AVX2__ + if (k_end) + { + // from + // 00 11 22 33 04 15 26 37 + // 01 12 23 30 05 16 27 34 + // 20 31 02 13 24 35 06 17 + // 21 32 03 10 25 36 07 14 + // to + // 00 10 20 30 04 14 24 34 + // 01 11 21 31 05 15 25 35 + // 02 12 22 32 06 16 26 36 + // 03 13 23 33 07 17 27 37 + { + _sum2 = _mm_shuffle_epi32(_sum2, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum6 = _mm_shuffle_epi32(_sum6, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum6); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum6); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum7); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum7); + __m128i _tmp4 = _mm_unpacklo_epi32(_sum4, _sum2); + __m128i _tmp5 = _mm_unpackhi_epi32(_sum4, _sum2); + __m128i _tmp6 = _mm_unpacklo_epi32(_sum5, _sum3); + __m128i _tmp7 = _mm_unpackhi_epi32(_sum5, _sum3); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp4); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp4); + _sum2 = _mm_unpacklo_epi64(_tmp5, _tmp1); + _sum3 = _mm_unpackhi_epi64(_tmp5, _tmp1); + _sum4 = _mm_unpacklo_epi64(_tmp2, _tmp6); + _sum5 = _mm_unpackhi_epi64(_tmp2, _tmp6); + _sum6 = _mm_unpacklo_epi64(_tmp7, _tmp3); + _sum7 = _mm_unpackhi_epi64(_tmp7, _tmp3); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + _sum5 = _mm_shuffle_epi32(_sum5, _MM_SHUFFLE(2, 1, 0, 3)); + _sum7 = _mm_shuffle_epi32(_sum7, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + _mm_store_si128((__m128i*)(outptr + 16), _sum4); + _mm_store_si128((__m128i*)(outptr + 20), _sum5); + _mm_store_si128((__m128i*)(outptr + 24), _sum6); + _mm_store_si128((__m128i*)(outptr + 28), _sum7); + outptr += 32; +#endif // __AVX2__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_load_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_load_si128((const __m128i*)(outptr + 12)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maddd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maddd_epi16(_pA1, _pB1, _sum3); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); +#endif + + pA += 8; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); +#if __XOP__ + __m128i _pA0 = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pA1 = _mm_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2)); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1)); + _sum0 = _mm_maccd_epi16(_pA0, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA0, _pB1, _sum1); + _sum2 = _mm_maccd_epi16(_pA1, _pB0, _sum2); + _sum3 = _mm_maccd_epi16(_pA1, _pB1, _sum3); +#else + __m128i _pA0 = _pA; + __m128i _pA1 = _mm_shuffle_epi32(_pA, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 3, 2, 1)); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB01); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB01); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB01); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif + + pA += 4; + pB += 4; + } + + if (k_end) + { + // from + // 00 11 22 33 + // 01 12 23 30 + // 20 31 02 13 + // 21 32 03 10 + // to + // 00 10 20 30 + // 01 11 21 31 + // 02 12 22 32 + // 03 13 23 33 + { + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum3); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum3); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum2, _sum1); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum2, _sum1); + _sum0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm_unpacklo_epi64(_tmp3, _tmp1); + _sum3 = _mm_unpackhi_epi64(_tmp3, _tmp1); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + _sum3 = _mm_shuffle_epi32(_sum3, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + _mm_store_si128((__m128i*)(outptr + 8), _sum2); + _mm_store_si128((__m128i*)(outptr + 12), _sum3); + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + _sum1 = _mm_load_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB0 = _mm_castpd_si128(_mm_load1_pd((const double*)pB)); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA, _pB0, _sum0); + _sum1 = _mm_maddd_epi16(_pA, _pB1, _sum1); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); +#endif + + pA += 8; + pB += 4; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_castpd_si128(_mm_load1_pd((const double*)pA)); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __XOP__ + _pA = _mm_unpacklo_epi16(_pA, _pA); + __m128i _pB0 = _mm_unpacklo_epi16(_pB, _pB); + __m128i _pB1 = _mm_shuffle_epi32(_pB0, _MM_SHUFFLE(2, 3, 0, 1)); + _sum0 = _mm_maccd_epi16(_pA, _pB0, _sum0); + _sum1 = _mm_maccd_epi16(_pA, _pB1, _sum1); +#else + __m128i _pB01 = _mm_shufflehi_epi16(_pB, _MM_SHUFFLE(0, 1, 0, 1)); + __m128i _sl = _mm_mullo_epi16(_pA, _pB01); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB01); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); +#endif + + pA += 4; + pB += 2; + } + + if (k_end) + { + // from + // 00 11 20 31 + // 01 10 21 30 + // to + // 00 10 20 30 + // 01 11 21 31 + { + __m128i _tmp0 = _mm_shuffle_epi32(_sum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _tmp1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(0, 2, 3, 1)); + _sum0 = _mm_unpacklo_epi32(_tmp0, _tmp1); + _sum1 = _mm_unpackhi_epi32(_tmp0, _tmp1); + _sum1 = _mm_shuffle_epi32(_sum1, _MM_SHUFFLE(2, 1, 0, 3)); + } + } + + _mm_store_si128((__m128i*)outptr, _sum0); + _mm_store_si128((__m128i*)(outptr + 4), _sum1); + outptr += 8; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_load_si128((const __m128i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_loadu_si128((const __m128i*)pA); + __m128i _pB = _mm_castps_si128(_mm_load1_ps((const float*)pB)); + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_pA, _pB, _sum0); +#else + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); +#endif + + pA += 8; + pB += 2; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_loadl_epi64((const __m128i*)pA); + __m128i _pB = _mm_set1_epi16(pB[0]); + +#if __XOP__ + _pA = _mm_unpacklo_epi16(_pA, _pA); + _sum0 = _mm_maccd_epi16(_pA, _pB, _sum0); +#else + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); +#endif + + pA += 4; + pB += 1; + } + + _mm_store_si128((__m128i*)outptr, _sum0); + outptr += 4; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + __m512i _sum1; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + _sum1 = _mm512_loadu_si512((const __m512i*)(outptr + 16)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pA1 = _mm512_set1_epi32(((const int*)pA)[1]); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); + _sum1 = _mm512_dpwssd_epi32(_sum1, _pA1, _pB0); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); + _sum1 = _mm512_add_epi32(_sum1, _mm512_madd_epi16(_pA1, _pB0)); +#endif // __AVX512VNNI__ + + pA += 4; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA0 = _mm512_set1_epi32(pA[0]); + __m512i _pA1 = _mm512_set1_epi32(pA[1]); + __m256i _pB = _mm256_loadu_si256((const __m256i*)pB); + __m512i _pB0 = _mm512_cvtepi16_epi32(_pB); + + __m512i _s0 = _mm512_mullo_epi32(_pA0, _pB0); + __m512i _s1 = _mm512_mullo_epi32(_pA1, _pB0); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + pA += 2; + pB += 16; + } + + if (k_end) + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(1, 0, 1, 0)); + _sum1 = _mm512_shuffle_i32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 2, 3, 2)); + _sum0 = _mm512_shuffle_i32x4(_sum0, _sum0, _MM_SHUFFLE(3, 1, 2, 0)); + _sum1 = _mm512_shuffle_i32x4(_sum1, _sum1, _MM_SHUFFLE(3, 1, 2, 0)); + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + _mm512_storeu_si512((__m512i*)(outptr + 16), _sum1); + outptr += 32; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + +#if __AVX2__ + __m256i _sum0; + __m256i _sum1; +#else + __m128i _sum0; + __m128i _sum1; + __m128i _sum2; + __m128i _sum3; +#endif + + if (k == 0) + { +#if __AVX2__ + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#else + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); +#endif + } + else + { +#if __AVX2__ + _sum0 = _mm256_loadu_si256((const __m256i*)outptr); + _sum1 = _mm256_loadu_si256((const __m256i*)(outptr + 8)); +#else + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + _sum2 = _mm_loadu_si128((const __m128i*)(outptr + 8)); + _sum3 = _mm_loadu_si128((const __m128i*)(outptr + 12)); +#endif + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { +#if __AVX2__ + __m256i _pA0 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)pA)); + __m256i _pA1 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(pA + 2))); + __m256i _pB0 = _mm256_loadu_si256((const __m256i*)pB); + + // vs2019 internal compiler error with avx512 vnni intrinsics here + // fallback to avx2 madd anyway as a workaround --- nihui + _sum0 = _mm256_add_epi32(_sum0, _mm256_madd_epi16(_pA0, _pB0)); + _sum1 = _mm256_add_epi32(_sum1, _mm256_madd_epi16(_pA1, _pB0)); +#else // __AVX2__ + __m128i _pA0 = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pA1 = _mm_castps_si128(_mm_load1_ps((const float*)(pA + 2))); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA0, _pB1)); + _sum2 = _mm_add_epi32(_sum2, _mm_madd_epi16(_pA1, _pB0)); + _sum3 = _mm_add_epi32(_sum3, _mm_madd_epi16(_pA1, _pB1)); +#endif // __AVX2__ + + pA += 4; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pB = _mm_load_si128((const __m128i*)pB); +#if __AVX2__ + __m256i _pA0 = _mm256_set1_epi32(pA[0]); + __m256i _pA1 = _mm256_set1_epi32(pA[1]); + __m256i _pB0 = _mm256_cvtepi16_epi32(_pB); + + __m256i _s0 = _mm256_mullo_epi32(_pA0, _pB0); + __m256i _s1 = _mm256_mullo_epi32(_pA1, _pB0); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); +#else // __AVX2__ + __m128i _pA0 = _mm_set1_epi16(pA[0]); + __m128i _pA1 = _mm_set1_epi16(pA[1]); + + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpackhi_epi16(_sl0, _sh0); + __m128i _s2 = _mm_unpacklo_epi16(_sl1, _sh1); + __m128i _s3 = _mm_unpackhi_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif // __AVX2__ + pA += 2; + pB += 8; + } + +#if __AVX2__ + if (k_end) + { + __m256i _tmp0 = _mm256_unpacklo_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_sum0, _sum1); + _sum0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _sum1 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + } + + _mm256_storeu_si256((__m256i*)outptr, _sum0); + _mm256_storeu_si256((__m256i*)(outptr + 8), _sum1); + outptr += 16; +#else // __AVX2__ + if (k_end) + { + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum2); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum2); + __m128i _tmp2 = _mm_unpacklo_epi32(_sum1, _sum3); + __m128i _tmp3 = _mm_unpackhi_epi32(_sum1, _sum3); + _sum0 = _tmp0; + _sum1 = _tmp1; + _sum2 = _tmp2; + _sum3 = _tmp3; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + _mm_storeu_si128((__m128i*)(outptr + 8), _sum2); + _mm_storeu_si128((__m128i*)(outptr + 12), _sum3); + outptr += 16; +#endif // __AVX2__ + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA0 = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pA1 = _mm_castps_si128(_mm_load1_ps((const float*)(pA + 2))); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA0, _pB)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA1, _pB)); + + pA += 4; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA0 = _mm_set1_epi16(pA[0]); + __m128i _pA1 = _mm_set1_epi16(pA[1]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + __m128i _sl0 = _mm_mullo_epi16(_pA0, _pB); + __m128i _sh0 = _mm_mulhi_epi16(_pA0, _pB); + __m128i _sl1 = _mm_mullo_epi16(_pA1, _pB); + __m128i _sh1 = _mm_mulhi_epi16(_pA1, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl0, _sh0); + __m128i _s1 = _mm_unpacklo_epi16(_sl1, _sh1); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + pA += 2; + pB += 4; + } + + if (k_end) + { + __m128i _tmp0 = _mm_unpacklo_epi32(_sum0, _sum1); + __m128i _tmp1 = _mm_unpackhi_epi32(_sum0, _sum1); + _sum0 = _tmp0; + _sum1 = _tmp1; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + outptr += 2 * 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum00 = 0; + int sum01 = 0; + int sum10 = 0; + int sum11 = 0; + + if (k == 0) + { + sum00 = 0; + sum01 = 0; + sum10 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum00 += pA[0] * pB[0]; + sum00 += pA[1] * pB[1]; + sum01 += pA[2] * pB[0]; + sum01 += pA[3] * pB[1]; + sum10 += pA[0] * pB[2]; + sum10 += pA[1] * pB[3]; + sum11 += pA[2] * pB[2]; + sum11 += pA[3] * pB[3]; + + pA += 4; + pB += 4; + } + for (; kk < max_kk; kk++) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + outptr += 2 * 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[2] * pB[0]; + sum1 += pA[3] * pB[1]; + pA += 4; + pB += 2; + } + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + } + } + for (; ii < max_ii; ii++) + { + for (int b = 0; b < batch; b++) + { + const short* pAT = AT_tile.row(b) + max_kk * ii; + const short* pB = BT_tile.row(b); + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + const short* pA = pAT; + + __m512i _sum0; + + if (k == 0) + { + _sum0 = _mm512_setzero_si512(); + } + else + { + _sum0 = _mm512_loadu_si512((const __m512i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _pA0 = _mm512_set1_epi32(((const int*)pA)[0]); + __m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB); + +#if __AVX512VNNI__ + _sum0 = _mm512_dpwssd_epi32(_sum0, _pA0, _pB0); +#else + _sum0 = _mm512_add_epi32(_sum0, _mm512_madd_epi16(_pA0, _pB0)); +#endif + + pA += 2; + pB += 32; + } + for (; kk < max_kk; kk++) + { + __m512i _pA = _mm512_set1_epi32(pA[0]); + __m512i _pB0 = _mm512_cvtepi16_epi32(_mm256_loadu_si256((const __m256i*)pB)); + + __m512i _s0 = _mm512_mullo_epi32(_pA, _pB0); + _sum0 = _mm512_add_epi32(_sum0, _s0); + + pA += 1; + pB += 16; + } + + _mm512_storeu_si512((__m512i*)outptr, _sum0); + outptr += 16; + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + const short* pA = pAT; + + __m128i _sum0; + __m128i _sum1; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + _sum1 = _mm_loadu_si128((const __m128i*)(outptr + 4)); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB0 = _mm_loadu_si128((const __m128i*)pB); + __m128i _pB1 = _mm_loadu_si128((const __m128i*)(pB + 8)); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB0)); + _sum1 = _mm_add_epi32(_sum1, _mm_madd_epi16(_pA, _pB1)); + + pA += 2; + pB += 16; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_load_si128((const __m128i*)pB); + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + __m128i _s1 = _mm_unpackhi_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + pA += 1; + pB += 8; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + _mm_storeu_si128((__m128i*)(outptr + 4), _sum1); + outptr += 8; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + const short* pA = pAT; + + __m128i _sum0; + + if (k == 0) + { + _sum0 = _mm_setzero_si128(); + } + else + { + _sum0 = _mm_loadu_si128((const __m128i*)outptr); + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + __m128i _pA = _mm_castps_si128(_mm_load1_ps((const float*)pA)); + __m128i _pB = _mm_loadu_si128((const __m128i*)pB); + _sum0 = _mm_add_epi32(_sum0, _mm_madd_epi16(_pA, _pB)); + + pA += 2; + pB += 8; + } + for (; kk < max_kk; kk++) + { + __m128i _pA = _mm_set1_epi16(pA[0]); + __m128i _pB = _mm_loadl_epi64((const __m128i*)pB); + __m128i _sl = _mm_mullo_epi16(_pA, _pB); + __m128i _sh = _mm_mulhi_epi16(_pA, _pB); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + _sum0 = _mm_add_epi32(_sum0, _s0); + pA += 1; + pB += 4; + } + + _mm_storeu_si128((__m128i*)outptr, _sum0); + outptr += 4; + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + const short* pA = pAT; + + int sum0 = 0; + int sum1 = 0; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + int kk = 0; + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[0] * pB[2]; + sum1 += pA[1] * pB[3]; + pA += 2; + pB += 4; + } + for (; kk < max_kk; kk++) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + outptr += 2; + } + for (; jj < max_jj; jj++) + { + const short* pA = pAT; + + int sum = 0; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + int kk = 0; + for (; kk < max_kk; kk++) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + outptr += 1; + } + } + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size_int8 = (int)(get_cpu_level2_cache_size() / sizeof(short)); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + // solve M + { + int tile_size = (int)sqrt((float)l2_cache_size_int8 / 3); + +#if __AVX512F__ + TILE_M = std::max(16, tile_size / 16 * 16); +#elif __AVX2__ + TILE_M = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_M = std::max(4, tile_size / 4 * 4); +#else + TILE_M = std::max(2, tile_size / 2 * 2); +#endif + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + int nn_M = (M + TILE_M - 1) / TILE_M; +#if __AVX512F__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 15) / 16 * 16); +#elif __AVX2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 1) / 2 * 2); +#endif + + if (nT > 1) + { +#if __AVX512F__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 15) / 16 * 16); +#elif __AVX2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); +#elif __SSE2__ + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 3) / 4 * 4); +#else + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 1) / 2 * 2); +#endif + } + } + + // solve K + { + int tile_size = (int)(sqrt((float)l2_cache_size_int8) - TILE_M); + +#if __AVX512F__ + TILE_K = std::max(16, tile_size / 16 * 16); +#elif __AVX2__ + TILE_K = std::max(8, tile_size / 8 * 8); +#elif __SSE2__ + TILE_K = std::max(4, tile_size / 4 * 4); +#else + TILE_K = std::max(2, tile_size / 2 * 2); +#endif + + int nn_K = (K + TILE_K - 1) / TILE_K; +#if __AVX512F__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 15) / 16 * 16); +#elif __AVX2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); +#elif __SSE2__ + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 3) / 4 * 4); +#else + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 1) / 2 * 2); +#endif + } + + if (N > 0) + { + int tile_size = (int)((l2_cache_size_int8 - TILE_M * TILE_K) / (TILE_M * 2 + TILE_K)); + +#if __SSE2__ + TILE_N = std::max(4, tile_size / 4 * 4); +#else + TILE_N = std::max(1, tile_size); +#endif + + int nn_N = (N + TILE_N - 1) / TILE_N; +#if __SSE2__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#else + TILE_N = std::min(TILE_N, (N + nn_N - 1) / nn_N); +#endif + } +} + +static inline void conv3x3s1_winograd23_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const signed char ktm[4][3] = { + // {2, 0, 0}, + // {1, 1, 1}, + // {1, -1, 1}, + // {0, 0, 2} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[4][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 2; + tmp[1][m] = r0 + r1 + r2; + tmp[2][m] = r0 - r1 + r2; + tmp[3][m] = r2 * 2; + + k0 += 3; + } + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 2; + short z1 = r0 + r1 + r2; + short z2 = r0 - r1 + r2; + short z3 = r2 * 2; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp += 4; + } + } + } +} + +static void conv3x3s1_winograd23_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd23_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); + return; + } +#endif +#endif + + const int M = outch; + const int K = inch; + const int B = 16; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 2u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 2u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd23_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd23_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const signed char itm[4][4] = { + // {1, 0, -1, 0}, + // {0, 1, 1, 0}, + // {0, -1, 1, 0}, + // {0, -1, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w - 1) / 2; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __SSE2__ +#if __AVX512F__ + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + short tmp[4][4][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_setzero_si256(); + __m256i _r1 = _mm256_setzero_si256(); + __m256i _r2 = _mm256_setzero_si256(); + __m256i _r3 = _mm256_setzero_si256(); + + if (ti * 2 + m < h) + { + if (elempack == 16) + { + _r0 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)r0)); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 16))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 32))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 48))); + } + if (elempack == 8) + { + const signed char* r1 = r0 + N; + + _r0 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)r0), _mm_loadl_epi64((const __m128i*)r1))); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 8)), _mm_loadl_epi64((const __m128i*)(r1 + 8)))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 16)), _mm_loadl_epi64((const __m128i*)(r1 + 16)))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 24)), _mm_loadl_epi64((const __m128i*)(r1 + 24)))); + } + if (elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _r0 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)r0, sizeof(signed char)))); + if (tj * 2 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 1), sizeof(signed char)))); + if (tj * 2 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 2), sizeof(signed char)))); + if (tj * 2 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 3), sizeof(signed char)))); + } + } + + __m256i _tmp0 = _mm256_sub_epi16(_r0, _r2); + __m256i _tmp1 = _mm256_add_epi16(_r1, _r2); + __m256i _tmp2 = _mm256_sub_epi16(_r2, _r1); + __m256i _tmp3 = _mm256_sub_epi16(_r3, _r1); + + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 16; + short* p1 = p0 + max_jj * 16; + short* p2 = p0 + max_jj * 16 * 2; + short* p3 = p0 + max_jj * 16 * 3; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + + __m256i _tmp0 = _mm256_sub_epi16(_r0, _r2); + __m256i _tmp1 = _mm256_add_epi16(_r1, _r2); + __m256i _tmp2 = _mm256_sub_epi16(_r2, _r1); + __m256i _tmp3 = _mm256_sub_epi16(_r3, _r1); + + _mm256_store_si256((__m256i*)p0, _tmp0); + _mm256_store_si256((__m256i*)p1, _tmp1); + _mm256_store_si256((__m256i*)p2, _tmp2); + _mm256_store_si256((__m256i*)p3, _tmp3); + + p0 += max_jj * 4 * 16; + p1 += max_jj * 4 * 16; + p2 += max_jj * 4 * 16; + p3 += max_jj * 4 * 16; + } + } + } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) +#endif // __AVX512F__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + short tmp[4][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 2) + (tj * 2) * elempack; + + for (int m = 0; m < 4; m++) + { + __m128i _r0 = _mm_setzero_si128(); + __m128i _r1 = _mm_setzero_si128(); + __m128i _r2 = _mm_setzero_si128(); + __m128i _r3 = _mm_setzero_si128(); + + if (ti * 2 + m < h) + { + if (elempack == 8) + { + _r0 = _mm_loadl_epi64((const __m128i*)r0); + _r0 = _mm_unpacklo_epi8(_r0, _mm_cmpgt_epi8(_mm_setzero_si128(), _r0)); + if (tj * 2 + 1 < w) + { + _r1 = _mm_loadl_epi64((const __m128i*)(r0 + 8)); + _r1 = _mm_unpacklo_epi8(_r1, _mm_cmpgt_epi8(_mm_setzero_si128(), _r1)); + } + if (tj * 2 + 2 < w) + { + _r2 = _mm_loadl_epi64((const __m128i*)(r0 + 16)); + _r2 = _mm_unpacklo_epi8(_r2, _mm_cmpgt_epi8(_mm_setzero_si128(), _r2)); + } + if (tj * 2 + 3 < w) + { + _r3 = _mm_loadl_epi64((const __m128i*)(r0 + 24)); + _r3 = _mm_unpacklo_epi8(_r3, _mm_cmpgt_epi8(_mm_setzero_si128(), _r3)); + } + } + if (elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); +#if __AVX512F__ + _r0 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)))); + if (tj * 2 + 1 < w) _r1 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)))); + if (tj * 2 + 2 < w) _r2 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)))); + if (tj * 2 + 3 < w) _r3 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)))); +#else + __m128i _sindex8 = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m256i _sindex88 = _mm256_inserti128_si256(_mm256_castsi128_si256(_sindex8), _sindex8, 1); + __m256i _val0_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)), _sindex88); + _r0 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val0_32, 0), _mm256_extracti128_si256(_val0_32, 1))); + if (tj * 2 + 1 < w) + { + __m256i _val1_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)), _sindex88); + _r1 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val1_32, 0), _mm256_extracti128_si256(_val1_32, 1))); + } + if (tj * 2 + 2 < w) + { + __m256i _val2_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)), _sindex88); + _r2 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val2_32, 0), _mm256_extracti128_si256(_val2_32, 1))); + } + if (tj * 2 + 3 < w) + { + __m256i _val3_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)), _sindex88); + _r3 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val3_32, 0), _mm256_extracti128_si256(_val3_32, 1))); + } +#endif // __AVX512F__ +#else // __AVX2__ + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + __m128i _t0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _t1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _t2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _t3 = _mm_loadl_epi64((const __m128i*)r3); + __m128i _t4 = _mm_loadl_epi64((const __m128i*)r4); + __m128i _t5 = _mm_loadl_epi64((const __m128i*)r5); + __m128i _t6 = _mm_loadl_epi64((const __m128i*)r6); + __m128i _t7 = _mm_loadl_epi64((const __m128i*)r7); + + __m128i _t01 = _mm_unpacklo_epi8(_t0, _t1); + __m128i _t23 = _mm_unpacklo_epi8(_t2, _t3); + __m128i _t45 = _mm_unpacklo_epi8(_t4, _t5); + __m128i _t67 = _mm_unpacklo_epi8(_t6, _t7); + _t0 = _mm_unpacklo_epi16(_t01, _t23); + _t1 = _mm_unpacklo_epi16(_t45, _t67); + _t2 = _mm_unpacklo_epi32(_t0, _t1); + _t3 = _mm_unpackhi_epi32(_t0, _t1); + + __m128i _extt2 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t2); + __m128i _extt3 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t3); + + _r0 = _mm_unpacklo_epi8(_t2, _extt2); + if (tj * 2 + 1 < w) _r1 = _mm_unpackhi_epi8(_t2, _extt2); + if (tj * 2 + 2 < w) _r2 = _mm_unpacklo_epi8(_t3, _extt3); + if (tj * 2 + 3 < w) _r3 = _mm_unpackhi_epi8(_t3, _extt3); +#endif // __AVX2__ + } + } + + __m128i _tmp0 = _mm_sub_epi16(_r0, _r2); + __m128i _tmp1 = _mm_add_epi16(_r1, _r2); + __m128i _tmp2 = _mm_sub_epi16(_r2, _r1); + __m128i _tmp3 = _mm_sub_epi16(_r3, _r1); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + // old gcc breaks stack variable alignement + // ref https://gcc.gnu.org/bugzilla/show_bug.cgi?id=16660 + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); +#endif + + __m128i _tmp0 = _mm_sub_epi16(_r0, _r2); + __m128i _tmp1 = _mm_add_epi16(_r1, _r2); + __m128i _tmp2 = _mm_sub_epi16(_r2, _r1); + __m128i _tmp3 = _mm_sub_epi16(_r3, _r1); + + _mm_store_si128((__m128i*)p0, _tmp0); + _mm_store_si128((__m128i*)p1, _tmp1); + _mm_store_si128((__m128i*)p2, _tmp2); + _mm_store_si128((__m128i*)p3, _tmp3); + + p0 += max_jj * 4 * 8; + p1 += max_jj * 4 * 8; + p2 += max_jj * 4 * 8; + p3 += max_jj * 4 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __SSE2__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[4][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 2 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 2 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 2 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + } + } + + tmp[0][m][0] = r00 - r20; + tmp[0][m][1] = r01 - r21; + tmp[1][m][0] = r10 + r20; + tmp[1][m][1] = r11 + r21; + tmp[2][m][0] = r20 - r10; + tmp[2][m][1] = r21 - r11; + tmp[3][m][0] = r30 - r10; + tmp[3][m][1] = r31 - r11; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + + p0[0] = r00 - r20; + p0[1] = r01 - r21; + p1[0] = r10 + r20; + p1[1] = r11 + r21; + p2[0] = r20 - r10; + p2[1] = r21 - r11; + p3[0] = r30 - r10; + p3[1] = r31 - r11; + + p0 += max_jj * 4 * 2; + p1 += max_jj * 4 * 2; + p2 += max_jj * 4 * 2; + p3 += max_jj * 4 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 4; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + + if (ti * 2 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 2 + 1 < w) r1 = r0123[1]; + if (tj * 2 + 2 < w) r2 = r0123[2]; + if (tj * 2 + 3 < w) r3 = r0123[3]; + } + } + + tmp[0][m] = r0 - r2; + tmp[1][m] = r1 + r2; + tmp[2][m] = r2 - r1; + tmp[3][m] = r3 - r1; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 16 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + + p0[0] = r0 - r2; + p1[0] = r1 + r2; + p2[0] = r2 - r1; + p3[0] = r3 - r1; + + p0 += max_jj * 4; + p1 += max_jj * 4; + p2 += max_jj * 4; + p3 += max_jj * 4; + } + } + } +} + +static inline void conv3x3s1_winograd23_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[2][4] = { + // {1, 1, 1, 0}, + // {0, 1, -1, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 1) / 2; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + int tmp[2][4][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 16; + const int* r1 = r0 + max_jj * 16; + const int* r2 = r0 + max_jj * 16 * 2; + const int* r3 = r0 + max_jj * 16 * 3; + + for (int m = 0; m < 4; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_r0, _r1), _r2); + __m512i _tmp1 = _mm512_add_epi32(_mm512_sub_epi32(_r1, _r2), _r3); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + + r0 += max_jj * 4 * 16; + r1 += max_jj * 4 * 16; + r2 += max_jj * 4 * 16; + r3 += max_jj * 4 * 16; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + __m512i _r0 = _mm512_load_si512((const __m512i*)tmp[m][0]); + __m512i _r1 = _mm512_load_si512((const __m512i*)tmp[m][1]); + __m512i _r2 = _mm512_load_si512((const __m512i*)tmp[m][2]); + __m512i _r3 = _mm512_load_si512((const __m512i*)tmp[m][3]); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_r0, _r1), _r2); + __m512i _tmp1 = _mm512_add_epi32(_mm512_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm512_srai_epi32(_tmp0, 2); + _tmp1 = _mm512_srai_epi32(_tmp1, 2); + + if (out_elempack == 16) + { + _mm512_store_si512((__m512i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) + { + _mm512_store_si512((__m512i*)(outptr0 + 16), _tmp1); + } + } + if (out_elempack == 8) + { + int* outptr1 = outptr0 + N; + + _mm256_store_si256((__m256i*)outptr0, _mm512_extracti32x8_epi32(_tmp0, 0)); + _mm256_store_si256((__m256i*)outptr1, _mm512_extracti32x8_epi32(_tmp0, 1)); + if (tj * 2 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _mm512_extracti32x8_epi32(_tmp1, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 8), _mm512_extracti32x8_epi32(_tmp1, 1)); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + _mm_store_si128((__m128i*)outptr0, _mm512_extracti32x4_epi32(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm512_extracti32x4_epi32(_tmp0, 1)); + _mm_store_si128((__m128i*)outptr2, _mm512_extracti32x4_epi32(_tmp0, 2)); + _mm_store_si128((__m128i*)outptr3, _mm512_extracti32x4_epi32(_tmp0, 3)); + if (tj * 2 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm512_extracti32x4_epi32(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm512_extracti32x4_epi32(_tmp1, 1)); + _mm_store_si128((__m128i*)(outptr2 + 4), _mm512_extracti32x4_epi32(_tmp1, 2)); + _mm_store_si128((__m128i*)(outptr3 + 4), _mm512_extracti32x4_epi32(_tmp1, 3)); + } + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _mm512_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm512_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + int tmp[2][4][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + + for (int m = 0; m < 4; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_r0, _r1), _r2); + __m256i _tmp1 = _mm256_add_epi32(_mm256_sub_epi32(_r1, _r2), _r3); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); +#endif + + r0 += max_jj * 4 * 8; + r1 += max_jj * 4 * 8; + r2 += max_jj * 4 * 8; + r3 += max_jj * 4 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m256i _r0 = _mm256_loadu_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)tmp[m][3]); +#else + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); +#endif + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_r0, _r1), _r2); + __m256i _tmp1 = _mm256_add_epi32(_mm256_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm256_srai_epi32(_tmp0, 2); + _tmp1 = _mm256_srai_epi32(_tmp1, 2); + + if (out_elempack == 8) + { + _mm256_store_si256((__m256i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _tmp1); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + _mm_store_si128((__m128i*)outptr0, _mm256_extracti128_si256(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm256_extracti128_si256(_tmp0, 1)); + if (tj * 2 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm256_extracti128_si256(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm256_extracti128_si256(_tmp1, 1)); + } + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); + _mm256_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm256_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); +#else + int tmp0[8]; + int tmp1[8]; + _mm256_storeu_si256((__m256i*)tmp0, _tmp0); + _mm256_storeu_si256((__m256i*)tmp1, _tmp1); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + outptr4[0] = tmp0[4]; + outptr5[0] = tmp0[5]; + outptr6[0] = tmp0[6]; + outptr7[0] = tmp0[7]; + + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + outptr4[1] = tmp1[4]; + outptr5[1] = tmp1[5]; + outptr6[1] = tmp1[6]; + outptr7[1] = tmp1[7]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + int tmp[2][4][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + + for (int m = 0; m < 4; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_r0, _r1), _r2); + __m128i _tmp1 = _mm_add_epi32(_mm_sub_epi32(_r1, _r2), _r3); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); +#endif + + r0 += max_jj * 4 * 4; + r1 += max_jj * 4 * 4; + r2 += max_jj * 4 * 4; + r3 += max_jj * 4 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 2) + (tj * 2) * out_elempack; + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); +#endif + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_r0, _r1), _r2); + __m128i _tmp1 = _mm_add_epi32(_mm_sub_epi32(_r1, _r2), _r3); + + _tmp0 = _mm_srai_epi32(_tmp0, 2); + _tmp1 = _mm_srai_epi32(_tmp1, 2); + + if (out_elempack == 4) + { + _mm_store_si128((__m128i*)outptr0, _tmp0); + if (tj * 2 + 1 < outw) _mm_store_si128((__m128i*)(outptr0 + 4), _tmp1); + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(N)); + _mm_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 2 + 1 < outw) _mm_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); +#else + int tmp0[4]; + int tmp1[4]; + _mm_storeu_si128((__m128i*)tmp0, _tmp0); + _mm_storeu_si128((__m128i*)tmp1, _tmp1); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[2][4][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m][0] = r0[0] + r1[0] + r2[0]; + tmp[0][m][1] = r0[1] + r1[1] + r2[1]; + tmp[1][m][0] = r1[0] - r2[0] + r3[0]; + tmp[1][m][1] = r1[1] - r2[1] + r3[1]; + + r0 += max_jj * 4 * 2; + r1 += max_jj * 4 * 2; + r2 += max_jj * 4 * 2; + r3 += max_jj * 4 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp00 = tmp[m][0][0] + tmp[m][1][0] + tmp[m][2][0]; + int tmp01 = tmp[m][0][1] + tmp[m][1][1] + tmp[m][2][1]; + int tmp10 = tmp[m][1][0] - tmp[m][2][0] + tmp[m][3][0]; + int tmp11 = tmp[m][1][1] - tmp[m][2][1] + tmp[m][3][1]; + + tmp00 = tmp00 >> 2; + tmp01 = tmp01 >> 2; + tmp10 = tmp10 >> 2; + tmp11 = tmp11 >> 2; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 2 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[2][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 16 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + + for (int m = 0; m < 4; m++) + { + tmp[0][m] = r0[0] + r1[0] + r2[0]; + tmp[1][m] = r1[0] - r2[0] + r3[0]; + + r0 += max_jj * 4; + r1 += max_jj * 4; + r2 += max_jj * 4; + r3 += max_jj * 4; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 2) + (tj * 2); + + for (int m = 0; m < 2; m++) + { + if (ti * 2 + m >= outh) + continue; + + int tmp0 = tmp[m][0] + tmp[m][1] + tmp[m][2]; + int tmp1 = tmp[m][1] - tmp[m][2] + tmp[m][3]; + + tmp0 = tmp0 >> 2; + tmp1 = tmp1 >> 2; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 2 + 1 < outw) outptr0[1] = tmp1; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd23_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + conv3x3s1_winograd23_int8_avx512vnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + conv3x3s1_winograd23_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd23_int8_avx2(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ + if (ncnn::cpu_support_x86_xop()) + { + conv3x3s1_winograd23_int8_xop(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif +#endif + + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 2n+2, winograd F(2,3) + int w_tiles = (outw + 1) / 2; + int h_tiles = (outh + 1) / 2; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 16; + + // NCNN_LOGE("conv3x3s1_winograd23_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 2u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 2u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 2u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd23_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + bool k_end = k + TILE_K >= K; + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk, k_end); + } + + // transform output + conv3x3s1_winograd23_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} + +static inline void conv3x3s1_winograd43_transform_kernel_tile_int8(const Mat& kernel, Mat& A, int inch, int i, int max_ii, int k, int max_kk) +{ + // const short ktm[6][3] = { + // {6, 0, 0}, + // {-4, -4, -4}, + // {-4, 4, -4}, + // {1, 2, 4}, + // {1, -2, 4}, + // {0, 0, 6} + // }; + + short* ptmp = A; + + int ii = 0; + for (; ii < max_ii; ii++) + { + int kk = 0; + for (; kk < max_kk; kk++) + { + short tmp[6][3]; + + const signed char* k0 = (const signed char*)kernel + (i + ii) * inch * 9 + (k + kk) * 9; + + for (int m = 0; m < 3; m++) + { + signed char r0 = k0[0]; + signed char r1 = k0[1]; + signed char r2 = k0[2]; + + tmp[0][m] = r0 * 6; + tmp[1][m] = -r0 * 4 - r1 * 4 - r2 * 4; + tmp[2][m] = -r0 * 4 + r1 * 4 - r2 * 4; + tmp[3][m] = r0 + r1 * 2 + r2 * 4; + tmp[4][m] = r0 - r1 * 2 + r2 * 4; + tmp[5][m] = r2 * 6; + + k0 += 3; + } + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + + short z0 = r0 * 6; + short z1 = -r0 * 4 - r1 * 4 - r2 * 4; + short z2 = -r0 * 4 + r1 * 4 - r2 * 4; + short z3 = r0 + r1 * 2 + r2 * 4; + short z4 = r0 - r1 * 2 + r2 * 4; + short z5 = r2 * 6; + + ptmp[0] = z0; + ptmp[1] = z1; + ptmp[2] = z2; + ptmp[3] = z3; + ptmp[4] = z4; + ptmp[5] = z5; + ptmp += 6; + } + } + } +} + +static void conv3x3s1_winograd43_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd43_transform_kernel_int8_avx2(kernel, AT, inch, outch, opt); + return; + } +#endif +#endif + + const int M = outch; + const int K = inch; + const int B = 36; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, 0, K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + Mat A_tileX(B * TILE_M * TILE_K, 1, opt.num_threads, 4u, (Allocator*)0); + + AT.create(TILE_K * TILE_M, B, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 4u, (Allocator*)0); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat A_tile = A_tileX.channel(get_omp_thread_num()); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + conv3x3s1_winograd43_transform_kernel_tile_int8(kernel, A_tile, inch, i, max_ii, k, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + pack_A_tile_int8(A_tile, AT_tile, B, max_ii, max_kk); + } + } +} + +static inline void conv3x3s1_winograd43_transform_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int nT) +{ + // const float itm[4][4] = { + // {4, 0, -5, 0, 1, 0}, + // {0, -4, -4, 1, 1, 0}, + // {0, 4, -4, -1, 1, 0}, + // {0, -2, -1, 2, 1, 0}, + // {0, 2, -1, -2, 1, 0}, + // {0, 4, 0, -5, 0, 1} + // }; + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int elempack = bottom_blob.elempack; + const int N = bottom_blob.cstep * elempack; + + const int w_tiles = (w + 1) / 4; + + int nn_max_kk = 0; + int remain_max_kk_start = 0; +#if __SSE2__ +#if __AVX512F__ + nn_max_kk = max_kk / 16; + #pragma omp parallel for num_threads(nT) + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = ppkk * 16; + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + short tmp[6][6][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + __m256i _v2 = _mm256_set1_epi16(2); + __m256i _v4 = _mm256_set1_epi16(4); + __m256i _v5 = _mm256_set1_epi16(5); + + for (int m = 0; m < 6; m++) + { + __m256i _r0 = _mm256_setzero_si256(); + __m256i _r1 = _mm256_setzero_si256(); + __m256i _r2 = _mm256_setzero_si256(); + __m256i _r3 = _mm256_setzero_si256(); + __m256i _r4 = _mm256_setzero_si256(); + __m256i _r5 = _mm256_setzero_si256(); + + if (ti * 4 + m < h) + { + if (elempack == 16) + { + _r0 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)r0)); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 16))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 32))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 48))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 64))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm_load_si128((const __m128i*)(r0 + 80))); + } + if (elempack == 8) + { + const signed char* r1 = r0 + N; + + _r0 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)r0), _mm_loadl_epi64((const __m128i*)r1))); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 8)), _mm_loadl_epi64((const __m128i*)(r1 + 8)))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 16)), _mm_loadl_epi64((const __m128i*)(r1 + 16)))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 24)), _mm_loadl_epi64((const __m128i*)(r1 + 24)))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 32)), _mm_loadl_epi64((const __m128i*)(r1 + 32)))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)(r0 + 40)), _mm_loadl_epi64((const __m128i*)(r1 + 40)))); + } + if (elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _r0 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)r0, sizeof(signed char)))); + if (tj * 4 + 1 < w) _r1 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 1), sizeof(signed char)))); + if (tj * 4 + 2 < w) _r2 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 2), sizeof(signed char)))); + if (tj * 4 + 3 < w) _r3 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 3), sizeof(signed char)))); + if (tj * 4 + 4 < w) _r4 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 4), sizeof(signed char)))); + if (tj * 4 + 5 < w) _r5 = _mm256_cvtepi8_epi16(_mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, (const int*)(r0 + 5), sizeof(signed char)))); + } + } + + __m256i _tmp12a = _mm256_sub_epi16(_r3, _mm256_mullo_epi16(_r1, _v4)); + __m256i _tmp12b = _mm256_sub_epi16(_r4, _mm256_mullo_epi16(_r2, _v4)); + __m256i _tmp34a = _mm256_mullo_epi16(_mm256_sub_epi16(_r3, _r1), _v2); + __m256i _tmp34b = _mm256_sub_epi16(_r4, _r2); + + __m256i _tmp0 = _mm256_add_epi16(_r4, _mm256_sub_epi16(_mm256_mullo_epi16(_r0, _v4), _mm256_mullo_epi16(_r2, _v5))); + __m256i _tmp1 = _mm256_add_epi16(_tmp12b, _tmp12a); + __m256i _tmp2 = _mm256_sub_epi16(_tmp12b, _tmp12a); + __m256i _tmp3 = _mm256_add_epi16(_tmp34b, _tmp34a); + __m256i _tmp4 = _mm256_sub_epi16(_tmp34b, _tmp34a); + __m256i _tmp5 = _mm256_add_epi16(_r5, _mm256_sub_epi16(_mm256_mullo_epi16(_r1, _v4), _mm256_mullo_epi16(_r3, _v5))); + + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); + _mm256_store_si256((__m256i*)tmp[4][m], _tmp4); + _mm256_store_si256((__m256i*)tmp[5][m], _tmp5); + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 16; + short* p1 = p0 + max_jj * 16; + short* p2 = p0 + max_jj * 16 * 2; + short* p3 = p0 + max_jj * 16 * 3; + short* p4 = p0 + max_jj * 16 * 4; + short* p5 = p0 + max_jj * 16 * 5; + + for (int m = 0; m < 6; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_load_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_load_si256((const __m256i*)tmp[m][5]); + + __m256i _tmp12a = _mm256_sub_epi16(_r3, _mm256_mullo_epi16(_r1, _v4)); + __m256i _tmp12b = _mm256_sub_epi16(_r4, _mm256_mullo_epi16(_r2, _v4)); + __m256i _tmp34a = _mm256_mullo_epi16(_mm256_sub_epi16(_r3, _r1), _v2); + __m256i _tmp34b = _mm256_sub_epi16(_r4, _r2); + + __m256i _tmp0 = _mm256_add_epi16(_r4, _mm256_sub_epi16(_mm256_mullo_epi16(_r0, _v4), _mm256_mullo_epi16(_r2, _v5))); + __m256i _tmp1 = _mm256_add_epi16(_tmp12b, _tmp12a); + __m256i _tmp2 = _mm256_sub_epi16(_tmp12b, _tmp12a); + __m256i _tmp3 = _mm256_add_epi16(_tmp34b, _tmp34a); + __m256i _tmp4 = _mm256_sub_epi16(_tmp34b, _tmp34a); + __m256i _tmp5 = _mm256_add_epi16(_r5, _mm256_sub_epi16(_mm256_mullo_epi16(_r1, _v4), _mm256_mullo_epi16(_r3, _v5))); + + _mm256_store_si256((__m256i*)p0, _tmp0); + _mm256_store_si256((__m256i*)p1, _tmp1); + _mm256_store_si256((__m256i*)p2, _tmp2); + _mm256_store_si256((__m256i*)p3, _tmp3); + _mm256_store_si256((__m256i*)p4, _tmp4); + _mm256_store_si256((__m256i*)p5, _tmp5); + + p0 += max_jj * 6 * 16; + p1 += max_jj * 6 * 16; + p2 += max_jj * 6 * 16; + p3 += max_jj * 6 * 16; + p4 += max_jj * 6 * 16; + p5 += max_jj * 6 * 16; + } + } + } + remain_max_kk_start += nn_max_kk * 16; + nn_max_kk = (max_kk - remain_max_kk_start) / 8; +#else // __AVX512F__ + nn_max_kk = (max_kk - remain_max_kk_start) / 8; + #pragma omp parallel for num_threads(nT) +#endif // __AVX512F__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 8; + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + short tmp[6][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel((k + kk) / elempack).row(ti * 4) + (tj * 4) * elempack; + + __m128i _v2 = _mm_set1_epi16(2); + __m128i _v4 = _mm_set1_epi16(4); + __m128i _v5 = _mm_set1_epi16(5); + + for (int m = 0; m < 6; m++) + { + __m128i _r0 = _mm_setzero_si128(); + __m128i _r1 = _mm_setzero_si128(); + __m128i _r2 = _mm_setzero_si128(); + __m128i _r3 = _mm_setzero_si128(); + __m128i _r4 = _mm_setzero_si128(); + __m128i _r5 = _mm_setzero_si128(); + + if (ti * 4 + m < h) + { + if (elempack == 8) + { + _r0 = _mm_loadl_epi64((const __m128i*)r0); + _r0 = _mm_unpacklo_epi8(_r0, _mm_cmpgt_epi8(_mm_setzero_si128(), _r0)); + if (tj * 4 + 1 < w) + { + _r1 = _mm_loadl_epi64((const __m128i*)(r0 + 8)); + _r1 = _mm_unpacklo_epi8(_r1, _mm_cmpgt_epi8(_mm_setzero_si128(), _r1)); + } + if (tj * 4 + 2 < w) + { + _r2 = _mm_loadl_epi64((const __m128i*)(r0 + 16)); + _r2 = _mm_unpacklo_epi8(_r2, _mm_cmpgt_epi8(_mm_setzero_si128(), _r2)); + } + if (tj * 4 + 3 < w) + { + _r3 = _mm_loadl_epi64((const __m128i*)(r0 + 24)); + _r3 = _mm_unpacklo_epi8(_r3, _mm_cmpgt_epi8(_mm_setzero_si128(), _r3)); + } + if (tj * 4 + 4 < w) + { + _r4 = _mm_loadl_epi64((const __m128i*)(r0 + 32)); + _r4 = _mm_unpacklo_epi8(_r4, _mm_cmpgt_epi8(_mm_setzero_si128(), _r4)); + } + if (tj * 4 + 5 < w) + { + _r5 = _mm_loadl_epi64((const __m128i*)(r0 + 40)); + _r5 = _mm_unpacklo_epi8(_r5, _mm_cmpgt_epi8(_mm_setzero_si128(), _r5)); + } + } + if (elempack == 1) + { +#if __AVX2__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); +#if __AVX512F__ + _r0 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)))); + if (tj * 4 + 1 < w) _r1 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)))); + if (tj * 4 + 2 < w) _r2 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)))); + if (tj * 4 + 3 < w) _r3 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)))); + if (tj * 4 + 4 < w) _r4 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 4), _vindex, sizeof(signed char)))); + if (tj * 4 + 5 < w) _r5 = _mm_cvtepi8_epi16(_mm256_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)(r0 + 5), _vindex, sizeof(signed char)))); +#else + __m128i _sindex8 = _mm_setr_epi8(0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); + __m256i _sindex88 = _mm256_inserti128_si256(_mm256_castsi128_si256(_sindex8), _sindex8, 1); + __m256i _val0_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)r0, _vindex, sizeof(signed char)), _sindex88); + _r0 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val0_32, 0), _mm256_extracti128_si256(_val0_32, 1))); + if (tj * 4 + 1 < w) + { + __m256i _val1_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 1), _vindex, sizeof(signed char)), _sindex88); + _r1 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val1_32, 0), _mm256_extracti128_si256(_val1_32, 1))); + } + if (tj * 4 + 2 < w) + { + __m256i _val2_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 2), _vindex, sizeof(signed char)), _sindex88); + _r2 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val2_32, 0), _mm256_extracti128_si256(_val2_32, 1))); + } + if (tj * 4 + 3 < w) + { + __m256i _val3_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 3), _vindex, sizeof(signed char)), _sindex88); + _r3 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val3_32, 0), _mm256_extracti128_si256(_val3_32, 1))); + } + if (tj * 4 + 4 < w) + { + __m256i _val4_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 4), _vindex, sizeof(signed char)), _sindex88); + _r4 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val4_32, 0), _mm256_extracti128_si256(_val4_32, 1))); + } + if (tj * 4 + 5 < w) + { + __m256i _val5_32 = _mm256_shuffle_epi8(_mm256_i32gather_epi32((const int*)(r0 + 5), _vindex, sizeof(signed char)), _sindex88); + _r5 = _mm_cvtepi8_epi16(_mm_unpacklo_epi32(_mm256_extracti128_si256(_val5_32, 0), _mm256_extracti128_si256(_val5_32, 1))); + } +#endif // __AVX512F__ +#else // __AVX2__ + const signed char* r1 = r0 + N; + const signed char* r2 = r0 + N * 2; + const signed char* r3 = r0 + N * 3; + const signed char* r4 = r0 + N * 4; + const signed char* r5 = r0 + N * 5; + const signed char* r6 = r0 + N * 6; + const signed char* r7 = r0 + N * 7; + + __m128i _t0 = _mm_loadl_epi64((const __m128i*)r0); + __m128i _t1 = _mm_loadl_epi64((const __m128i*)r1); + __m128i _t2 = _mm_loadl_epi64((const __m128i*)r2); + __m128i _t3 = _mm_loadl_epi64((const __m128i*)r3); + __m128i _t4 = _mm_loadl_epi64((const __m128i*)r4); + __m128i _t5 = _mm_loadl_epi64((const __m128i*)r5); + __m128i _t6 = _mm_loadl_epi64((const __m128i*)r6); + __m128i _t7 = _mm_loadl_epi64((const __m128i*)r7); + + __m128i _t01 = _mm_unpacklo_epi8(_t0, _t1); + __m128i _t23 = _mm_unpacklo_epi8(_t2, _t3); + __m128i _t45 = _mm_unpacklo_epi8(_t4, _t5); + __m128i _t67 = _mm_unpacklo_epi8(_t6, _t7); + _t0 = _mm_unpacklo_epi16(_t01, _t23); + _t1 = _mm_unpacklo_epi16(_t45, _t67); + _t2 = _mm_unpacklo_epi32(_t0, _t1); + _t3 = _mm_unpackhi_epi32(_t0, _t1); + + __m128i _extt2 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t2); + __m128i _extt3 = _mm_cmpgt_epi8(_mm_setzero_si128(), _t3); + + _r0 = _mm_unpacklo_epi8(_t2, _extt2); + if (tj * 4 + 1 < w) _r1 = _mm_unpackhi_epi8(_t2, _extt2); + if (tj * 4 + 2 < w) _r2 = _mm_unpacklo_epi8(_t3, _extt3); + if (tj * 4 + 3 < w) _r3 = _mm_unpackhi_epi8(_t3, _extt3); + if (tj * 4 + 4 < w) _r4 = _mm_setr_epi16(r0[4], r1[4], r2[4], r3[4], r4[4], r5[4], r6[4], r7[4]); + if (tj * 4 + 5 < w) _r5 = _mm_setr_epi16(r0[5], r1[5], r2[5], r3[5], r4[5], r5[5], r6[5], r7[5]); +#endif // __AVX2__ + } + } + + __m128i _tmp12a = _mm_sub_epi16(_r3, _mm_mullo_epi16(_r1, _v4)); + __m128i _tmp12b = _mm_sub_epi16(_r4, _mm_mullo_epi16(_r2, _v4)); + __m128i _tmp34a = _mm_mullo_epi16(_mm_sub_epi16(_r3, _r1), _v2); + __m128i _tmp34b = _mm_sub_epi16(_r4, _r2); + + __m128i _tmp0 = _mm_add_epi16(_r4, _mm_sub_epi16(_mm_mullo_epi16(_r0, _v4), _mm_mullo_epi16(_r2, _v5))); + __m128i _tmp1 = _mm_add_epi16(_tmp12b, _tmp12a); + __m128i _tmp2 = _mm_sub_epi16(_tmp12b, _tmp12a); + __m128i _tmp3 = _mm_add_epi16(_tmp34b, _tmp34a); + __m128i _tmp4 = _mm_sub_epi16(_tmp34b, _tmp34a); + __m128i _tmp5 = _mm_add_epi16(_r5, _mm_sub_epi16(_mm_mullo_epi16(_r1, _v4), _mm_mullo_epi16(_r3, _v5))); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); + _mm_storeu_si128((__m128i*)tmp[4][m], _tmp4); + _mm_storeu_si128((__m128i*)tmp[5][m], _tmp5); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); + _mm_store_si128((__m128i*)tmp[4][m], _tmp4); + _mm_store_si128((__m128i*)tmp[5][m], _tmp5); +#endif + + r0 += w * elempack; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 8; + short* p1 = p0 + max_jj * 8; + short* p2 = p0 + max_jj * 8 * 2; + short* p3 = p0 + max_jj * 8 * 3; + short* p4 = p0 + max_jj * 8 * 4; + short* p5 = p0 + max_jj * 8 * 5; + + for (int m = 0; m < 6; m++) + { +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_loadu_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_loadu_si128((const __m128i*)tmp[m][5]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_load_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_load_si128((const __m128i*)tmp[m][5]); +#endif + + __m128i _tmp12a = _mm_sub_epi16(_r3, _mm_mullo_epi16(_r1, _v4)); + __m128i _tmp12b = _mm_sub_epi16(_r4, _mm_mullo_epi16(_r2, _v4)); + __m128i _tmp34a = _mm_mullo_epi16(_mm_sub_epi16(_r3, _r1), _v2); + __m128i _tmp34b = _mm_sub_epi16(_r4, _r2); + + __m128i _tmp0 = _mm_add_epi16(_r4, _mm_sub_epi16(_mm_mullo_epi16(_r0, _v4), _mm_mullo_epi16(_r2, _v5))); + __m128i _tmp1 = _mm_add_epi16(_tmp12b, _tmp12a); + __m128i _tmp2 = _mm_sub_epi16(_tmp12b, _tmp12a); + __m128i _tmp3 = _mm_add_epi16(_tmp34b, _tmp34a); + __m128i _tmp4 = _mm_sub_epi16(_tmp34b, _tmp34a); + __m128i _tmp5 = _mm_add_epi16(_r5, _mm_sub_epi16(_mm_mullo_epi16(_r1, _v4), _mm_mullo_epi16(_r3, _v5))); + + _mm_store_si128((__m128i*)p0, _tmp0); + _mm_store_si128((__m128i*)p1, _tmp1); + _mm_store_si128((__m128i*)p2, _tmp2); + _mm_store_si128((__m128i*)p3, _tmp3); + _mm_store_si128((__m128i*)p4, _tmp4); + _mm_store_si128((__m128i*)p5, _tmp5); + + p0 += max_jj * 6 * 8; + p1 += max_jj * 6 * 8; + p2 += max_jj * 6 * 8; + p3 += max_jj * 6 * 8; + p4 += max_jj * 6 * 8; + p5 += max_jj * 6 * 8; + } + } + } + remain_max_kk_start += nn_max_kk * 8; + nn_max_kk = (max_kk - remain_max_kk_start) / 2; +#else // __SSE2__ + nn_max_kk = (max_kk - remain_max_kk_start) / 2; + #pragma omp parallel for num_threads(nT) +#endif // __SSE2__ + for (int ppkk = 0; ppkk < nn_max_kk; ppkk++) + { + const int kk = remain_max_kk_start + ppkk * 2; + + short tmp[6][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r00 = 0; + signed char r01 = 0; + signed char r10 = 0; + signed char r11 = 0; + signed char r20 = 0; + signed char r21 = 0; + signed char r30 = 0; + signed char r31 = 0; + signed char r40 = 0; + signed char r41 = 0; + signed char r50 = 0; + signed char r51 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + const signed char* r1 = r0 + N; + + r00 = r0[0]; + r01 = r1[0]; + if (tj * 4 + 1 < w) + { + r10 = r0[1]; + r11 = r1[1]; + } + if (tj * 4 + 2 < w) + { + r20 = r0[2]; + r21 = r1[2]; + } + if (tj * 4 + 3 < w) + { + r30 = r0[3]; + r31 = r1[3]; + } + if (tj * 4 + 4 < w) + { + r40 = r0[4]; + r41 = r1[4]; + } + if (tj * 4 + 5 < w) + { + r50 = r0[5]; + r51 = r1[5]; + } + } + } + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + tmp[0][m][0] = r40 + r00 * 4 - r20 * 5; + tmp[0][m][1] = r41 + r01 * 4 - r21 * 5; + tmp[1][m][0] = tmp120b + tmp120a; + tmp[1][m][1] = tmp121b + tmp121a; + tmp[2][m][0] = tmp120b - tmp120a; + tmp[2][m][1] = tmp121b - tmp121a; + tmp[3][m][0] = tmp340b + tmp340a; + tmp[3][m][1] = tmp341b + tmp341a; + tmp[4][m][0] = tmp340b - tmp340a; + tmp[4][m][1] = tmp341b - tmp341a; + tmp[5][m][0] = r50 + r10 * 4 - r30 * 5; + tmp[5][m][1] = r51 + r11 * 4 - r31 * 5; + + r0 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj * 2; + short* p1 = p0 + max_jj * 2; + short* p2 = p0 + max_jj * 2 * 2; + short* p3 = p0 + max_jj * 2 * 3; + short* p4 = p0 + max_jj * 2 * 4; + short* p5 = p0 + max_jj * 2 * 5; + + for (int m = 0; m < 6; m++) + { + short r00 = tmp[m][0][0]; + short r01 = tmp[m][0][1]; + short r10 = tmp[m][1][0]; + short r11 = tmp[m][1][1]; + short r20 = tmp[m][2][0]; + short r21 = tmp[m][2][1]; + short r30 = tmp[m][3][0]; + short r31 = tmp[m][3][1]; + short r40 = tmp[m][4][0]; + short r41 = tmp[m][4][1]; + short r50 = tmp[m][5][0]; + short r51 = tmp[m][5][1]; + + short tmp120a = r30 - r10 * 4; + short tmp121a = r31 - r11 * 4; + short tmp120b = r40 - r20 * 4; + short tmp121b = r41 - r21 * 4; + short tmp340a = (r30 - r10) * 2; + short tmp341a = (r31 - r11) * 2; + short tmp340b = r40 - r20; + short tmp341b = r41 - r21; + + p0[0] = r40 + r00 * 4 - r20 * 5; + p0[1] = r41 + r01 * 4 - r21 * 5; + p1[0] = tmp120b + tmp120a; + p1[1] = tmp121b + tmp121a; + p2[0] = tmp120b - tmp120a; + p2[1] = tmp121b - tmp121a; + p3[0] = tmp340b + tmp340a; + p3[1] = tmp341b + tmp341a; + p4[0] = tmp340b - tmp340a; + p4[1] = tmp341b - tmp341a; + p5[0] = r50 + r10 * 4 - r30 * 5; + p5[1] = r51 + r11 * 4 - r31 * 5; + + p0 += max_jj * 6 * 2; + p1 += max_jj * 6 * 2; + p2 += max_jj * 6 * 2; + p3 += max_jj * 6 * 2; + p4 += max_jj * 6 * 2; + p5 += max_jj * 6 * 2; + } + } + } + remain_max_kk_start += nn_max_kk * 2; + for (int kk = remain_max_kk_start; kk < max_kk; kk++) + { + short tmp[6][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const signed char* r0123 = bottom_blob.channel(k + kk).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 6; m++) + { + signed char r0 = 0; + signed char r1 = 0; + signed char r2 = 0; + signed char r3 = 0; + signed char r4 = 0; + signed char r5 = 0; + + if (ti * 4 + m < h) + { + // if (elempack == 1) + { + r0 = r0123[0]; + if (tj * 4 + 1 < w) r1 = r0123[1]; + if (tj * 4 + 2 < w) r2 = r0123[2]; + if (tj * 4 + 3 < w) r3 = r0123[3]; + if (tj * 4 + 4 < w) r4 = r0123[4]; + if (tj * 4 + 5 < w) r5 = r0123[5]; + } + } + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + tmp[0][m] = r4 + r0 * 4 - r2 * 5; + tmp[1][m] = tmp12b + tmp12a; + tmp[2][m] = tmp12b - tmp12a; + tmp[3][m] = tmp34b + tmp34a; + tmp[4][m] = tmp34b - tmp34a; + tmp[5][m] = r5 + r1 * 4 - r3 * 5; + + r0123 += w; + } + + short* p0 = (short*)B + kk * max_jj * 36 + jj; + short* p1 = p0 + max_jj; + short* p2 = p0 + max_jj * 2; + short* p3 = p0 + max_jj * 3; + short* p4 = p0 + max_jj * 4; + short* p5 = p0 + max_jj * 5; + + for (int m = 0; m < 6; m++) + { + short r0 = tmp[m][0]; + short r1 = tmp[m][1]; + short r2 = tmp[m][2]; + short r3 = tmp[m][3]; + short r4 = tmp[m][4]; + short r5 = tmp[m][5]; + + short tmp12a = r3 - r1 * 4; + short tmp12b = r4 - r2 * 4; + short tmp34a = (r3 - r1) * 2; + short tmp34b = r4 - r2; + + p0[0] = r4 + r0 * 4 - r2 * 5; + p1[0] = tmp12b + tmp12a; + p2[0] = tmp12b - tmp12a; + p3[0] = tmp34b + tmp34a; + p4[0] = tmp34b - tmp34a; + p5[0] = r5 + r1 * 4 - r3 * 5; + + p0 += max_jj * 6; + p1 += max_jj * 6; + p2 += max_jj * 6; + p3 += max_jj * 6; + p4 += max_jj * 6; + p5 += max_jj * 6; + } + } + } +} + +static inline void conv3x3s1_winograd43_transform_output_tile_int8(const Mat& top_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj) +{ + // const int otm[4][6] = { + // {1, 1, 1, 1, 1, 0}, + // {0, 1, -1, 2, -2, 0}, + // {0, 1, 1, 4, 4, 0}, + // {0, 1, -1, 8, -8, 1} + // }; + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int out_elempack = top_blob.elempack; + const int N = top_blob.cstep * out_elempack; + + const int w_tiles = (outw + 3) / 4; + + int ii = 0; +#if __SSE2__ +#if __AVX2__ +#if __AVX512F__ + for (; ii + 15 < max_ii; ii += 16) + { +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + int tmp[4][6][16]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 16; + const int* r1 = r0 + max_jj * 16; + const int* r2 = r0 + max_jj * 16 * 2; + const int* r3 = r0 + max_jj * 16 * 3; + const int* r4 = r0 + max_jj * 16 * 4; + const int* r5 = r0 + max_jj * 16 * 5; + + for (int m = 0; m < 5; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + __m512i _r4 = _mm512_load_si512((const __m512i*)r4); + __m512i _r5 = _mm512_load_si512((const __m512i*)r5); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _mm512_slli_epi32(_r5, 2)); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + _mm512_store_si512((__m512i*)tmp[2][m], _tmp2); + _mm512_store_si512((__m512i*)tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 16; + r1 += max_jj * 6 * 16; + r2 += max_jj * 6 * 16; + r3 += max_jj * 6 * 16; + r4 += max_jj * 6 * 16; + r5 += max_jj * 6 * 16; + } + for (int m = 5; m < 6; m++) + { + __m512i _r0 = _mm512_load_si512((const __m512i*)r0); + __m512i _r1 = _mm512_load_si512((const __m512i*)r1); + __m512i _r2 = _mm512_load_si512((const __m512i*)r2); + __m512i _r3 = _mm512_load_si512((const __m512i*)r3); + __m512i _r4 = _mm512_load_si512((const __m512i*)r4); + __m512i _r5 = _mm512_load_si512((const __m512i*)r5); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _mm512_slli_epi32(_r5, 2)); + + _tmp0 = _mm512_slli_epi32(_tmp0, 2); + _tmp1 = _mm512_slli_epi32(_tmp1, 2); + _tmp2 = _mm512_slli_epi32(_tmp2, 2); + _tmp3 = _mm512_slli_epi32(_tmp3, 2); + + _mm512_store_si512((__m512i*)tmp[0][m], _tmp0); + _mm512_store_si512((__m512i*)tmp[1][m], _tmp1); + _mm512_store_si512((__m512i*)tmp[2][m], _tmp2); + _mm512_store_si512((__m512i*)tmp[3][m], _tmp3); + + r0 += max_jj * 6 * 16; + r1 += max_jj * 6 * 16; + r2 += max_jj * 6 * 16; + r3 += max_jj * 6 * 16; + r4 += max_jj * 6 * 16; + r5 += max_jj * 6 * 16; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + __m512i _r0 = _mm512_load_si512((const __m512i*)tmp[m][0]); + __m512i _r1 = _mm512_load_si512((const __m512i*)tmp[m][1]); + __m512i _r2 = _mm512_load_si512((const __m512i*)tmp[m][2]); + __m512i _r3 = _mm512_load_si512((const __m512i*)tmp[m][3]); + __m512i _r4 = _mm512_load_si512((const __m512i*)tmp[m][4]); + __m512i _r5 = _mm512_load_si512((const __m512i*)tmp[m][5]); + + __m512i _tmp02a = _mm512_add_epi32(_r1, _r2); + __m512i _tmp02b = _mm512_add_epi32(_r3, _r4); + __m512i _tmp13a = _mm512_sub_epi32(_r1, _r2); + __m512i _tmp13b = _mm512_sub_epi32(_r3, _r4); + + __m512i _tmp0 = _mm512_add_epi32(_mm512_add_epi32(_tmp02a, _tmp02b), _r0); + __m512i _tmp1 = _mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 1)); + __m512i _tmp2 = _mm512_add_epi32(_tmp02a, _mm512_slli_epi32(_tmp02b, 2)); + __m512i _tmp3 = _mm512_add_epi32(_mm512_add_epi32(_tmp13a, _mm512_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m512 _v576 = _mm512_set1_ps(1.0 / 576); + _tmp0 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm512_cvttps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 16) + { + _mm512_store_si512((__m512i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm512_store_si512((__m512i*)(outptr0 + 16), _tmp1); + if (tj * 4 + 2 < outw) _mm512_store_si512((__m512i*)(outptr0 + 32), _tmp2); + if (tj * 4 + 3 < outw) _mm512_store_si512((__m512i*)(outptr0 + 48), _tmp3); + } + if (out_elempack == 8) + { + int* outptr1 = outptr0 + N; + + _mm256_store_si256((__m256i*)outptr0, _mm512_extracti32x8_epi32(_tmp0, 0)); + _mm256_store_si256((__m256i*)outptr1, _mm512_extracti32x8_epi32(_tmp0, 1)); + if (tj * 4 + 1 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 8), _mm512_extracti32x8_epi32(_tmp1, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 8), _mm512_extracti32x8_epi32(_tmp1, 1)); + } + if (tj * 4 + 2 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 16), _mm512_extracti32x8_epi32(_tmp2, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 16), _mm512_extracti32x8_epi32(_tmp2, 1)); + } + if (tj * 4 + 3 < outw) + { + _mm256_store_si256((__m256i*)(outptr0 + 24), _mm512_extracti32x8_epi32(_tmp3, 0)); + _mm256_store_si256((__m256i*)(outptr1 + 24), _mm512_extracti32x8_epi32(_tmp3, 1)); + } + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + _mm_store_si128((__m128i*)outptr0, _mm512_extracti32x4_epi32(_tmp0, 0)); + _mm_store_si128((__m128i*)outptr1, _mm512_extracti32x4_epi32(_tmp0, 1)); + _mm_store_si128((__m128i*)outptr2, _mm512_extracti32x4_epi32(_tmp0, 2)); + _mm_store_si128((__m128i*)outptr3, _mm512_extracti32x4_epi32(_tmp0, 3)); + if (tj * 4 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm512_extracti32x4_epi32(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm512_extracti32x4_epi32(_tmp1, 1)); + _mm_store_si128((__m128i*)(outptr2 + 4), _mm512_extracti32x4_epi32(_tmp1, 2)); + _mm_store_si128((__m128i*)(outptr3 + 4), _mm512_extracti32x4_epi32(_tmp1, 3)); + } + if (tj * 4 + 2 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 8), _mm512_extracti32x4_epi32(_tmp2, 0)); + _mm_store_si128((__m128i*)(outptr1 + 8), _mm512_extracti32x4_epi32(_tmp2, 1)); + _mm_store_si128((__m128i*)(outptr2 + 8), _mm512_extracti32x4_epi32(_tmp2, 2)); + _mm_store_si128((__m128i*)(outptr3 + 8), _mm512_extracti32x4_epi32(_tmp2, 3)); + } + if (tj * 4 + 3 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 12), _mm512_extracti32x4_epi32(_tmp3, 0)); + _mm_store_si128((__m128i*)(outptr1 + 12), _mm512_extracti32x4_epi32(_tmp3, 1)); + _mm_store_si128((__m128i*)(outptr2 + 12), _mm512_extracti32x4_epi32(_tmp3, 2)); + _mm_store_si128((__m128i*)(outptr3 + 12), _mm512_extracti32x4_epi32(_tmp3, 3)); + } + } + if (out_elempack == 1) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(N)); + _mm512_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm512_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm512_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm512_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX512F__ + for (; ii + 7 < max_ii; ii += 8) + { +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + int tmp[4][6][8]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 8; + const int* r1 = r0 + max_jj * 8; + const int* r2 = r0 + max_jj * 8 * 2; + const int* r3 = r0 + max_jj * 8 * 3; + const int* r4 = r0 + max_jj * 8 * 4; + const int* r5 = r0 + max_jj * 8 * 5; + + for (int m = 0; m < 5; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + __m256i _r4 = _mm256_load_si256((const __m256i*)r4); + __m256i _r5 = _mm256_load_si256((const __m256i*)r5); + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _mm256_slli_epi32(_r5, 2)); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_storeu_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_storeu_si256((__m256i*)tmp[3][m], _tmp3); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + for (int m = 5; m < 6; m++) + { + __m256i _r0 = _mm256_load_si256((const __m256i*)r0); + __m256i _r1 = _mm256_load_si256((const __m256i*)r1); + __m256i _r2 = _mm256_load_si256((const __m256i*)r2); + __m256i _r3 = _mm256_load_si256((const __m256i*)r3); + __m256i _r4 = _mm256_load_si256((const __m256i*)r4); + __m256i _r5 = _mm256_load_si256((const __m256i*)r5); + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _mm256_slli_epi32(_r5, 2)); + + _tmp0 = _mm256_slli_epi32(_tmp0, 2); + _tmp1 = _mm256_slli_epi32(_tmp1, 2); + _tmp2 = _mm256_slli_epi32(_tmp2, 2); + _tmp3 = _mm256_slli_epi32(_tmp3, 2); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm256_storeu_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_storeu_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_storeu_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_storeu_si256((__m256i*)tmp[3][m], _tmp3); +#else + _mm256_store_si256((__m256i*)tmp[0][m], _tmp0); + _mm256_store_si256((__m256i*)tmp[1][m], _tmp1); + _mm256_store_si256((__m256i*)tmp[2][m], _tmp2); + _mm256_store_si256((__m256i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 8; + r1 += max_jj * 6 * 8; + r2 += max_jj * 6 * 8; + r3 += max_jj * 6 * 8; + r4 += max_jj * 6 * 8; + r5 += max_jj * 6 * 8; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m256i _r0 = _mm256_loadu_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_loadu_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_loadu_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_loadu_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_loadu_si256((const __m256i*)tmp[m][5]); +#else + __m256i _r0 = _mm256_load_si256((const __m256i*)tmp[m][0]); + __m256i _r1 = _mm256_load_si256((const __m256i*)tmp[m][1]); + __m256i _r2 = _mm256_load_si256((const __m256i*)tmp[m][2]); + __m256i _r3 = _mm256_load_si256((const __m256i*)tmp[m][3]); + __m256i _r4 = _mm256_load_si256((const __m256i*)tmp[m][4]); + __m256i _r5 = _mm256_load_si256((const __m256i*)tmp[m][5]); +#endif + + __m256i _tmp02a = _mm256_add_epi32(_r1, _r2); + __m256i _tmp02b = _mm256_add_epi32(_r3, _r4); + __m256i _tmp13a = _mm256_sub_epi32(_r1, _r2); + __m256i _tmp13b = _mm256_sub_epi32(_r3, _r4); + + __m256i _tmp0 = _mm256_add_epi32(_mm256_add_epi32(_tmp02a, _tmp02b), _r0); + __m256i _tmp1 = _mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 1)); + __m256i _tmp2 = _mm256_add_epi32(_tmp02a, _mm256_slli_epi32(_tmp02b, 2)); + __m256i _tmp3 = _mm256_add_epi32(_mm256_add_epi32(_tmp13a, _mm256_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m256 _v576 = _mm256_set1_ps(1.0 / 576); + _tmp0 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm256_cvttps_epi32(_mm256_mul_ps(_mm256_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 8) + { + _mm256_store_si256((__m256i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm256_store_si256((__m256i*)(outptr0 + 8), _tmp1); + if (tj * 4 + 2 < outw) _mm256_store_si256((__m256i*)(outptr0 + 16), _tmp2); + if (tj * 4 + 3 < outw) _mm256_store_si256((__m256i*)(outptr0 + 24), _tmp3); + } + if (out_elempack == 4) + { + int* outptr1 = outptr0 + N; + + _mm_store_si128((__m128i*)(outptr0), _mm256_extracti128_si256(_tmp0, 0)); + _mm_store_si128((__m128i*)(outptr1), _mm256_extracti128_si256(_tmp0, 1)); + if (tj * 4 + 1 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 4), _mm256_extracti128_si256(_tmp1, 0)); + _mm_store_si128((__m128i*)(outptr1 + 4), _mm256_extracti128_si256(_tmp1, 1)); + } + if (tj * 4 + 2 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 8), _mm256_extracti128_si256(_tmp2, 0)); + _mm_store_si128((__m128i*)(outptr1 + 8), _mm256_extracti128_si256(_tmp2, 1)); + } + if (tj * 4 + 3 < outw) + { + _mm_store_si128((__m128i*)(outptr0 + 12), _mm256_extracti128_si256(_tmp3, 0)); + _mm_store_si128((__m128i*)(outptr1 + 12), _mm256_extracti128_si256(_tmp3, 1)); + } + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(N)); + _mm256_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm256_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm256_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm256_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); +#else + int tmp0[8]; + int tmp1[8]; + int tmp2[8]; + int tmp3[8]; + _mm256_storeu_si256((__m256i*)tmp0, _tmp0); + _mm256_storeu_si256((__m256i*)tmp1, _tmp1); + _mm256_storeu_si256((__m256i*)tmp2, _tmp2); + _mm256_storeu_si256((__m256i*)tmp3, _tmp3); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + int* outptr4 = outptr0 + N * 4; + int* outptr5 = outptr0 + N * 5; + int* outptr6 = outptr0 + N * 6; + int* outptr7 = outptr0 + N * 7; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + outptr4[0] = tmp0[4]; + outptr5[0] = tmp0[5]; + outptr6[0] = tmp0[6]; + outptr7[0] = tmp0[7]; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + outptr4[1] = tmp1[4]; + outptr5[1] = tmp1[5]; + outptr6[1] = tmp1[6]; + outptr7[1] = tmp1[7]; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp2[0]; + outptr1[2] = tmp2[1]; + outptr2[2] = tmp2[2]; + outptr3[2] = tmp2[3]; + outptr4[2] = tmp2[4]; + outptr5[2] = tmp2[5]; + outptr6[2] = tmp2[6]; + outptr7[2] = tmp2[7]; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp3[0]; + outptr1[3] = tmp3[1]; + outptr2[3] = tmp3[2]; + outptr3[3] = tmp3[3]; + outptr4[3] = tmp3[4]; + outptr5[3] = tmp3[5]; + outptr6[3] = tmp3[6]; + outptr7[3] = tmp3[7]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __AVX2__ + for (; ii + 3 < max_ii; ii += 4) + { +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + int tmp[4][6][4]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 4; + const int* r1 = r0 + max_jj * 4; + const int* r2 = r0 + max_jj * 4 * 2; + const int* r3 = r0 + max_jj * 4 * 3; + const int* r4 = r0 + max_jj * 4 * 4; + const int* r5 = r0 + max_jj * 4 * 5; + + for (int m = 0; m < 5; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + __m128i _r4 = _mm_load_si128((const __m128i*)r4); + __m128i _r5 = _mm_load_si128((const __m128i*)r5); + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _mm_slli_epi32(_r5, 2)); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + for (int m = 5; m < 6; m++) + { + __m128i _r0 = _mm_load_si128((const __m128i*)r0); + __m128i _r1 = _mm_load_si128((const __m128i*)r1); + __m128i _r2 = _mm_load_si128((const __m128i*)r2); + __m128i _r3 = _mm_load_si128((const __m128i*)r3); + __m128i _r4 = _mm_load_si128((const __m128i*)r4); + __m128i _r5 = _mm_load_si128((const __m128i*)r5); + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _mm_slli_epi32(_r5, 2)); + + _tmp0 = _mm_slli_epi32(_tmp0, 2); + _tmp1 = _mm_slli_epi32(_tmp1, 2); + _tmp2 = _mm_slli_epi32(_tmp2, 2); + _tmp3 = _mm_slli_epi32(_tmp3, 2); + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + _mm_storeu_si128((__m128i*)tmp[0][m], _tmp0); + _mm_storeu_si128((__m128i*)tmp[1][m], _tmp1); + _mm_storeu_si128((__m128i*)tmp[2][m], _tmp2); + _mm_storeu_si128((__m128i*)tmp[3][m], _tmp3); +#else + _mm_store_si128((__m128i*)tmp[0][m], _tmp0); + _mm_store_si128((__m128i*)tmp[1][m], _tmp1); + _mm_store_si128((__m128i*)tmp[2][m], _tmp2); + _mm_store_si128((__m128i*)tmp[3][m], _tmp3); +#endif + + r0 += max_jj * 6 * 4; + r1 += max_jj * 6 * 4; + r2 += max_jj * 6 * 4; + r3 += max_jj * 6 * 4; + r4 += max_jj * 6 * 4; + r5 += max_jj * 6 * 4; + } + + int* outptr0 = top_blob.channel((i + ii) / out_elempack).row(ti * 4) + (tj * 4) * out_elempack; + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + +#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 6) + __m128i _r0 = _mm_loadu_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_loadu_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_loadu_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_loadu_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_loadu_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_loadu_si128((const __m128i*)tmp[m][5]); +#else + __m128i _r0 = _mm_load_si128((const __m128i*)tmp[m][0]); + __m128i _r1 = _mm_load_si128((const __m128i*)tmp[m][1]); + __m128i _r2 = _mm_load_si128((const __m128i*)tmp[m][2]); + __m128i _r3 = _mm_load_si128((const __m128i*)tmp[m][3]); + __m128i _r4 = _mm_load_si128((const __m128i*)tmp[m][4]); + __m128i _r5 = _mm_load_si128((const __m128i*)tmp[m][5]); +#endif + + __m128i _tmp02a = _mm_add_epi32(_r1, _r2); + __m128i _tmp02b = _mm_add_epi32(_r3, _r4); + __m128i _tmp13a = _mm_sub_epi32(_r1, _r2); + __m128i _tmp13b = _mm_sub_epi32(_r3, _r4); + + __m128i _tmp0 = _mm_add_epi32(_mm_add_epi32(_tmp02a, _tmp02b), _r0); + __m128i _tmp1 = _mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 1)); + __m128i _tmp2 = _mm_add_epi32(_tmp02a, _mm_slli_epi32(_tmp02b, 2)); + __m128i _tmp3 = _mm_add_epi32(_mm_add_epi32(_tmp13a, _mm_slli_epi32(_tmp13b, 3)), _r5); + + // TODO use integer trick for division by 576 + __m128 _v576 = _mm_set1_ps(1.0 / 576); + _tmp0 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp0), _v576)); + _tmp1 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp1), _v576)); + _tmp2 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp2), _v576)); + _tmp3 = _mm_cvttps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(_tmp3), _v576)); + + if (out_elempack == 4) + { + _mm_store_si128((__m128i*)outptr0, _tmp0); + if (tj * 4 + 1 < outw) _mm_store_si128((__m128i*)(outptr0 + 4), _tmp1); + if (tj * 4 + 2 < outw) _mm_store_si128((__m128i*)(outptr0 + 8), _tmp2); + if (tj * 4 + 3 < outw) _mm_store_si128((__m128i*)(outptr0 + 12), _tmp3); + } + if (out_elempack == 1) + { +#if __AVX512F__ + __m128i _vindex = _mm_setr_epi32(0, 1, 2, 3); + _vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(N)); + _mm_i32scatter_epi32(outptr0, _vindex, _tmp0, sizeof(int)); + if (tj * 4 + 1 < outw) _mm_i32scatter_epi32(outptr0 + 1, _vindex, _tmp1, sizeof(int)); + if (tj * 4 + 2 < outw) _mm_i32scatter_epi32(outptr0 + 2, _vindex, _tmp2, sizeof(int)); + if (tj * 4 + 3 < outw) _mm_i32scatter_epi32(outptr0 + 3, _vindex, _tmp3, sizeof(int)); +#else + int tmp0[4]; + int tmp1[4]; + int tmp2[4]; + int tmp3[4]; + _mm_storeu_si128((__m128i*)tmp0, _tmp0); + _mm_storeu_si128((__m128i*)tmp1, _tmp1); + _mm_storeu_si128((__m128i*)tmp2, _tmp2); + _mm_storeu_si128((__m128i*)tmp3, _tmp3); + + int* outptr1 = outptr0 + N; + int* outptr2 = outptr0 + N * 2; + int* outptr3 = outptr0 + N * 3; + + outptr0[0] = tmp0[0]; + outptr1[0] = tmp0[1]; + outptr2[0] = tmp0[2]; + outptr3[0] = tmp0[3]; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp1[0]; + outptr1[1] = tmp1[1]; + outptr2[1] = tmp1[2]; + outptr3[1] = tmp1[3]; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp2[0]; + outptr1[2] = tmp2[1]; + outptr2[2] = tmp2[2]; + outptr3[2] = tmp2[3]; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp3[0]; + outptr1[3] = tmp3[1]; + outptr2[3] = tmp3[2]; + outptr3[3] = tmp3[3]; + } +#endif // __AVX512F__ + } + + outptr0 += outw * out_elempack; + } + } + } +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) + { + int tmp[4][6][2]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj * 2; + const int* r1 = r0 + max_jj * 2; + const int* r2 = r0 + max_jj * 2 * 2; + const int* r3 = r0 + max_jj * 2 * 3; + const int* r4 = r0 + max_jj * 2 * 4; + const int* r5 = r0 + max_jj * 2 * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + for (int m = 5; m < 6; m++) + { + int tmp02a0 = r1[0] + r2[0]; + int tmp02a1 = r1[1] + r2[1]; + int tmp02b0 = r3[0] + r4[0]; + int tmp02b1 = r3[1] + r4[1]; + int tmp13a0 = r1[0] - r2[0]; + int tmp13a1 = r1[1] - r2[1]; + int tmp13b0 = r3[0] - r4[0]; + int tmp13b1 = r3[1] - r4[1]; + + int tmp00 = tmp02a0 + tmp02b0 + r0[0]; + int tmp01 = tmp02a1 + tmp02b1 + r0[1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + r5[0] * 4; + int tmp31 = tmp13a1 + tmp13b1 * 8 + r5[1] * 4; + + tmp00 = tmp00 * 4; + tmp01 = tmp01 * 4; + tmp10 = tmp10 * 4; + tmp11 = tmp11 * 4; + tmp20 = tmp20 * 4; + tmp21 = tmp21 * 4; + tmp30 = tmp30 * 4; + tmp31 = tmp31 * 4; + + tmp[0][m][0] = tmp00; + tmp[0][m][1] = tmp01; + tmp[1][m][0] = tmp10; + tmp[1][m][1] = tmp11; + tmp[2][m][0] = tmp20; + tmp[2][m][1] = tmp21; + tmp[3][m][0] = tmp30; + tmp[3][m][1] = tmp31; + + r0 += max_jj * 6 * 2; + r1 += max_jj * 6 * 2; + r2 += max_jj * 6 * 2; + r3 += max_jj * 6 * 2; + r4 += max_jj * 6 * 2; + r5 += max_jj * 6 * 2; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a0 = tmp[m][1][0] + tmp[m][2][0]; + int tmp02a1 = tmp[m][1][1] + tmp[m][2][1]; + int tmp02b0 = tmp[m][3][0] + tmp[m][4][0]; + int tmp02b1 = tmp[m][3][1] + tmp[m][4][1]; + int tmp13a0 = tmp[m][1][0] - tmp[m][2][0]; + int tmp13a1 = tmp[m][1][1] - tmp[m][2][1]; + int tmp13b0 = tmp[m][3][0] - tmp[m][4][0]; + int tmp13b1 = tmp[m][3][1] - tmp[m][4][1]; + + int tmp00 = tmp02a0 + tmp02b0 + tmp[m][0][0]; + int tmp01 = tmp02a1 + tmp02b1 + tmp[m][0][1]; + int tmp10 = tmp13a0 + tmp13b0 * 2; + int tmp11 = tmp13a1 + tmp13b1 * 2; + int tmp20 = tmp02a0 + tmp02b0 * 4; + int tmp21 = tmp02a1 + tmp02b1 * 4; + int tmp30 = tmp13a0 + tmp13b0 * 8 + tmp[m][5][0]; + int tmp31 = tmp13a1 + tmp13b1 * 8 + tmp[m][5][1]; + + tmp00 = tmp00 / 576; + tmp01 = tmp01 / 576; + tmp10 = tmp10 / 576; + tmp11 = tmp11 / 576; + tmp20 = tmp20 / 576; + tmp21 = tmp21 / 576; + tmp30 = tmp30 / 576; + tmp31 = tmp31 / 576; + + // if (out_elempack == 1) + { + int* outptr1 = outptr0 + N; + + outptr0[0] = tmp00; + outptr1[0] = tmp01; + if (tj * 4 + 1 < outw) + { + outptr0[1] = tmp10; + outptr1[1] = tmp11; + } + if (tj * 4 + 2 < outw) + { + outptr0[2] = tmp20; + outptr1[2] = tmp21; + } + if (tj * 4 + 3 < outw) + { + outptr0[3] = tmp30; + outptr1[3] = tmp31; + } + } + + outptr0 += outw; + } + } + } + for (; ii < max_ii; ii++) + { + int tmp[4][6]; + + int jj = 0; + for (; jj < max_jj; jj++) + { + int ti = (j + jj) / w_tiles; + int tj = (j + jj) % w_tiles; + + const int* r0 = (const int*)top_tile + ii * max_jj * 36 + jj; + const int* r1 = r0 + max_jj; + const int* r2 = r0 + max_jj * 2; + const int* r3 = r0 + max_jj * 3; + const int* r4 = r0 + max_jj * 4; + const int* r5 = r0 + max_jj * 5; + + for (int m = 0; m < 5; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + for (int m = 5; m < 6; m++) + { + int tmp02a = r1[0] + r2[0]; + int tmp02b = r3[0] + r4[0]; + int tmp13a = r1[0] - r2[0]; + int tmp13b = r3[0] - r4[0]; + + int tmp0 = tmp02a + tmp02b + r0[0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + r5[0] * 4; + + tmp0 = tmp0 * 4; + tmp1 = tmp1 * 4; + tmp2 = tmp2 * 4; + tmp3 = tmp3 * 4; + + tmp[0][m] = tmp0; + tmp[1][m] = tmp1; + tmp[2][m] = tmp2; + tmp[3][m] = tmp3; + + r0 += max_jj * 6; + r1 += max_jj * 6; + r2 += max_jj * 6; + r3 += max_jj * 6; + r4 += max_jj * 6; + r5 += max_jj * 6; + } + + int* outptr0 = top_blob.channel(i + ii).row(ti * 4) + (tj * 4); + + for (int m = 0; m < 4; m++) + { + if (ti * 4 + m >= outh) + continue; + + int tmp02a = tmp[m][1] + tmp[m][2]; + int tmp02b = tmp[m][3] + tmp[m][4]; + int tmp13a = tmp[m][1] - tmp[m][2]; + int tmp13b = tmp[m][3] - tmp[m][4]; + + int tmp0 = tmp02a + tmp02b + tmp[m][0]; + int tmp1 = tmp13a + tmp13b * 2; + int tmp2 = tmp02a + tmp02b * 4; + int tmp3 = tmp13a + tmp13b * 8 + tmp[m][5]; + + tmp0 = tmp0 / 576; + tmp1 = tmp1 / 576; + tmp2 = tmp2 / 576; + tmp3 = tmp3 / 576; + + // if (out_elempack == 1) + { + outptr0[0] = tmp0; + if (tj * 4 + 1 < outw) outptr0[1] = tmp1; + if (tj * 4 + 2 < outw) outptr0[2] = tmp2; + if (tj * 4 + 3 < outw) outptr0[3] = tmp3; + } + + outptr0 += outw; + } + } + } +} + +static void conv3x3s1_winograd43_int8(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) +{ +#if !(__AVX512VNNI__ || __AVXVNNI__ || __AVX2__ || __XOP__) +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + conv3x3s1_winograd43_int8_avx512vnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + conv3x3s1_winograd43_int8_avxvnni(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + conv3x3s1_winograd43_int8_avx2(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ + if (ncnn::cpu_support_x86_xop()) + { + conv3x3s1_winograd43_int8_xop(bottom_blob, top_blob, AT, nT, opt); + return; + } +#endif +#endif + + int outw = top_blob.w; + int outh = top_blob.h; + + // pad to 4n+2, winograd F(4,3) + int w_tiles = (outw + 3) / 4; + int h_tiles = (outh + 3) / 4; + int tiles = w_tiles * h_tiles; + + const int M = top_blob.c * top_blob.elempack; + const int N = tiles; + const int K = bottom_blob.c * bottom_blob.elempack; + const int B = 36; + + // NCNN_LOGE("conv3x3s1_winograd43_int8 %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk_int8(M, N, K, TILE_M, TILE_N, TILE_K, nT); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + const int nn_N = (N + TILE_N - 1) / TILE_N; + const int nn_K = (K + TILE_K - 1) / TILE_K; + + // NCNN_LOGE("TILE M/N/K = %d %d %d -> %d %d %d", M, N, K, TILE_M, TILE_N, TILE_K); + + Mat BT(TILE_K * TILE_N, B, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + if (nT > 1 && nn_NK < nT) + { + Mat B_tile(TILE_N * B * TILE_K, 4u, opt.workspace_allocator); + + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, nT); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, nT); + } + } + else + { + Mat B_tileX(TILE_N * B * TILE_K, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat B_tile = B_tileX.channel(get_omp_thread_num()); + + // transform input + conv3x3s1_winograd43_transform_input_tile_int8(bottom_blob, B_tile, j, max_jj, k, max_kk, 1); + + Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + transpose_pack_B_tile_int8(B_tile, BT_tile, B, max_jj, max_kk, 1); + } + } + + Mat top_tileX(TILE_N * B * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + Mat top_tile = top_tileX.channel(get_omp_thread_num()); + + const int max_ii = std::min((M - i), TILE_M); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + const Mat AT_tile = AT.channel(i / TILE_M).depth(k / TILE_K); + + const Mat BT_tile = BT.channel(j / TILE_N).depth(k / TILE_K); + + bool k_end = k + TILE_K >= K; + + gemm_transB_packed_tile_int8(AT_tile, BT_tile, top_tile, B, max_ii, max_jj, k, max_kk, k_end); + } + + // transform output + conv3x3s1_winograd43_transform_output_tile_int8(top_tile, top_blob, i, max_ii, j, max_jj); + } + } +} diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index 56a4aa4763af..da504677a68f 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -934,7 +934,6 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M _sumf = _mm512_shuffle_epi32(_sumf, _MM_PERM_CBAD); } - // TODO __m512i _tmp0 = _mm512_shuffle_i32x4(_sum0, _sumc, _MM_SHUFFLE(2, 0, 2, 0)); __m512i _tmp1 = _mm512_shuffle_i32x4(_sum4, _sum0, _MM_SHUFFLE(3, 1, 3, 1)); __m512i _tmp2 = _mm512_shuffle_i32x4(_sum8, _sum4, _MM_SHUFFLE(0, 2, 0, 2)); @@ -2547,12 +2546,12 @@ static void convolution_gemm_transB_packed_tile_int8(const Mat& AT_tile, const M __m512i _tmp7 = _mm512_shuffle_i32x4(_sum3, _sum7, _MM_SHUFFLE(3, 2, 3, 2)); _sum0 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); _sum1 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(2, 0, 2, 0)); - _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 3, 1)); - _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(3, 1, 3, 1)); + _sum2 = _mm512_shuffle_i32x4(_tmp0, _tmp2, _MM_SHUFFLE(1, 3, 1, 3)); + _sum3 = _mm512_shuffle_i32x4(_tmp4, _tmp6, _MM_SHUFFLE(1, 3, 1, 3)); _sum4 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); _sum5 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(2, 0, 2, 0)); - _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); - _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(3, 1, 3, 1)); + _sum6 = _mm512_shuffle_i32x4(_tmp1, _tmp3, _MM_SHUFFLE(1, 3, 1, 3)); + _sum7 = _mm512_shuffle_i32x4(_tmp5, _tmp7, _MM_SHUFFLE(1, 3, 1, 3)); _mm512_storeu_si512((__m512i*)outptr0, _sum0); _mm512_storeu_si512((__m512i*)(outptr0 + 16), _sum1); @@ -6142,14 +6141,13 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo } } -static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +template +#if __AVX512F__ +void convolution_im2col_input_tile_int8_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else // __AVX512F__ +void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif // __AVX512F__ { - if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { - convolution_im2col_input_tile_conv1x1s1d1_int8(bottom_blob, B, j, max_jj, k, max_kk); - return; - } - const int w = bottom_blob.w; // const int channels = bottom_blob.c; const int elempack = bottom_blob.elempack; @@ -6206,288 +6204,468 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dxe = (j + jj + 14) % outw; int dxf = (j + jj + 15) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dyf) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int x08 = stride_w * dx8 + dilation_w * v0; - int x09 = stride_w * dx9 + dilation_w * v0; - int x0a = stride_w * dxa + dilation_w * v0; - int x0b = stride_w * dxb + dilation_w * v0; - int x0c = stride_w * dxc + dilation_w * v0; - int x0d = stride_w * dxd + dilation_w * v0; - int x0e = stride_w * dxe + dilation_w * v0; - int x0f = stride_w * dxf + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - int y08 = stride_h * dy8 + dilation_h * u0; - int y09 = stride_h * dy9 + dilation_h * u0; - int y0a = stride_h * dya + dilation_h * u0; - int y0b = stride_h * dyb + dilation_h * u0; - int y0c = stride_h * dyc + dilation_h * u0; - int y0d = stride_h * dyd + dilation_h * u0; - int y0e = stride_h * dye + dilation_h * u0; - int y0f = stride_h * dyf + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int x18 = stride_w * dx8 + dilation_w * v1; - int x19 = stride_w * dx9 + dilation_w * v1; - int x1a = stride_w * dxa + dilation_w * v1; - int x1b = stride_w * dxb + dilation_w * v1; - int x1c = stride_w * dxc + dilation_w * v1; - int x1d = stride_w * dxd + dilation_w * v1; - int x1e = stride_w * dxe + dilation_w * v1; - int x1f = stride_w * dxf + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - int y18 = stride_h * dy8 + dilation_h * u1; - int y19 = stride_h * dy9 + dilation_h * u1; - int y1a = stride_h * dya + dilation_h * u1; - int y1b = stride_h * dyb + dilation_h * u1; - int y1c = stride_h * dyc + dilation_h * u1; - int y1d = stride_h * dyd + dilation_h * u1; - int y1e = stride_h * dye + dilation_h * u1; - int y1f = stride_h * dyf + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - const signed char* sptr08 = img0.row(y08) + x08; - const signed char* sptr09 = img0.row(y09) + x09; - const signed char* sptr0a = img0.row(y0a) + x0a; - const signed char* sptr0b = img0.row(y0b) + x0b; - const signed char* sptr0c = img0.row(y0c) + x0c; - const signed char* sptr0d = img0.row(y0d) + x0d; - const signed char* sptr0e = img0.row(y0e) + x0e; - const signed char* sptr0f = img0.row(y0f) + x0f; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - const signed char* sptr18 = img1.row(y18) + x18; - const signed char* sptr19 = img1.row(y19) + x19; - const signed char* sptr1a = img1.row(y1a) + x1a; - const signed char* sptr1b = img1.row(y1b) + x1b; - const signed char* sptr1c = img1.row(y1c) + x1c; - const signed char* sptr1d = img1.row(y1d) + x1d; - const signed char* sptr1e = img1.row(y1e) + x1e; - const signed char* sptr1f = img1.row(y1f) + x1f; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp[16 + 0] = sptr08[0]; - pp[16 + 1] = sptr18[0]; - pp[16 + 2] = sptr09[0]; - pp[16 + 3] = sptr19[0]; - pp[16 + 4] = sptr0a[0]; - pp[16 + 5] = sptr1a[0]; - pp[16 + 6] = sptr0b[0]; - pp[16 + 7] = sptr1b[0]; - pp[16 + 8] = sptr0c[0]; - pp[16 + 9] = sptr1c[0]; - pp[16 + 10] = sptr0d[0]; - pp[16 + 11] = sptr1d[0]; - pp[16 + 12] = sptr0e[0]; - pp[16 + 13] = sptr1e[0]; - pp[16 + 14] = sptr0f[0]; - pp[16 + 15] = sptr1f[0]; - pp += 32; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _mm_store_si128((__m128i*)pp, _tmp0); + _mm_store_si128((__m128i*)(pp + 16), _tmp1); + pp += 32; + } + else if (stride_w == 2) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); + __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _mm256_storeu_si256((__m256i*)pp, _r01); + pp += 32; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp[16 + 0] = sptr0[stride_w * 8]; + pp[16 + 1] = sptr1[stride_w * 8]; + pp[16 + 2] = sptr0[stride_w * 9]; + pp[16 + 3] = sptr1[stride_w * 9]; + pp[16 + 4] = sptr0[stride_w * 10]; + pp[16 + 5] = sptr1[stride_w * 10]; + pp[16 + 6] = sptr0[stride_w * 11]; + pp[16 + 7] = sptr1[stride_w * 11]; + pp[16 + 8] = sptr0[stride_w * 12]; + pp[16 + 9] = sptr1[stride_w * 12]; + pp[16 + 10] = sptr0[stride_w * 13]; + pp[16 + 11] = sptr1[stride_w * 13]; + pp[16 + 12] = sptr0[stride_w * 14]; + pp[16 + 13] = sptr1[stride_w * 14]; + pp[16 + 14] = sptr0[stride_w * 15]; + pp[16 + 15] = sptr1[stride_w * 15]; + pp += 32; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int x8 = stride_w * dx8 + dilation_w * v; - int x9 = stride_w * dx9 + dilation_w * v; - int xa = stride_w * dxa + dilation_w * v; - int xb = stride_w * dxb + dilation_w * v; - int xc = stride_w * dxc + dilation_w * v; - int xd = stride_w * dxd + dilation_w * v; - int xe = stride_w * dxe + dilation_w * v; - int xf = stride_w * dxf + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - int y8 = stride_h * dy8 + dilation_h * u; - int y9 = stride_h * dy9 + dilation_h * u; - int ya = stride_h * dya + dilation_h * u; - int yb = stride_h * dyb + dilation_h * u; - int yc = stride_h * dyc + dilation_h * u; - int yd = stride_h * dyd + dilation_h * u; - int ye = stride_h * dye + dilation_h * u; - int yf = stride_h * dyf + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - const signed char* sptr8 = img.row(y8) + x8 * elempack; - const signed char* sptr9 = img.row(y9) + x9 * elempack; - const signed char* sptra = img.row(ya) + xa * elempack; - const signed char* sptrb = img.row(yb) + xb * elempack; - const signed char* sptrc = img.row(yc) + xc * elempack; - const signed char* sptrd = img.row(yd) + xd * elempack; - const signed char* sptre = img.row(ye) + xe * elempack; - const signed char* sptrf = img.row(yf) + xf * elempack; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); - __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); - __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); - __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); - __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); - __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); - __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); - __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); - __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); - __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); - __m128i _ref = _mm_unpacklo_epi16(_re, _rf); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi32(_r89, _rab); - _r5 = _mm_unpackhi_epi32(_r89, _rab); - _r6 = _mm_unpacklo_epi32(_rcd, _ref); - _r7 = _mm_unpackhi_epi32(_rcd, _ref); - _r8 = _mm_unpacklo_epi64(_r0, _r2); - _r9 = _mm_unpacklo_epi64(_r4, _r6); - _ra = _mm_unpackhi_epi64(_r0, _r2); - _rb = _mm_unpackhi_epi64(_r4, _r6); - _rc = _mm_unpacklo_epi64(_r1, _r3); - _rd = _mm_unpacklo_epi64(_r5, _r7); - _re = _mm_unpackhi_epi64(_r1, _r3); - _rf = _mm_unpackhi_epi64(_r5, _r7); - _mm_storeu_si128((__m128i*)pp, _r8); - _mm_storeu_si128((__m128i*)(pp + 16), _r9); - _mm_storeu_si128((__m128i*)(pp + 32), _ra); - _mm_storeu_si128((__m128i*)(pp + 48), _rb); - _mm_storeu_si128((__m128i*)(pp + 64), _rc); - _mm_storeu_si128((__m128i*)(pp + 80), _rd); - _mm_storeu_si128((__m128i*)(pp + 96), _re); - _mm_storeu_si128((__m128i*)(pp + 112), _rf); - pp += 128; + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); + __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); + __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); + __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); + __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); + __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); + __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp[8] = sptr[stride_w * 8]; + pp[9] = sptr[stride_w * 9]; + pp[10] = sptr[stride_w * 10]; + pp[11] = sptr[stride_w * 11]; + pp[12] = sptr[stride_w * 12]; + pp[13] = sptr[stride_w * 13]; + pp[14] = sptr[stride_w * 14]; + pp[15] = sptr[stride_w * 15]; + pp += 16; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr8[0]; - pp[9] = sptr9[0]; - pp[10] = sptra[0]; - pp[11] = sptrb[0]; - pp[12] = sptrc[0]; - pp[13] = sptrd[0]; - pp[14] = sptre[0]; - pp[15] = sptrf[0]; - pp += 16; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int x08 = stride_w * dx8 + dilation_w * v0; + int x09 = stride_w * dx9 + dilation_w * v0; + int x0a = stride_w * dxa + dilation_w * v0; + int x0b = stride_w * dxb + dilation_w * v0; + int x0c = stride_w * dxc + dilation_w * v0; + int x0d = stride_w * dxd + dilation_w * v0; + int x0e = stride_w * dxe + dilation_w * v0; + int x0f = stride_w * dxf + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + int y08 = stride_h * dy8 + dilation_h * u0; + int y09 = stride_h * dy9 + dilation_h * u0; + int y0a = stride_h * dya + dilation_h * u0; + int y0b = stride_h * dyb + dilation_h * u0; + int y0c = stride_h * dyc + dilation_h * u0; + int y0d = stride_h * dyd + dilation_h * u0; + int y0e = stride_h * dye + dilation_h * u0; + int y0f = stride_h * dyf + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int x18 = stride_w * dx8 + dilation_w * v1; + int x19 = stride_w * dx9 + dilation_w * v1; + int x1a = stride_w * dxa + dilation_w * v1; + int x1b = stride_w * dxb + dilation_w * v1; + int x1c = stride_w * dxc + dilation_w * v1; + int x1d = stride_w * dxd + dilation_w * v1; + int x1e = stride_w * dxe + dilation_w * v1; + int x1f = stride_w * dxf + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + int y18 = stride_h * dy8 + dilation_h * u1; + int y19 = stride_h * dy9 + dilation_h * u1; + int y1a = stride_h * dya + dilation_h * u1; + int y1b = stride_h * dyb + dilation_h * u1; + int y1c = stride_h * dyc + dilation_h * u1; + int y1d = stride_h * dyd + dilation_h * u1; + int y1e = stride_h * dye + dilation_h * u1; + int y1f = stride_h * dyf + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + const signed char* sptr08 = img0.row(y08) + x08; + const signed char* sptr09 = img0.row(y09) + x09; + const signed char* sptr0a = img0.row(y0a) + x0a; + const signed char* sptr0b = img0.row(y0b) + x0b; + const signed char* sptr0c = img0.row(y0c) + x0c; + const signed char* sptr0d = img0.row(y0d) + x0d; + const signed char* sptr0e = img0.row(y0e) + x0e; + const signed char* sptr0f = img0.row(y0f) + x0f; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + const signed char* sptr18 = img1.row(y18) + x18; + const signed char* sptr19 = img1.row(y19) + x19; + const signed char* sptr1a = img1.row(y1a) + x1a; + const signed char* sptr1b = img1.row(y1b) + x1b; + const signed char* sptr1c = img1.row(y1c) + x1c; + const signed char* sptr1d = img1.row(y1d) + x1d; + const signed char* sptr1e = img1.row(y1e) + x1e; + const signed char* sptr1f = img1.row(y1f) + x1f; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp[16 + 0] = sptr08[0]; + pp[16 + 1] = sptr18[0]; + pp[16 + 2] = sptr09[0]; + pp[16 + 3] = sptr19[0]; + pp[16 + 4] = sptr0a[0]; + pp[16 + 5] = sptr1a[0]; + pp[16 + 6] = sptr0b[0]; + pp[16 + 7] = sptr1b[0]; + pp[16 + 8] = sptr0c[0]; + pp[16 + 9] = sptr1c[0]; + pp[16 + 10] = sptr0d[0]; + pp[16 + 11] = sptr1d[0]; + pp[16 + 12] = sptr0e[0]; + pp[16 + 13] = sptr1e[0]; + pp[16 + 14] = sptr0f[0]; + pp[16 + 15] = sptr1f[0]; + pp += 32; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int x8 = stride_w * dx8 + dilation_w * v; + int x9 = stride_w * dx9 + dilation_w * v; + int xa = stride_w * dxa + dilation_w * v; + int xb = stride_w * dxb + dilation_w * v; + int xc = stride_w * dxc + dilation_w * v; + int xd = stride_w * dxd + dilation_w * v; + int xe = stride_w * dxe + dilation_w * v; + int xf = stride_w * dxf + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + int y8 = stride_h * dy8 + dilation_h * u; + int y9 = stride_h * dy9 + dilation_h * u; + int ya = stride_h * dya + dilation_h * u; + int yb = stride_h * dyb + dilation_h * u; + int yc = stride_h * dyc + dilation_h * u; + int yd = stride_h * dyd + dilation_h * u; + int ye = stride_h * dye + dilation_h * u; + int yf = stride_h * dyf + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + const signed char* sptr8 = img.row(y8) + x8 * elempack; + const signed char* sptr9 = img.row(y9) + x9 * elempack; + const signed char* sptra = img.row(ya) + xa * elempack; + const signed char* sptrb = img.row(yb) + xb * elempack; + const signed char* sptrc = img.row(yc) + xc * elempack; + const signed char* sptrd = img.row(yd) + xd * elempack; + const signed char* sptre = img.row(ye) + xe * elempack; + const signed char* sptrf = img.row(yf) + xf * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); + __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); + __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); + __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); + __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); + __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); + __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp[8] = sptr8[0]; + pp[9] = sptr9[0]; + pp[10] = sptra[0]; + pp[11] = sptrb[0]; + pp[12] = sptrc[0]; + pp[13] = sptrd[0]; + pp[14] = sptre[0]; + pp[15] = sptrf[0]; + pp += 16; + } } } } @@ -6511,168 +6689,298 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx6 = (j + jj + 6) % outw; int dx7 = (j + jj + 7) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy7) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp += 16; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp += 16; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi64(_r0, _r2); - _r5 = _mm_unpackhi_epi64(_r0, _r2); - _r6 = _mm_unpacklo_epi64(_r1, _r3); - _r7 = _mm_unpackhi_epi64(_r1, _r3); - _mm_storeu_si128((__m128i*)pp, _r4); - _mm_storeu_si128((__m128i*)(pp + 16), _r5); - _mm_storeu_si128((__m128i*)(pp + 32), _r6); - _mm_storeu_si128((__m128i*)(pp + 48), _r7); - pp += 64; + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp += 8; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp += 8; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp += 16; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp += 8; + } } } } @@ -6688,106 +6996,206 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx2 = (j + jj + 2) % outw; int dx3 = (j + jj + 3) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy3) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp += 8; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp += 8; + } + } } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; - const Mat img = bottom_blob.channel(p); + const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr = img.row(y0) + x0 * elempack; - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _mm_storeu_si128((__m128i*)pp, _r0); - _mm_storeu_si128((__m128i*)(pp + 16), _r1); - pp += 32; + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp += 4; + } } + } + else + { + int kk = 0; if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp += 4; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp += 8; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp += 4; + } } } } @@ -6799,44 +7207,154 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i int dx0 = (j + jj) % outw; int dx1 = (j + jj + 1) % outw; - int kk = 0; - if (elempack == 1) + if (dy0 == dy1) { - for (; kk + 1 < max_kk; kk += 2) + int kk = 0; + if (elempack == 1) { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp += 4; + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp += 2; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp += 2; + } } } + } + for (; jj < max_jj; jj++) + { + int dy = (j + jj) / outw; + int dx = (j + jj) % outw; + + int kk = 0; for (; kk < max_kk / elempack; kk++) { int p = (k / elempack + kk) / maxk; @@ -6846,29 +7364,1309 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i const Mat img = bottom_blob.channel(p); - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; + int x = stride_w * dx + dilation_w * v; + int y = stride_h * dy + dilation_h * u; - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr = img.row(y) + x * elempack; #if __SSE2__ if (elempack == 8) { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; + _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)sptr)); + pp += 8; } #endif // __SSE2__ if (elempack == 1) { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp += 2; + pp[0] = sptr[0]; + pp += 1; + } + } + } +} + +#if __AVX512F__ +template void convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8_avx512<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#else // __AVX512F__ +template void convolution_im2col_input_tile_int8<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<3, 3, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<5, 5, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<5, 5, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +template void convolution_im2col_input_tile_int8<7, 7, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); +#endif // __AVX512F__ + +static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) +{ + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_input_tile_conv1x1s1d1_int8(bottom_blob, B, j, max_jj, k, max_kk); + return; + } + + if (kernel_w == 1 && kernel_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<1, 1, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<3, 3, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<3, 3, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<5, 5, 1, 1, 1, 1>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 5 && kernel_h == 5 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<5, 5, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + if (kernel_w == 7 && kernel_h == 7 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { +#if __AVX512F__ + convolution_im2col_input_tile_int8_avx512<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#else // __AVX512F__ + convolution_im2col_input_tile_int8<7, 7, 1, 1, 2, 2>(bottom_blob, B, j, max_jj, k, max_kk); +#endif // __AVX512F__ + return; + } + + const int w = bottom_blob.w; + // const int channels = bottom_blob.c; + const int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int outw = (w - kernel_extent_w) / stride_w + 1; + + // j max_jj outw*outh split w and h + + // k max_kk pa*maxk*(inch/pa) split inch + + // k/max_kk shall be multiple of maxk + + const int maxk = kernel_w * kernel_h; + + signed char* pp = B; + + int jj = 0; +#if __SSE2__ +#if defined(__x86_64__) || defined(_M_X64) +#if __AVX512F__ + for (; jj + 15 < max_jj; jj += 16) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dy8 = (j + jj + 8) / outw; + int dy9 = (j + jj + 9) / outw; + int dya = (j + jj + 10) / outw; + int dyb = (j + jj + 11) / outw; + int dyc = (j + jj + 12) / outw; + int dyd = (j + jj + 13) / outw; + int dye = (j + jj + 14) / outw; + int dyf = (j + jj + 15) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + int dx8 = (j + jj + 8) % outw; + int dx9 = (j + jj + 9) % outw; + int dxa = (j + jj + 10) % outw; + int dxb = (j + jj + 11) % outw; + int dxc = (j + jj + 12) % outw; + int dxd = (j + jj + 13) % outw; + int dxe = (j + jj + 14) % outw; + int dxf = (j + jj + 15) % outw; + + if (dy0 == dyf) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _mm_store_si128((__m128i*)pp, _tmp0); + _mm_store_si128((__m128i*)(pp + 16), _tmp1); + pp += 32; + } + else if (stride_w == 2) + { + __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); + __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); + __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); + _mm256_storeu_si256((__m256i*)pp, _r01); + pp += 32; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp[16 + 0] = sptr0[stride_w * 8]; + pp[16 + 1] = sptr1[stride_w * 8]; + pp[16 + 2] = sptr0[stride_w * 9]; + pp[16 + 3] = sptr1[stride_w * 9]; + pp[16 + 4] = sptr0[stride_w * 10]; + pp[16 + 5] = sptr1[stride_w * 10]; + pp[16 + 6] = sptr0[stride_w * 11]; + pp[16 + 7] = sptr1[stride_w * 11]; + pp[16 + 8] = sptr0[stride_w * 12]; + pp[16 + 9] = sptr1[stride_w * 12]; + pp[16 + 10] = sptr0[stride_w * 13]; + pp[16 + 11] = sptr1[stride_w * 13]; + pp[16 + 12] = sptr0[stride_w * 14]; + pp[16 + 13] = sptr1[stride_w * 14]; + pp[16 + 14] = sptr0[stride_w * 15]; + pp[16 + 15] = sptr1[stride_w * 15]; + pp += 32; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); + __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); + __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); + __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); + __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); + __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); + __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp[8] = sptr[stride_w * 8]; + pp[9] = sptr[stride_w * 9]; + pp[10] = sptr[stride_w * 10]; + pp[11] = sptr[stride_w * 11]; + pp[12] = sptr[stride_w * 12]; + pp[13] = sptr[stride_w * 13]; + pp[14] = sptr[stride_w * 14]; + pp[15] = sptr[stride_w * 15]; + pp += 16; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int x08 = stride_w * dx8 + dilation_w * v0; + int x09 = stride_w * dx9 + dilation_w * v0; + int x0a = stride_w * dxa + dilation_w * v0; + int x0b = stride_w * dxb + dilation_w * v0; + int x0c = stride_w * dxc + dilation_w * v0; + int x0d = stride_w * dxd + dilation_w * v0; + int x0e = stride_w * dxe + dilation_w * v0; + int x0f = stride_w * dxf + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + int y08 = stride_h * dy8 + dilation_h * u0; + int y09 = stride_h * dy9 + dilation_h * u0; + int y0a = stride_h * dya + dilation_h * u0; + int y0b = stride_h * dyb + dilation_h * u0; + int y0c = stride_h * dyc + dilation_h * u0; + int y0d = stride_h * dyd + dilation_h * u0; + int y0e = stride_h * dye + dilation_h * u0; + int y0f = stride_h * dyf + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int x18 = stride_w * dx8 + dilation_w * v1; + int x19 = stride_w * dx9 + dilation_w * v1; + int x1a = stride_w * dxa + dilation_w * v1; + int x1b = stride_w * dxb + dilation_w * v1; + int x1c = stride_w * dxc + dilation_w * v1; + int x1d = stride_w * dxd + dilation_w * v1; + int x1e = stride_w * dxe + dilation_w * v1; + int x1f = stride_w * dxf + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + int y18 = stride_h * dy8 + dilation_h * u1; + int y19 = stride_h * dy9 + dilation_h * u1; + int y1a = stride_h * dya + dilation_h * u1; + int y1b = stride_h * dyb + dilation_h * u1; + int y1c = stride_h * dyc + dilation_h * u1; + int y1d = stride_h * dyd + dilation_h * u1; + int y1e = stride_h * dye + dilation_h * u1; + int y1f = stride_h * dyf + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + const signed char* sptr08 = img0.row(y08) + x08; + const signed char* sptr09 = img0.row(y09) + x09; + const signed char* sptr0a = img0.row(y0a) + x0a; + const signed char* sptr0b = img0.row(y0b) + x0b; + const signed char* sptr0c = img0.row(y0c) + x0c; + const signed char* sptr0d = img0.row(y0d) + x0d; + const signed char* sptr0e = img0.row(y0e) + x0e; + const signed char* sptr0f = img0.row(y0f) + x0f; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + const signed char* sptr18 = img1.row(y18) + x18; + const signed char* sptr19 = img1.row(y19) + x19; + const signed char* sptr1a = img1.row(y1a) + x1a; + const signed char* sptr1b = img1.row(y1b) + x1b; + const signed char* sptr1c = img1.row(y1c) + x1c; + const signed char* sptr1d = img1.row(y1d) + x1d; + const signed char* sptr1e = img1.row(y1e) + x1e; + const signed char* sptr1f = img1.row(y1f) + x1f; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp[16 + 0] = sptr08[0]; + pp[16 + 1] = sptr18[0]; + pp[16 + 2] = sptr09[0]; + pp[16 + 3] = sptr19[0]; + pp[16 + 4] = sptr0a[0]; + pp[16 + 5] = sptr1a[0]; + pp[16 + 6] = sptr0b[0]; + pp[16 + 7] = sptr1b[0]; + pp[16 + 8] = sptr0c[0]; + pp[16 + 9] = sptr1c[0]; + pp[16 + 10] = sptr0d[0]; + pp[16 + 11] = sptr1d[0]; + pp[16 + 12] = sptr0e[0]; + pp[16 + 13] = sptr1e[0]; + pp[16 + 14] = sptr0f[0]; + pp[16 + 15] = sptr1f[0]; + pp += 32; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int x8 = stride_w * dx8 + dilation_w * v; + int x9 = stride_w * dx9 + dilation_w * v; + int xa = stride_w * dxa + dilation_w * v; + int xb = stride_w * dxb + dilation_w * v; + int xc = stride_w * dxc + dilation_w * v; + int xd = stride_w * dxd + dilation_w * v; + int xe = stride_w * dxe + dilation_w * v; + int xf = stride_w * dxf + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + int y8 = stride_h * dy8 + dilation_h * u; + int y9 = stride_h * dy9 + dilation_h * u; + int ya = stride_h * dya + dilation_h * u; + int yb = stride_h * dyb + dilation_h * u; + int yc = stride_h * dyc + dilation_h * u; + int yd = stride_h * dyd + dilation_h * u; + int ye = stride_h * dye + dilation_h * u; + int yf = stride_h * dyf + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + const signed char* sptr8 = img.row(y8) + x8 * elempack; + const signed char* sptr9 = img.row(y9) + x9 * elempack; + const signed char* sptra = img.row(ya) + xa * elempack; + const signed char* sptrb = img.row(yb) + xb * elempack; + const signed char* sptrc = img.row(yc) + xc * elempack; + const signed char* sptrd = img.row(yd) + xd * elempack; + const signed char* sptre = img.row(ye) + xe * elempack; + const signed char* sptrf = img.row(yf) + xf * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); + __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); + __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); + __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); + __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); + __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); + __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); + __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); + __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); + __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); + __m128i _ref = _mm_unpacklo_epi16(_re, _rf); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi32(_r89, _rab); + _r5 = _mm_unpackhi_epi32(_r89, _rab); + _r6 = _mm_unpacklo_epi32(_rcd, _ref); + _r7 = _mm_unpackhi_epi32(_rcd, _ref); + _r8 = _mm_unpacklo_epi64(_r0, _r2); + _r9 = _mm_unpacklo_epi64(_r4, _r6); + _ra = _mm_unpackhi_epi64(_r0, _r2); + _rb = _mm_unpackhi_epi64(_r4, _r6); + _rc = _mm_unpacklo_epi64(_r1, _r3); + _rd = _mm_unpacklo_epi64(_r5, _r7); + _re = _mm_unpackhi_epi64(_r1, _r3); + _rf = _mm_unpackhi_epi64(_r5, _r7); + _mm_store_si128((__m128i*)pp, _r8); + _mm_store_si128((__m128i*)(pp + 16), _r9); + _mm_store_si128((__m128i*)(pp + 32), _ra); + _mm_store_si128((__m128i*)(pp + 48), _rb); + _mm_store_si128((__m128i*)(pp + 64), _rc); + _mm_store_si128((__m128i*)(pp + 80), _rd); + _mm_store_si128((__m128i*)(pp + 96), _re); + _mm_store_si128((__m128i*)(pp + 112), _rf); + pp += 128; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp[8] = sptr8[0]; + pp[9] = sptr9[0]; + pp[10] = sptra[0]; + pp[11] = sptrb[0]; + pp[12] = sptrc[0]; + pp[13] = sptrd[0]; + pp[14] = sptre[0]; + pp[15] = sptrf[0]; + pp += 16; + } + } + } + } +#endif // __AVX512F__ + for (; jj + 7 < max_jj; jj += 8) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dy4 = (j + jj + 4) / outw; + int dy5 = (j + jj + 5) / outw; + int dy6 = (j + jj + 6) / outw; + int dy7 = (j + jj + 7) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + int dx4 = (j + jj + 4) % outw; + int dx5 = (j + jj + 5) % outw; + int dx6 = (j + jj + 6) % outw; + int dx7 = (j + jj + 7) % outw; + + if (dy0 == dy7) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); + __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); + _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp[8] = sptr0[stride_w * 4]; + pp[9] = sptr1[stride_w * 4]; + pp[10] = sptr0[stride_w * 5]; + pp[11] = sptr1[stride_w * 5]; + pp[12] = sptr0[stride_w * 6]; + pp[13] = sptr1[stride_w * 6]; + pp[14] = sptr0[stride_w * 7]; + pp[15] = sptr1[stride_w * 7]; + pp += 16; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp[4] = sptr[stride_w * 4]; + pp[5] = sptr[stride_w * 5]; + pp[6] = sptr[stride_w * 6]; + pp[7] = sptr[stride_w * 7]; + pp += 8; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int x04 = stride_w * dx4 + dilation_w * v0; + int x05 = stride_w * dx5 + dilation_w * v0; + int x06 = stride_w * dx6 + dilation_w * v0; + int x07 = stride_w * dx7 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + int y04 = stride_h * dy4 + dilation_h * u0; + int y05 = stride_h * dy5 + dilation_h * u0; + int y06 = stride_h * dy6 + dilation_h * u0; + int y07 = stride_h * dy7 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int x14 = stride_w * dx4 + dilation_w * v1; + int x15 = stride_w * dx5 + dilation_w * v1; + int x16 = stride_w * dx6 + dilation_w * v1; + int x17 = stride_w * dx7 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + int y14 = stride_h * dy4 + dilation_h * u1; + int y15 = stride_h * dy5 + dilation_h * u1; + int y16 = stride_h * dy6 + dilation_h * u1; + int y17 = stride_h * dy7 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + const signed char* sptr04 = img0.row(y04) + x04; + const signed char* sptr05 = img0.row(y05) + x05; + const signed char* sptr06 = img0.row(y06) + x06; + const signed char* sptr07 = img0.row(y07) + x07; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + const signed char* sptr14 = img1.row(y14) + x14; + const signed char* sptr15 = img1.row(y15) + x15; + const signed char* sptr16 = img1.row(y16) + x16; + const signed char* sptr17 = img1.row(y17) + x17; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp[8] = sptr04[0]; + pp[9] = sptr14[0]; + pp[10] = sptr05[0]; + pp[11] = sptr15[0]; + pp[12] = sptr06[0]; + pp[13] = sptr16[0]; + pp[14] = sptr07[0]; + pp[15] = sptr17[0]; + pp += 16; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int x4 = stride_w * dx4 + dilation_w * v; + int x5 = stride_w * dx5 + dilation_w * v; + int x6 = stride_w * dx6 + dilation_w * v; + int x7 = stride_w * dx7 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + int y4 = stride_h * dy4 + dilation_h * u; + int y5 = stride_h * dy5 + dilation_h * u; + int y6 = stride_h * dy6 + dilation_h * u; + int y7 = stride_h * dy7 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + const signed char* sptr4 = img.row(y4) + x4 * elempack; + const signed char* sptr5 = img.row(y5) + x5 * elempack; + const signed char* sptr6 = img.row(y6) + x6 * elempack; + const signed char* sptr7 = img.row(y7) + x7 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); + __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); + __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); + __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _r2 = _mm_unpacklo_epi32(_r45, _r67); + _r3 = _mm_unpackhi_epi32(_r45, _r67); + _r4 = _mm_unpacklo_epi64(_r0, _r2); + _r5 = _mm_unpackhi_epi64(_r0, _r2); + _r6 = _mm_unpacklo_epi64(_r1, _r3); + _r7 = _mm_unpackhi_epi64(_r1, _r3); + _mm_storeu_si128((__m128i*)pp, _r4); + _mm_storeu_si128((__m128i*)(pp + 16), _r5); + _mm_storeu_si128((__m128i*)(pp + 32), _r6); + _mm_storeu_si128((__m128i*)(pp + 48), _r7); + pp += 64; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp[4] = sptr4[0]; + pp[5] = sptr5[0]; + pp[6] = sptr6[0]; + pp[7] = sptr7[0]; + pp += 8; + } + } + } + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; jj + 3 < max_jj; jj += 4) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dy2 = (j + jj + 2) / outw; + int dy3 = (j + jj + 3) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + int dx2 = (j + jj + 2) % outw; + int dx3 = (j + jj + 3) % outw; + + if (dy0 == dy3) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + if (stride_w == 1) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else if (stride_w == 2) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); + _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); + _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); + _mm_storel_epi64((__m128i*)pp, _r01); + pp += 8; + } + else + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp[4] = sptr0[stride_w * 2]; + pp[5] = sptr1[stride_w * 2]; + pp[6] = sptr0[stride_w * 3]; + pp[7] = sptr1[stride_w * 3]; + pp += 8; + } + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp[2] = sptr[stride_w * 2]; + pp[3] = sptr[stride_w * 3]; + pp += 4; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int x02 = stride_w * dx2 + dilation_w * v0; + int x03 = stride_w * dx3 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int y02 = stride_h * dy2 + dilation_h * u0; + int y03 = stride_h * dy3 + dilation_h * u0; + + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int x12 = stride_w * dx2 + dilation_w * v1; + int x13 = stride_w * dx3 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + int y12 = stride_h * dy2 + dilation_h * u1; + int y13 = stride_h * dy3 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr02 = img0.row(y02) + x02; + const signed char* sptr03 = img0.row(y03) + x03; + + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + const signed char* sptr12 = img1.row(y12) + x12; + const signed char* sptr13 = img1.row(y13) + x13; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp[4] = sptr02[0]; + pp[5] = sptr12[0]; + pp[6] = sptr03[0]; + pp[7] = sptr13[0]; + pp += 8; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int x2 = stride_w * dx2 + dilation_w * v; + int x3 = stride_w * dx3 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + int y2 = stride_h * dy2 + dilation_h * u; + int y3 = stride_h * dy3 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + const signed char* sptr2 = img.row(y2) + x2 * elempack; + const signed char* sptr3 = img.row(y3) + x3 * elempack; + + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); + __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); + _r0 = _mm_unpacklo_epi32(_r01, _r23); + _r1 = _mm_unpackhi_epi32(_r01, _r23); + _mm_storeu_si128((__m128i*)pp, _r0); + _mm_storeu_si128((__m128i*)(pp + 16), _r1); + pp += 32; + } + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr2[0]; + pp[3] = sptr3[0]; + pp += 4; + } + } + } + } +#endif // __SSE2__ + for (; jj + 1 < max_jj; jj += 2) + { + int dy0 = (j + jj) / outw; + int dy1 = (j + jj + 1) / outw; + int dx0 = (j + jj) % outw; + int dx1 = (j + jj + 1) % outw; + + if (dy0 == dy1) + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + + const signed char* sptr0 = img0.row(y00) + x00; + const signed char* sptr1 = img1.row(y10) + x10; + + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp[2] = sptr0[stride_w]; + pp[3] = sptr1[stride_w]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + + const signed char* sptr = img.row(y0) + x0 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr[0]; + pp[1] = sptr[stride_w]; + pp += 2; + } + } + } + else + { + int kk = 0; + if (elempack == 1) + { + for (; kk + 1 < max_kk; kk += 2) + { + int p0 = (k + kk) / maxk; + int p1 = (k + kk + 1) / maxk; + int uv0 = (k + kk) % maxk; + int uv1 = (k + kk + 1) % maxk; + int u0 = uv0 / kernel_w; + int u1 = uv1 / kernel_w; + int v0 = uv0 % kernel_w; + int v1 = uv1 % kernel_w; + + const Mat img0 = bottom_blob.channel(p0); + const Mat img1 = bottom_blob.channel(p1); + + int x00 = stride_w * dx0 + dilation_w * v0; + int x01 = stride_w * dx1 + dilation_w * v0; + int y00 = stride_h * dy0 + dilation_h * u0; + int y01 = stride_h * dy1 + dilation_h * u0; + int x10 = stride_w * dx0 + dilation_w * v1; + int x11 = stride_w * dx1 + dilation_w * v1; + int y10 = stride_h * dy0 + dilation_h * u1; + int y11 = stride_h * dy1 + dilation_h * u1; + + const signed char* sptr00 = img0.row(y00) + x00; + const signed char* sptr01 = img0.row(y01) + x01; + const signed char* sptr10 = img1.row(y10) + x10; + const signed char* sptr11 = img1.row(y11) + x11; + + pp[0] = sptr00[0]; + pp[1] = sptr10[0]; + pp[2] = sptr01[0]; + pp[3] = sptr11[0]; + pp += 4; + } + } + for (; kk < max_kk / elempack; kk++) + { + int p = (k / elempack + kk) / maxk; + int uv = (k / elempack + kk) % maxk; + int u = uv / kernel_w; + int v = uv % kernel_w; + + const Mat img = bottom_blob.channel(p); + + int x0 = stride_w * dx0 + dilation_w * v; + int x1 = stride_w * dx1 + dilation_w * v; + int y0 = stride_h * dy0 + dilation_h * u; + int y1 = stride_h * dy1 + dilation_h * u; + + const signed char* sptr0 = img.row(y0) + x0 * elempack; + const signed char* sptr1 = img.row(y1) + x1 * elempack; + +#if __SSE2__ + if (elempack == 8) + { + __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); + __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); + __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); + _mm_storeu_si128((__m128i*)pp, _r01); + pp += 16; + } +#endif // __SSE2__ + if (elempack == 1) + { + pp[0] = sptr0[0]; + pp[1] = sptr1[0]; + pp += 2; + } } } } diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index f870a8847462..09008985f121 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -46,16 +46,13 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" + +#include "convolution_3x3_winograd_int8.h" #endif // NCNN_INT8 #if __SSE2__ #include "convolution_3x3_pack1to4.h" -#if NCNN_INT8 -#include "convolution_3x3_pack8to4_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#endif // NCNN_INT8 - #if __AVX__ #include "convolution_3x3_pack1to8.h" #include "convolution_3x3_pack8to1.h" @@ -1231,32 +1228,14 @@ int Convolution_x86::create_pipeline_int8_x86(const Option& opt) const int maxk = kernel_w * kernel_h; const int num_input = weight_data_size / maxk / num_output; - int elempack = 1; - int out_elempack_int32 = 1; -#if __SSE2__ - if (opt.use_packing_layout) - { - elempack = num_input % 8 == 0 ? 8 : 1; - out_elempack_int32 = num_output % 4 == 0 ? 4 : 1; - } -#endif // __SSE2__ + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input > 8 || num_output > 8); - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __SSE2__ - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); -#endif // __SSE2__ - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd23_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - conv3x3s1_winograd23_transform_kernel_int8_sse(weight_data, weight_winograd23_data, num_input, num_output, opt); - // conv3x3s1_winograd43_transform_kernel_int8_sse(weight_data, weight_winograd43_data, num_input, num_output, opt); + if (opt.use_winograd43_convolution) + conv3x3s1_winograd43_transform_kernel_int8(weight_data, weight_winograd43_data, num_input, num_output, opt); + else + conv3x3s1_winograd23_transform_kernel_int8(weight_data, weight_winograd23_data, num_input, num_output, opt); } else if (opt.use_sgemm_convolution) { @@ -1352,6 +1331,8 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con if (top_blob_int32.empty()) return -100; + bool prefer_winograd = (opt.use_winograd23_convolution || opt.use_winograd43_convolution) && (num_input > 8 || num_output > 8); + int _nT = nT ? nT : opt.num_threads; if (nT != 0 && opt.num_threads != nT) { @@ -1360,22 +1341,12 @@ int Convolution_x86::forward_int8_x86(const Mat& bottom_blob, Mat& top_blob, con NCNN_LOGE("opt.num_threads %d changed, convolution gemm will use load-time value %d", opt.num_threads, nT); } - if (elempack == 8 && out_elempack_int32 == 4 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __SSE2__ - } - else if (elempack == 8 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd43_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) - { -#if __SSE2__ - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); -#endif // __SSE2__ - } - else if (elempack == 1 && out_elempack_int32 == 1 && opt.use_winograd_convolution && opt.use_winograd23_convolution && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1 && num_input >= 16 && num_output >= 16) + if (opt.use_winograd_convolution && prefer_winograd && kernel_w == 3 && kernel_h == 3 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) { - conv3x3s1_winograd23_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, opt); - // conv3x3s1_winograd43_int8_sse(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, opt); + if (opt.use_winograd43_convolution && !weight_winograd43_data.empty()) + conv3x3s1_winograd43_int8(bottom_blob_bordered, top_blob_int32, weight_winograd43_data, _nT, opt); + else + conv3x3s1_winograd23_int8(bottom_blob_bordered, top_blob_int32, weight_winograd23_data, _nT, opt); } else if (opt.use_sgemm_convolution) { diff --git a/src/layer/x86/convolution_x86_avx2.cpp b/src/layer/x86/convolution_x86_avx2.cpp index 38f107ee0865..49cded702137 100644 --- a/src/layer/x86/convolution_x86_avx2.cpp +++ b/src/layer/x86/convolution_x86_avx2.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_transform_kernel_packed_int8_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) @@ -46,24 +45,24 @@ void convolution_im2col_gemm_int8_avx2(const Mat& bottom_blob, Mat& top_blob, co } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_transform_kernel_int8(kernel, AT, inch, outch, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd23_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx2(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd43_transform_kernel_int8_avx2(const Mat& kernel, Mat& AT, int inch, int outch, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd43_transform_kernel_int8(kernel, AT, inch, outch, opt); } -void conv3x3s1_winograd43_pack8to4_int8_sse_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avx2(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_avx512vnni.cpp b/src/layer/x86/convolution_x86_avx512vnni.cpp index f0ac51bbf856..8e34bb61309f 100644 --- a/src/layer/x86/convolution_x86_avx512vnni.cpp +++ b/src/layer/x86/convolution_x86_avx512vnni.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_avx512vnni(const Mat& bottom_blob, Mat& top_bl } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avx512vnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_avx512vnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_avxvnni.cpp b/src/layer/x86/convolution_x86_avxvnni.cpp index a8ef75bb968b..aa1ba401856c 100644 --- a/src/layer/x86/convolution_x86_avxvnni.cpp +++ b/src/layer/x86/convolution_x86_avxvnni.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_avxvnni(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_avxvnni(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/convolution_x86_xop.cpp b/src/layer/x86/convolution_x86_xop.cpp index d954f5545655..cacba8f07cdd 100644 --- a/src/layer/x86/convolution_x86_xop.cpp +++ b/src/layer/x86/convolution_x86_xop.cpp @@ -20,8 +20,7 @@ namespace ncnn { #include "convolution_packed_int8.h" #include "convolution_im2col_gemm_int8.h" -#include "convolution_3x3_pack8to1_int8.h" -#include "convolution_3x3_pack8to4_int8.h" +#include "convolution_3x3_winograd_int8.h" // packed void convolution_packed_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) @@ -36,24 +35,14 @@ void convolution_im2col_gemm_int8_xop(const Mat& bottom_blob, Mat& top_blob, con } // winograd -void conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse_xop(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) +void conv3x3s1_winograd23_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_transform_kernel_pack8to1_int8_sse(kernel, kernel_tm, inch, outch, opt); + conv3x3s1_winograd23_int8(bottom_blob, top_blob, AT, nT, opt); } -void conv3x3s1_winograd43_pack8to1_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) +void conv3x3s1_winograd43_int8_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& AT, int nT, const Option& opt) { - conv3x3s1_winograd43_pack8to1_int8_sse(bottom_blob, top_blob, kernel, opt); -} - -void conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse_xop(const Mat& kernel, Mat& kernel_tm, int inch, int outch, const Option& opt) -{ - conv3x3s1_winograd43_transform_kernel_pack8to4_int8_sse(kernel, kernel_tm, inch, outch, opt); -} - -void conv3x3s1_winograd43_pack8to4_int8_sse_xop(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Option& opt) -{ - conv3x3s1_winograd43_pack8to4_int8_sse(bottom_blob, top_blob, kernel, opt); + conv3x3s1_winograd43_int8(bottom_blob, top_blob, AT, nT, opt); } } // namespace ncnn diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index c9551330ff67..e75e78c0c255 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -42,6 +42,83 @@ static NCNN_FORCEINLINE signed char float2int8(float v) } #if __SSE2__ +static NCNN_FORCEINLINE void transpose4x8_epi32(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) +{ + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi32(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi32(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi32(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi32(_r6, _r7); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpacklo_epi64(_tmp4, _tmp6); + _r2 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r3 = _mm_unpackhi_epi64(_tmp4, _tmp6); + _r4 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r5 = _mm_unpacklo_epi64(_tmp5, _tmp7); + _r6 = _mm_unpackhi_epi64(_tmp1, _tmp3); + _r7 = _mm_unpackhi_epi64(_tmp5, _tmp7); +} + +static NCNN_FORCEINLINE void transpose4x4_epi32(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi32(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi32(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi32(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi32(_r2, _r3); + + _r0 = _mm_unpacklo_epi64(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi64(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi64(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi64(_tmp1, _tmp3); +} + +static NCNN_FORCEINLINE void transpose8x8_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) +{ + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); + __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); + __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); + __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); + + __m128i _tmp8 = _mm_unpacklo_epi32(_tmp0, _tmp2); + __m128i _tmp9 = _mm_unpackhi_epi32(_tmp0, _tmp2); + __m128i _tmpa = _mm_unpacklo_epi32(_tmp1, _tmp3); + __m128i _tmpb = _mm_unpackhi_epi32(_tmp1, _tmp3); + __m128i _tmpc = _mm_unpacklo_epi32(_tmp4, _tmp6); + __m128i _tmpd = _mm_unpackhi_epi32(_tmp4, _tmp6); + __m128i _tmpe = _mm_unpacklo_epi32(_tmp5, _tmp7); + __m128i _tmpf = _mm_unpackhi_epi32(_tmp5, _tmp7); + + _r0 = _mm_unpacklo_epi64(_tmp8, _tmpc); + _r1 = _mm_unpackhi_epi64(_tmp8, _tmpc); + _r2 = _mm_unpacklo_epi64(_tmp9, _tmpd); + _r3 = _mm_unpackhi_epi64(_tmp9, _tmpd); + _r4 = _mm_unpacklo_epi64(_tmpa, _tmpe); + _r5 = _mm_unpackhi_epi64(_tmpa, _tmpe); + _r6 = _mm_unpacklo_epi64(_tmpb, _tmpf); + _r7 = _mm_unpackhi_epi64(_tmpb, _tmpf); +} + +static NCNN_FORCEINLINE void transpose8x4_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); + + _r0 = _mm_unpacklo_epi32(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi32(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi32(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi32(_tmp1, _tmp3); +} + static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128) { const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); @@ -341,36 +418,6 @@ static NCNN_FORCEINLINE void transpose8x2_ps(__m256& _r0, __m256& _r1) _r1 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); } -static NCNN_FORCEINLINE void transpose8x8_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7) -{ - __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _tmp1 = _mm_unpackhi_epi16(_r0, _r1); - __m128i _tmp2 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _tmp3 = _mm_unpackhi_epi16(_r2, _r3); - __m128i _tmp4 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _tmp5 = _mm_unpackhi_epi16(_r4, _r5); - __m128i _tmp6 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _tmp7 = _mm_unpackhi_epi16(_r6, _r7); - - __m128i _tmp8 = _mm_unpacklo_epi32(_tmp0, _tmp2); - __m128i _tmp9 = _mm_unpackhi_epi32(_tmp0, _tmp2); - __m128i _tmpa = _mm_unpacklo_epi32(_tmp1, _tmp3); - __m128i _tmpb = _mm_unpackhi_epi32(_tmp1, _tmp3); - __m128i _tmpc = _mm_unpacklo_epi32(_tmp4, _tmp6); - __m128i _tmpd = _mm_unpackhi_epi32(_tmp4, _tmp6); - __m128i _tmpe = _mm_unpacklo_epi32(_tmp5, _tmp7); - __m128i _tmpf = _mm_unpackhi_epi32(_tmp5, _tmp7); - - _r0 = _mm_unpacklo_epi64(_tmp8, _tmpc); - _r1 = _mm_unpackhi_epi64(_tmp8, _tmpc); - _r2 = _mm_unpacklo_epi64(_tmp9, _tmpd); - _r3 = _mm_unpackhi_epi64(_tmp9, _tmpd); - _r4 = _mm_unpacklo_epi64(_tmpa, _tmpe); - _r5 = _mm_unpackhi_epi64(_tmpa, _tmpe); - _r6 = _mm_unpacklo_epi64(_tmpb, _tmpf); - _r7 = _mm_unpackhi_epi64(_tmpb, _tmpf); -} - static NCNN_FORCEINLINE __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7) { const __m256 s01 = _mm256_hadd_ps(v0, v1); @@ -598,6 +645,55 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx(const __m256& v0, const __m256& return _v; } +#if __AVX2__ +static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1) +{ + __m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi32(_r0, _r1); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m256i& _r2, __m256i& _r3, __m256i& _r4, __m256i& _r5, __m256i& _r6, __m256i& _r7) +{ + __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); + __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); + __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); + __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); + __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); + __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); + __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); + __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); + + __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); + __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); + __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); + __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); + __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); + __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); + __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); + __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); + + _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); + _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); + _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); + _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); + _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); + _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); + _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); + _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); + + _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); + _r6 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); +} + #if __AVX512F__ static NCNN_FORCEINLINE void transpose16x16_ps(__m512& _r0, __m512& _r1, __m512& _r2, __m512& _r3, __m512& _r4, __m512& _r5, __m512& _r6, __m512& _r7, __m512& _r8, __m512& _r9, __m512& _ra, __m512& _rb, __m512& _rc, __m512& _rd, __m512& _re, __m512& _rf) @@ -928,45 +1024,6 @@ static NCNN_FORCEINLINE void transpose16x16_epi16(__m256i& _r0, __m256i& _r1, __ _rf = _mm256_permute2x128_si256(_tmp7, _tmpf, _MM_SHUFFLE(0, 3, 0, 1)); } -static NCNN_FORCEINLINE void transpose16x8_epi16(__m256i& _r0, __m256i& _r1, __m256i& _r2, __m256i& _r3, __m256i& _r4, __m256i& _r5, __m256i& _r6, __m256i& _r7) -{ - __m256i _tmp0 = _mm256_unpacklo_epi16(_r0, _r1); - __m256i _tmp1 = _mm256_unpackhi_epi16(_r0, _r1); - __m256i _tmp2 = _mm256_unpacklo_epi16(_r2, _r3); - __m256i _tmp3 = _mm256_unpackhi_epi16(_r2, _r3); - __m256i _tmp4 = _mm256_unpacklo_epi16(_r4, _r5); - __m256i _tmp5 = _mm256_unpackhi_epi16(_r4, _r5); - __m256i _tmp6 = _mm256_unpacklo_epi16(_r6, _r7); - __m256i _tmp7 = _mm256_unpackhi_epi16(_r6, _r7); - - __m256i _tmpg = _mm256_unpacklo_epi32(_tmp0, _tmp2); - __m256i _tmph = _mm256_unpackhi_epi32(_tmp0, _tmp2); - __m256i _tmpi = _mm256_unpacklo_epi32(_tmp1, _tmp3); - __m256i _tmpj = _mm256_unpackhi_epi32(_tmp1, _tmp3); - __m256i _tmpk = _mm256_unpacklo_epi32(_tmp4, _tmp6); - __m256i _tmpl = _mm256_unpackhi_epi32(_tmp4, _tmp6); - __m256i _tmpm = _mm256_unpacklo_epi32(_tmp5, _tmp7); - __m256i _tmpn = _mm256_unpackhi_epi32(_tmp5, _tmp7); - - _tmp0 = _mm256_unpacklo_epi64(_tmpg, _tmpk); - _tmp1 = _mm256_unpackhi_epi64(_tmpg, _tmpk); - _tmp2 = _mm256_unpacklo_epi64(_tmph, _tmpl); - _tmp3 = _mm256_unpackhi_epi64(_tmph, _tmpl); - _tmp4 = _mm256_unpacklo_epi64(_tmpi, _tmpm); - _tmp5 = _mm256_unpackhi_epi64(_tmpi, _tmpm); - _tmp6 = _mm256_unpacklo_epi64(_tmpj, _tmpn); - _tmp7 = _mm256_unpackhi_epi64(_tmpj, _tmpn); - - _r0 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 2, 0, 0)); - _r1 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 2, 0, 0)); - _r2 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 2, 0, 0)); - _r3 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); - _r4 = _mm256_permute2x128_si256(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); - _r5 = _mm256_permute2x128_si256(_tmp2, _tmp3, _MM_SHUFFLE(0, 3, 0, 1)); - _r6 = _mm256_permute2x128_si256(_tmp4, _tmp5, _MM_SHUFFLE(0, 3, 0, 1)); - _r7 = _mm256_permute2x128_si256(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); -} - static NCNN_FORCEINLINE void transpose8x16_epi16(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3, __m128i& _r4, __m128i& _r5, __m128i& _r6, __m128i& _r7, __m128i& _r8, __m128i& _r9, __m128i& _ra, __m128i& _rb, __m128i& _rc, __m128i& _rd, __m128i& _re, __m128i& _rf) { __m128i _tmp0 = _mm_unpacklo_epi16(_r0, _r1); @@ -1088,6 +1145,7 @@ static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m5 } #endif // __AVX512F__ +#endif // __AVX2__ #endif // __AVX__ #endif // __SSE2__ diff --git a/tests/test_convolution_3.cpp b/tests/test_convolution_3.cpp index fa358d0670cc..1d0f8f079b62 100644 --- a/tests/test_convolution_3.cpp +++ b/tests/test_convolution_3.cpp @@ -190,6 +190,30 @@ static int test_convolution_int8(int w, int h, int c, int outch, int kernel, int return ret; } + if (kernel == 3 && dilation == 1 && stride == 1) + { + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = true; + opt.use_fp16_packed = false; + opt.use_fp16_storage = false; + opt.use_fp16_arithmetic = false; + opt.use_bf16_storage = false; + opt.use_shader_pack8 = false; + opt.use_image_storage = false; + opt.use_sgemm_convolution = false; + opt.use_winograd_convolution = true; + opt.use_winograd23_convolution = true; + opt.use_winograd43_convolution = false; + + ret = test_layer_opt("Convolution", pd, weights, opt, a, requant ? 1.0f : 0.001f, 0, flag); + if (ret != 0) + { + fprintf(stderr, "test_convolution_int8 failed w=%d h=%d c=%d outch=%d kernel=%d dilation=%d stride=%d pad=%d bias=%d requant=%d act=%d actparams=[%f,%f]\n", w, h, c, outch, kernel, dilation, stride, pad, bias, requant, activation_type, activation_params[0], activation_params[1]); + return ret; + } + } + { ncnn::Option opt; opt.num_threads = 1; @@ -310,6 +334,7 @@ static int test_convolution_1() || test_convolution_int8(4, 20, 16, 24, 3, 1, 1, 1, 0) || test_convolution_int8(6, 7, 64, 64, 3, 1, 2, 0, 1) || test_convolution_int8(25, 33, 16, 15, 3, 1, 1, 1, 0) + || test_convolution_int8(25, 33, 31, 31, 3, 1, 1, 1, 0) || test_convolution_int8(7, 7, 15, 12, 3, 1, 1, 1, 0) || test_convolution_int8(5, 6, 31, 9, 5, 1, 1, 0, 1) || test_convolution_int8(5, 7, 32, 8, 5, 1, 2, 0, 1) From 97ffd1e661e91dbc74ee032ad052fc1f605627f0 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 11 Oct 2023 20:10:00 +0800 Subject: [PATCH 09/16] add labeler (#5078) --- .github/label.yml | 12 ++++++++++++ .github/labeler.yml | 26 ++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 .github/label.yml create mode 100644 .github/labeler.yml diff --git a/.github/label.yml b/.github/label.yml new file mode 100644 index 000000000000..889c41b54c97 --- /dev/null +++ b/.github/label.yml @@ -0,0 +1,12 @@ +name: labeler +on: [pull_request_target] + +permissions: + contents: read + pull-requests: write + +jobs: + label: + runs-on: ubuntu-latest + steps: + - uses: actions/labeler@v4 diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 000000000000..0d98b80d5455 --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,26 @@ +cmake: +- cmake/** +- **/CMakeLists.txt +- **/*.cmake +- toolchains/** + +doc: docs/** + +python: python/** + +example: examples/** + +test: tests/** + +tool: tools/** +pnnx: tools/pnnx/** + +core: src/* +layer: src/layer/* + +arm: src/layer/arm/** +loongarch: src/layer/loongarch/** +mips: src/layer/mips/** +riscv: src/layer/riscv/** +vulkan: src/layer/vulkan/** +x86: src/layer/x86/** From 2f1f2e97700ce1be6e6b4664e87afcd6baa398a6 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 11 Oct 2023 20:14:21 +0800 Subject: [PATCH 10/16] fix labeler --- .github/{label.yml => workflows/labeler.yml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/{label.yml => workflows/labeler.yml} (100%) diff --git a/.github/label.yml b/.github/workflows/labeler.yml similarity index 100% rename from .github/label.yml rename to .github/workflows/labeler.yml From 3f79c4ff103780d6e6def9224fe2dab2406b04e3 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 11 Oct 2023 20:16:17 +0800 Subject: [PATCH 11/16] Update labeler.yml --- .github/labeler.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index 0d98b80d5455..31d03ede341c 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,7 +1,5 @@ cmake: - cmake/** -- **/CMakeLists.txt -- **/*.cmake - toolchains/** doc: docs/** From bcdc276ffedd0a17a083f129cb1f8a7438b6fddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=B0=E9=98=85?= <43716063+Baiyuetribe@users.noreply.github.com> Date: Fri, 20 Oct 2023 01:44:55 -0500 Subject: [PATCH 12/16] add torch.view_as_real and torch.view_as_complex (#5083) --- tools/pnnx/src/CMakeLists.txt | 2 + .../src/pass_level2/torch_view_as_complex.cpp | 40 ++++++++++++ .../src/pass_level2/torch_view_as_real.cpp | 40 ++++++++++++ tools/pnnx/tests/CMakeLists.txt | 2 + .../pnnx/tests/test_torch_view_as_complex.py | 61 +++++++++++++++++++ tools/pnnx/tests/test_torch_view_as_real.py | 61 +++++++++++++++++++ 6 files changed, 206 insertions(+) create mode 100644 tools/pnnx/src/pass_level2/torch_view_as_complex.cpp create mode 100644 tools/pnnx/src/pass_level2/torch_view_as_real.cpp create mode 100644 tools/pnnx/tests/test_torch_view_as_complex.py create mode 100644 tools/pnnx/tests/test_torch_view_as_real.py diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index c15ba0973d67..58264dfd9754 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -270,6 +270,8 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_unbind.cpp pass_level2/torch_unsqueeze.cpp pass_level2/torch_var.cpp + pass_level2/torch_view_as_complex.cpp + pass_level2/torch_view_as_real.cpp pass_level2/torch_zeros.cpp pass_level2/torch_zeros_like.cpp pass_level2/torch_stft.cpp diff --git a/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp b/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp new file mode 100644 index 000000000000..e00ff1371ca4 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_view_as_complex : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::view_as_complex op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.view_as_complex"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_view_as_complex, 20) + +} // namespace pnnx \ No newline at end of file diff --git a/tools/pnnx/src/pass_level2/torch_view_as_real.cpp b/tools/pnnx/src/pass_level2/torch_view_as_real.cpp new file mode 100644 index 000000000000..83327e01ef96 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_view_as_real.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_view_as_real : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::view_as_real op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.view_as_real"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_view_as_real, 20) + +} // namespace pnnx \ No newline at end of file diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index fab12342b58f..346ee0a955fc 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -237,6 +237,8 @@ pnnx_add_test(torch_topk) pnnx_add_test(torch_transpose) pnnx_add_test(torch_unbind) pnnx_add_test(torch_unsqueeze) +pnnx_add_test(torch_view_as_complex) +pnnx_add_test(torch_view_as_real) pnnx_add_test(torch_zeros) pnnx_add_test(torch_zeros_like) diff --git a/tools/pnnx/tests/test_torch_view_as_complex.py b/tools/pnnx/tests/test_torch_view_as_complex.py new file mode 100644 index 000000000000..c2cedc537d05 --- /dev/null +++ b/tools/pnnx/tests/test_torch_view_as_complex.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.view_as_complex(x) + y = torch.view_as_complex(y) + z = torch.view_as_complex(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 2) + y = torch.rand(1, 5, 9, 2) + z = torch.rand(14, 8, 5, 9, 2) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_view_as_complex.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_view_as_complex.pt inputshape=[1,3,2],[1,5,9,2],[14,8,5,9,2]") + + # pnnx inference + import test_torch_view_as_complex_pnnx + b = test_torch_view_as_complex_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) \ No newline at end of file diff --git a/tools/pnnx/tests/test_torch_view_as_real.py b/tools/pnnx/tests/test_torch_view_as_real.py new file mode 100644 index 000000000000..06bbe7de9b10 --- /dev/null +++ b/tools/pnnx/tests/test_torch_view_as_real.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.view_as_real(x) + y = torch.view_as_real(y) + z = torch.view_as_real(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16,dtype=torch.complex64) + y = torch.rand(1, 5, 9, 11,dtype=torch.complex64) + z = torch.rand(14, 8, 5, 9, 10,dtype=torch.complex64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_view_as_real.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_view_as_real.pt inputshape=[1,3,16]c64,[1,5,9,11]c64,[14,8,5,9,10]c64") + + # pnnx inference + import test_torch_view_as_real_pnnx + b = test_torch_view_as_real_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) \ No newline at end of file From b82d3957533c8c9cf9aa0aa68b1e03ba0f48ec2d Mon Sep 17 00:00:00 2001 From: Xinyu302 Date: Fri, 20 Oct 2023 15:43:37 +0800 Subject: [PATCH 13/16] Add riscv float32 gemm (#4903) Co-authored-by: Xinyu302 --- src/layer/riscv/gemm_riscv.cpp | 4282 +++++++++++++++++++++++++++++ src/layer/riscv/gemm_riscv.h | 43 + src/layer/riscv/riscv_usability.h | 276 ++ 3 files changed, 4601 insertions(+) create mode 100644 src/layer/riscv/gemm_riscv.cpp create mode 100644 src/layer/riscv/gemm_riscv.h diff --git a/src/layer/riscv/gemm_riscv.cpp b/src/layer/riscv/gemm_riscv.cpp new file mode 100644 index 000000000000..ec5a5cdac413 --- /dev/null +++ b/src/layer/riscv/gemm_riscv.cpp @@ -0,0 +1,4282 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "gemm_riscv.h" + +#if __riscv_vector +#include +#endif // __riscv_vector + +#include "riscv_usability.h" + +#include "cpu.h" + +namespace ncnn { + +Gemm_riscv::Gemm_riscv() +{ +#if __riscv_vector + support_packing = true; +#endif // __riscv_vector + one_blob_only = false; + support_inplace = false; + + nT = 0; +#if __riscv_vector + // When processing float data, + // even if the current hardware provides vector registers of more than 128 bits, + // vl=4 is still used, even though this will waste the width of the vector register. + vl = vsetvlmax_e32m1(); + vl = vl >= 4 ? 4 : vl; +#else + vl = 0; +#endif // __riscv_vector +} + +static void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + float* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 4; + const float* p1 = (const float*)A + (i + ii + 4) * A_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p1, vl), vl); + pp += 8; + p0 += 4; + p1 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + const float* p4 = (const float*)A + (i + ii + 4) * A_hstep + k; + const float* p5 = (const float*)A + (i + ii + 5) * A_hstep + k; + const float* p6 = (const float*)A + (i + ii + 6) * A_hstep + k; + const float* p7 = (const float*)A + (i + ii + 7) * A_hstep + k; + + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + vfloat32m1_t _r0l = vle32_v_f32m1(p0, vl); + vfloat32m1_t _r0h = vle32_v_f32m1(p0 + 4, vl); + vfloat32m1_t _r1l = vle32_v_f32m1(p1, vl); + vfloat32m1_t _r1h = vle32_v_f32m1(p1 + 4, vl); + vfloat32m1_t _r2l = vle32_v_f32m1(p2, vl); + vfloat32m1_t _r2h = vle32_v_f32m1(p2 + 4, vl); + vfloat32m1_t _r3l = vle32_v_f32m1(p3, vl); + vfloat32m1_t _r3h = vle32_v_f32m1(p3 + 4, vl); + vfloat32m1_t _r4l = vle32_v_f32m1(p4, vl); + vfloat32m1_t _r4h = vle32_v_f32m1(p4 + 4, vl); + vfloat32m1_t _r5l = vle32_v_f32m1(p5, vl); + vfloat32m1_t _r5h = vle32_v_f32m1(p5 + 4, vl); + vfloat32m1_t _r6l = vle32_v_f32m1(p6, vl); + vfloat32m1_t _r6h = vle32_v_f32m1(p6 + 4, vl); + vfloat32m1_t _r7l = vle32_v_f32m1(p7, vl); + vfloat32m1_t _r7h = vle32_v_f32m1(p7 + 4, vl); + transpose8x8_ps(_r0l, _r0h, _r1l, _r1h, _r2l, _r2h, _r3l, _r3h, _r4l, _r4h, _r5l, _r5h, _r6l, _r6h, _r7l, _r7h, vl); + vse32_v_f32m1(pp, _r0l, vl); + vse32_v_f32m1(pp + 4, _r0h, vl); + vse32_v_f32m1(pp + 8, _r1l, vl); + vse32_v_f32m1(pp + 12, _r1h, vl); + vse32_v_f32m1(pp + 8 * 2, _r2l, vl); + vse32_v_f32m1(pp + 8 * 2 + 4, _r2h, vl); + vse32_v_f32m1(pp + 8 * 3, _r3l, vl); + vse32_v_f32m1(pp + 8 * 3 + 4, _r3h, vl); + vse32_v_f32m1(pp + 8 * 4, _r4l, vl); + vse32_v_f32m1(pp + 8 * 4 + 4, _r4h, vl); + vse32_v_f32m1(pp + 8 * 5, _r5l, vl); + vse32_v_f32m1(pp + 8 * 5 + 4, _r5h, vl); + vse32_v_f32m1(pp + 8 * 6, _r6l, vl); + vse32_v_f32m1(pp + 8 * 6 + 4, _r6h, vl); + vse32_v_f32m1(pp + 8 * 7, _r7l, vl); + vse32_v_f32m1(pp + 8 * 7 + 4, _r7h, vl); + pp += 64; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + p4 += 8; + p5 += 8; + p6 += 8; + p7 += 8; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + const float* p2 = (const float*)A + (i + ii + 2) * A_hstep + k; + const float* p3 = (const float*)A + (i + ii + 3) * A_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t v2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t v3 = vle32_v_f32m1(p3, vl); + store_float_v4(v0, v1, v2, v3, pp, vl); + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { + // if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + const float* p1 = (const float*)A + (i + ii + 1) * A_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += 4; + p1 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + } + for (; ii < max_ii; ii += 1) + { + // if (elempack == 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, size_t vl) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + float* pp = AT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vfloat32m1_t _r4; + vfloat32m1_t _r5; + vfloat32m1_t _r6; + vfloat32m1_t _r7; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vlseg4e32_v_f32m1(&_r4, &_r5, &_r6, &_r7, p0 + 16, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r1, vl); + vse32_v_f32m1(pp + 4 * 3, _r5, vl); + vse32_v_f32m1(pp + 4 * 4, _r2, vl); + vse32_v_f32m1(pp + 4 * 5, _r6, vl); + vse32_v_f32m1(pp + 4 * 6, _r3, vl); + vse32_v_f32m1(pp + 4 * 7, _r7, vl); + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p0 + 4, vl), vl); + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r1, vl); + vse32_v_f32m1(pp + 4 * 2, _r2, vl); + vse32_v_f32m1(pp + 4 * 3, _r3, vl); + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p0 + 4, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, size_t vl) +{ + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + float* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; + const float* p1 = (const float*)B + (j + jj + 4) * B_hstep + k * 4; + const float* p2 = (const float*)B + (j + jj + 8) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p1, vl), vl); + vse32_v_f32m1(pp + 8, vle32_v_f32m1(p2, vl), vl); + pp += 12; + p0 += 4; + p1 += 4; + p2 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; + const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; + const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; + const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; + const float* p8 = (const float*)B + (j + jj + 8) * B_hstep + k; + const float* p9 = (const float*)B + (j + jj + 9) * B_hstep + k; + const float* pa = (const float*)B + (j + jj + 10) * B_hstep + k; + const float* pb = (const float*)B + (j + jj + 11) * B_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t _r1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t _r2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t _r3 = vle32_v_f32m1(p3, vl); + vfloat32m1_t _r4 = vle32_v_f32m1(p4, vl); + vfloat32m1_t _r5 = vle32_v_f32m1(p5, vl); + vfloat32m1_t _r6 = vle32_v_f32m1(p6, vl); + vfloat32m1_t _r7 = vle32_v_f32m1(p7, vl); + vfloat32m1_t _r8 = vle32_v_f32m1(p8, vl); + vfloat32m1_t _r9 = vle32_v_f32m1(p9, vl); + vfloat32m1_t _ra = vle32_v_f32m1(pa, vl); + vfloat32m1_t _rb = vle32_v_f32m1(pb, vl); + + transpose4x4_ps(_r0, _r1, _r2, _r3, vl); + transpose4x4_ps(_r4, _r5, _r6, _r7, vl); + transpose4x4_ps(_r8, _r9, _ra, _rb, vl); + + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r8, vl); + vse32_v_f32m1(pp + 4 * 3, _r1, vl); + vse32_v_f32m1(pp + 4 * 4, _r5, vl); + vse32_v_f32m1(pp + 4 * 5, _r9, vl); + vse32_v_f32m1(pp + 4 * 6, _r2, vl); + vse32_v_f32m1(pp + 4 * 7, _r6, vl); + vse32_v_f32m1(pp + 4 * 8, _ra, vl); + vse32_v_f32m1(pp + 4 * 9, _r3, vl); + vse32_v_f32m1(pp + 4 * 10, _r7, vl); + vse32_v_f32m1(pp + 4 * 11, _rb, vl); + pp += 48; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + p8 += 4; + p9 += 4; + pa += 4; + pb += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp[8] = p8[0]; + pp[9] = p9[0]; + pp[10] = pa[0]; + pp[11] = pb[0]; + pp += 12; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + p8++; + p9++; + pa++; + pb++; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; + const float* p1 = (const float*)B + (j + jj + 4) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p1, vl), vl); + pp += 8; + p0 += 4; + p1 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + const float* p4 = (const float*)B + (j + jj + 4) * B_hstep + k; + const float* p5 = (const float*)B + (j + jj + 5) * B_hstep + k; + const float* p6 = (const float*)B + (j + jj + 6) * B_hstep + k; + const float* p7 = (const float*)B + (j + jj + 7) * B_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t _r1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t _r2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t _r3 = vle32_v_f32m1(p3, vl); + vfloat32m1_t _r4 = vle32_v_f32m1(p4, vl); + vfloat32m1_t _r5 = vle32_v_f32m1(p5, vl); + vfloat32m1_t _r6 = vle32_v_f32m1(p6, vl); + vfloat32m1_t _r7 = vle32_v_f32m1(p7, vl); + + transpose4x4_ps(_r0, _r1, _r2, _r3, vl); + transpose4x4_ps(_r4, _r5, _r6, _r7, vl); + + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r1, vl); + vse32_v_f32m1(pp + 4 * 3, _r5, vl); + vse32_v_f32m1(pp + 4 * 4, _r2, vl); + vse32_v_f32m1(pp + 4 * 5, _r6, vl); + vse32_v_f32m1(pp + 4 * 6, _r3, vl); + vse32_v_f32m1(pp + 4 * 7, _r7, vl); + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * 4; + + for (int kk = 0; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + const float* p2 = (const float*)B + (j + jj + 2) * B_hstep + k; + const float* p3 = (const float*)B + (j + jj + 3) * B_hstep + k; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + vfloat32m1_t v2 = vle32_v_f32m1(p2, vl); + vfloat32m1_t v3 = vle32_v_f32m1(p3, vl); + store_float_v4(v0, v1, v2, v3, pp, vl); + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + // if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + const float* p1 = (const float*)B + (j + jj + 1) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p1, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += 4; + p1 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + } + for (; jj < max_jj; jj += 1) + { + // if (elempack == 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + int kk = 0; +#if __riscv_vector + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += 4; + } +#endif // __riscv_vector + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, size_t vl) +{ + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + float* pp = BT; + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vfloat32m1_t _r4; + vfloat32m1_t _r5; + vfloat32m1_t _r6; + vfloat32m1_t _r7; + vfloat32m1_t _r8; + vfloat32m1_t _r9; + vfloat32m1_t _ra; + vfloat32m1_t _rb; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vlseg4e32_v_f32m1(&_r4, &_r5, &_r6, &_r7, p0 + 16, vl); + vlseg4e32_v_f32m1(&_r8, &_r9, &_ra, &_rb, p0 + 32, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r8, vl); + vse32_v_f32m1(pp + 4 * 3, _r1, vl); + vse32_v_f32m1(pp + 4 * 4, _r5, vl); + vse32_v_f32m1(pp + 4 * 5, _r9, vl); + vse32_v_f32m1(pp + 4 * 6, _r2, vl); + vse32_v_f32m1(pp + 4 * 7, _r6, vl); + vse32_v_f32m1(pp + 4 * 8, _ra, vl); + vse32_v_f32m1(pp + 4 * 9, _r3, vl); + vse32_v_f32m1(pp + 4 * 10, _r7, vl); + vse32_v_f32m1(pp + 4 * 11, _rb, vl); + pp += 48; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p0 + 4, vl), vl); + vse32_v_f32m1(pp + 8, vle32_v_f32m1(p0 + 8, vl), vl); + pp += 12; + p0 += B_hstep; + } + } + } + for (; jj + 7 < max_jj; jj += 8) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vfloat32m1_t _r4; + vfloat32m1_t _r5; + vfloat32m1_t _r6; + vfloat32m1_t _r7; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vlseg4e32_v_f32m1(&_r4, &_r5, &_r6, &_r7, p0 + 16, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r4, vl); + vse32_v_f32m1(pp + 4 * 2, _r1, vl); + vse32_v_f32m1(pp + 4 * 3, _r5, vl); + vse32_v_f32m1(pp + 4 * 4, _r2, vl); + vse32_v_f32m1(pp + 4 * 5, _r6, vl); + vse32_v_f32m1(pp + 4 * 6, _r3, vl); + vse32_v_f32m1(pp + 4 * 7, _r7, vl); + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + vse32_v_f32m1(pp + 4, vle32_v_f32m1(p0 + 4, vl), vl); + pp += 8; + p0 += B_hstep; + } + } + } + for (; jj + 3 < max_jj; jj += 4) + { + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _r0; + vfloat32m1_t _r1; + vfloat32m1_t _r2; + vfloat32m1_t _r3; + vlseg4e32_v_f32m1(&_r0, &_r1, &_r2, &_r3, p0, vl); + vse32_v_f32m1(pp, _r0, vl); + vse32_v_f32m1(pp + 4, _r1, vl); + vse32_v_f32m1(pp + 4 * 2, _r2, vl); + vse32_v_f32m1(pp + 4 * 3, _r3, vl); + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(p0, vl); + vfloat32m1_t v1 = vle32_v_f32m1(p0 + 4, vl); + store_float_v2(v0, v1, pp, vl); + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { +#if __riscv_vector + if (elempack == 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * 4; + + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vse32_v_f32m1(pp, vle32_v_f32m1(p0, vl), vl); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __riscv_vector + if (elempack == 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj); + + int kk = 0; + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void transpose_unpack_output_tile(const Mat& topT, Mat& top_blob, int i, int max_ii, int j, int max_jj, size_t vl) +{ + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const float* pp = topT; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(pp, vl); + vfloat32m1_t v1 = vle32_v_f32m1(pp + 8, vl); + vfloat32m1_t v2 = vle32_v_f32m1(pp + 16, vl); + vfloat32m1_t v3 = vle32_v_f32m1(pp + 24, vl); + store_float_v4(v0, v1, v2, v3, p0, vl); + v0 = vle32_v_f32m1(pp + 4, vl); + v1 = vle32_v_f32m1(pp + 12, vl); + v2 = vle32_v_f32m1(pp + 20, vl); + v3 = vle32_v_f32m1(pp + 28, vl); + store_float_v4(v0, v1, v2, v3, p0 + 16, vl); + pp += 32; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vfloat32m1_t _r0 = vle32_v_f32m1(pp, vl); + vfloat32m1_t _r1 = vle32_v_f32m1(pp + 4, vl); + vse32_v_f32m1(p0, _r0, vl); + vse32_v_f32m1(p0 + 4, _r1, vl); + pp += 8; + p0 += out_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t v0 = vle32_v_f32m1(pp, vl); + vfloat32m1_t v1 = vle32_v_f32m1(pp + 4, vl); + vfloat32m1_t v2 = vle32_v_f32m1(pp + 8, vl); + vfloat32m1_t v3 = vle32_v_f32m1(pp + 12, vl); + store_float_v4(v0, v1, v2, v3, p0, vl); + pp += 16; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + vfloat32m1_t _r0 = vle32_v_f32m1(pp, vl); + vse32_v_f32m1(p0, _r0, vl); + pp += 4; + p0 += out_hstep; + } + } + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { +#if __riscv_vector + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + p0[0] = pp[0]; + p0[1] = pp[2]; + p0[2] = pp[4]; + p0[3] = pp[6]; + p0[4] = pp[1]; + p0[5] = pp[3]; + p0[6] = pp[5]; + p0[7] = pp[7]; + pp += 8; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + p0[1] = pp[1]; + pp += 2; + p0 += out_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { +#if __riscv_vector + if (out_elempack == 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * 4; + + for (int jj = 0; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _r0 = vle32_v_f32m1(pp, vl); + vse32_v_f32m1(p0, _r0, vl); + pp += 4; + p0 += out_hstep * 4; + } + } +#endif // __riscv_vector + if (out_elempack == 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii); + + for (int jj = 0; jj < max_jj; jj += 1) + { + p0[0] = pp[0]; + pp += 1; + p0 += out_hstep; + } + } + } +} + +static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end, size_t vl) +{ + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const float* pAT = AT_tile; + const float* pBT = BT_tile; + const float* pC = CT_tile; + + float* outptr = topT_tile; + + int ii = 0; +#if __riscv_vector + for (; ii + 7 < max_ii; ii += 8) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum20; + vfloat32m1_t _sum21; + vfloat32m1_t _sum30; + vfloat32m1_t _sum31; + vfloat32m1_t _sum40; + vfloat32m1_t _sum41; + vfloat32m1_t _sum50; + vfloat32m1_t _sum51; + vfloat32m1_t _sum60; + vfloat32m1_t _sum61; + vfloat32m1_t _sum70; + vfloat32m1_t _sum71; + vfloat32m1_t _sum80; + vfloat32m1_t _sum81; + vfloat32m1_t _sum90; + vfloat32m1_t _sum91; + vfloat32m1_t _suma0; + vfloat32m1_t _suma1; + vfloat32m1_t _sumb0; + vfloat32m1_t _sumb1; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum20 = vfmv_v_f_f32m1(0.f, vl); + _sum21 = vfmv_v_f_f32m1(0.f, vl); + _sum30 = vfmv_v_f_f32m1(0.f, vl); + _sum31 = vfmv_v_f_f32m1(0.f, vl); + _sum40 = vfmv_v_f_f32m1(0.f, vl); + _sum41 = vfmv_v_f_f32m1(0.f, vl); + _sum50 = vfmv_v_f_f32m1(0.f, vl); + _sum51 = vfmv_v_f_f32m1(0.f, vl); + _sum60 = vfmv_v_f_f32m1(0.f, vl); + _sum61 = vfmv_v_f_f32m1(0.f, vl); + _sum70 = vfmv_v_f_f32m1(0.f, vl); + _sum71 = vfmv_v_f_f32m1(0.f, vl); + _sum80 = vfmv_v_f_f32m1(0.f, vl); + _sum81 = vfmv_v_f_f32m1(0.f, vl); + _sum90 = vfmv_v_f_f32m1(0.f, vl); + _sum91 = vfmv_v_f_f32m1(0.f, vl); + _suma0 = vfmv_v_f_f32m1(0.f, vl); + _suma1 = vfmv_v_f_f32m1(0.f, vl); + _sumb0 = vfmv_v_f_f32m1(0.f, vl); + _sumb1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum20 = vfmv_v_f_f32m1(pC[0], vl); + _sum21 = vfmv_v_f_f32m1(pC[0], vl); + _sum30 = vfmv_v_f_f32m1(pC[0], vl); + _sum31 = vfmv_v_f_f32m1(pC[0], vl); + _sum40 = vfmv_v_f_f32m1(pC[0], vl); + _sum41 = vfmv_v_f_f32m1(pC[0], vl); + _sum50 = vfmv_v_f_f32m1(pC[0], vl); + _sum51 = vfmv_v_f_f32m1(pC[0], vl); + _sum60 = vfmv_v_f_f32m1(pC[0], vl); + _sum61 = vfmv_v_f_f32m1(pC[0], vl); + _sum70 = vfmv_v_f_f32m1(pC[0], vl); + _sum71 = vfmv_v_f_f32m1(pC[0], vl); + _sum80 = vfmv_v_f_f32m1(pC[0], vl); + _sum81 = vfmv_v_f_f32m1(pC[0], vl); + _sum90 = vfmv_v_f_f32m1(pC[0], vl); + _sum91 = vfmv_v_f_f32m1(pC[0], vl); + _suma0 = vfmv_v_f_f32m1(pC[0], vl); + _suma1 = vfmv_v_f_f32m1(pC[0], vl); + _sumb0 = vfmv_v_f_f32m1(pC[0], vl); + _sumb1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum20 = _sum00; + _sum21 = _sum01; + _sum30 = _sum00; + _sum31 = _sum01; + _sum40 = _sum00; + _sum41 = _sum01; + _sum50 = _sum00; + _sum51 = _sum01; + _sum60 = _sum00; + _sum61 = _sum01; + _sum70 = _sum00; + _sum71 = _sum01; + _sum80 = _sum00; + _sum81 = _sum01; + _sum90 = _sum00; + _sum91 = _sum01; + _suma0 = _sum00; + _suma1 = _sum01; + _sumb0 = _sum00; + _sumb1 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + _sum20 = vle32_v_f32m1(pC + 4 * 4, vl); + _sum21 = vle32_v_f32m1(pC + 4 * 5, vl); + _sum30 = vle32_v_f32m1(pC + 4 * 6, vl); + _sum31 = vle32_v_f32m1(pC + 4 * 7, vl); + _sum40 = vle32_v_f32m1(pC + 4 * 8, vl); + _sum41 = vle32_v_f32m1(pC + 4 * 9, vl); + _sum50 = vle32_v_f32m1(pC + 4 * 10, vl); + _sum51 = vle32_v_f32m1(pC + 4 * 11, vl); + _sum60 = vle32_v_f32m1(pC + 4 * 12, vl); + _sum61 = vle32_v_f32m1(pC + 4 * 13, vl); + _sum70 = vle32_v_f32m1(pC + 4 * 14, vl); + _sum71 = vle32_v_f32m1(pC + 4 * 15, vl); + _sum80 = vle32_v_f32m1(pC + 4 * 16, vl); + _sum81 = vle32_v_f32m1(pC + 4 * 17, vl); + _sum90 = vle32_v_f32m1(pC + 4 * 18, vl); + _sum91 = vle32_v_f32m1(pC + 4 * 19, vl); + _suma0 = vle32_v_f32m1(pC + 4 * 20, vl); + _suma1 = vle32_v_f32m1(pC + 4 * 21, vl); + _sumb0 = vle32_v_f32m1(pC + 4 * 22, vl); + _sumb1 = vle32_v_f32m1(pC + 4 * 23, vl); + pC += 96; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum20 = vfmv_v_f_f32m1(pC[2], vl); + _sum30 = vfmv_v_f_f32m1(pC[3], vl); + _sum40 = vfmv_v_f_f32m1(pC[4], vl); + _sum50 = vfmv_v_f_f32m1(pC[5], vl); + _sum60 = vfmv_v_f_f32m1(pC[6], vl); + _sum70 = vfmv_v_f_f32m1(pC[7], vl); + _sum80 = vfmv_v_f_f32m1(pC[8], vl); + _sum90 = vfmv_v_f_f32m1(pC[9], vl); + _suma0 = vfmv_v_f_f32m1(pC[10], vl); + _sumb0 = vfmv_v_f_f32m1(pC[11], vl); + _sum01 = _sum00; + _sum11 = _sum10; + _sum21 = _sum20; + _sum31 = _sum30; + _sum41 = _sum40; + _sum51 = _sum50; + _sum61 = _sum60; + _sum71 = _sum70; + _sum81 = _sum80; + _sum91 = _sum90; + _suma1 = _suma0; + _sumb1 = _sumb0; + pC += 12; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum20 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum21 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum30 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum31 = vle32_v_f32m1(outptr + 4 * 7, vl); + _sum40 = vle32_v_f32m1(outptr + 4 * 8, vl); + _sum41 = vle32_v_f32m1(outptr + 4 * 9, vl); + _sum50 = vle32_v_f32m1(outptr + 4 * 10, vl); + _sum51 = vle32_v_f32m1(outptr + 4 * 11, vl); + _sum60 = vle32_v_f32m1(outptr + 4 * 12, vl); + _sum61 = vle32_v_f32m1(outptr + 4 * 13, vl); + _sum70 = vle32_v_f32m1(outptr + 4 * 14, vl); + _sum71 = vle32_v_f32m1(outptr + 4 * 15, vl); + _sum80 = vle32_v_f32m1(outptr + 4 * 16, vl); + _sum81 = vle32_v_f32m1(outptr + 4 * 17, vl); + _sum90 = vle32_v_f32m1(outptr + 4 * 18, vl); + _sum91 = vle32_v_f32m1(outptr + 4 * 19, vl); + _suma0 = vle32_v_f32m1(outptr + 4 * 20, vl); + _suma1 = vle32_v_f32m1(outptr + 4 * 21, vl); + _sumb0 = vle32_v_f32m1(outptr + 4 * 22, vl); + _sumb1 = vle32_v_f32m1(outptr + 4 * 23, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk + 3 < max_kk; kk += 4) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + + _pA0 = vle32_v_f32m1(pA, vl); + _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + + _pA0 = vle32_v_f32m1(pA, vl); + _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + + _pA0 = vle32_v_f32m1(pA, vl); + _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + } + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + _sum80 = vfmadd_vf_f32m1(_pA0, pB[8], _sum80, vl); + _sum81 = vfmadd_vf_f32m1(_pA1, pB[8], _sum81, vl); + _sum90 = vfmadd_vf_f32m1(_pA0, pB[9], _sum90, vl); + _sum91 = vfmadd_vf_f32m1(_pA1, pB[9], _sum91, vl); + _suma0 = vfmadd_vf_f32m1(_pA0, pB[10], _suma0, vl); + _suma1 = vfmadd_vf_f32m1(_pA1, pB[10], _suma1, vl); + _sumb0 = vfmadd_vf_f32m1(_pA0, pB[11], _sumb0, vl); + _sumb1 = vfmadd_vf_f32m1(_pA1, pB[11], _sumb1, vl); + + pA += 8; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum40, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum50, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum60, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum70, vl); + vse32_v_f32m1(outptr0 + 4 * 8, _sum80, vl); + vse32_v_f32m1(outptr0 + 4 * 9, _sum90, vl); + vse32_v_f32m1(outptr0 + 4 * 10, _suma0, vl); + vse32_v_f32m1(outptr0 + 4 * 11, _sumb0, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 4, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 5, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 6, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 7, _sum71, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 8, _sum81, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 9, _sum91, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 10, _suma1, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 11, _sumb1, vl); + + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose8x12_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, _sum40, _sum41, _sum50, _sum51, _sum60, _sum61, _sum70, _sum71, _sum80, _sum81, _sum90, _sum91, _suma0, _suma1, _sumb0, _sumb1, vl); + + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + 8, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum20, vl); + vse32_v_f32m1(outptr0 + out_hstep + 8, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum30, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 8, _sum40, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum50, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 8, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum60, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 8, _sum70, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5, _sum71, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5 + 4, _sum80, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5 + 8, _sum81, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6, _sum90, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6 + 4, _sum91, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6 + 8, _suma0, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7, _suma1, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7 + 4, _sumb0, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7 + 8, _sumb1, vl); + + outptr0 += 12; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); + vse32_v_f32m1(outptr + 4 * 8, _sum40, vl); + vse32_v_f32m1(outptr + 4 * 9, _sum41, vl); + vse32_v_f32m1(outptr + 4 * 10, _sum50, vl); + vse32_v_f32m1(outptr + 4 * 11, _sum51, vl); + vse32_v_f32m1(outptr + 4 * 12, _sum60, vl); + vse32_v_f32m1(outptr + 4 * 13, _sum61, vl); + vse32_v_f32m1(outptr + 4 * 14, _sum70, vl); + vse32_v_f32m1(outptr + 4 * 15, _sum71, vl); + vse32_v_f32m1(outptr + 4 * 16, _sum80, vl); + vse32_v_f32m1(outptr + 4 * 17, _sum81, vl); + vse32_v_f32m1(outptr + 4 * 18, _sum90, vl); + vse32_v_f32m1(outptr + 4 * 19, _sum91, vl); + vse32_v_f32m1(outptr + 4 * 20, _suma0, vl); + vse32_v_f32m1(outptr + 4 * 21, _suma1, vl); + vse32_v_f32m1(outptr + 4 * 22, _sumb0, vl); + vse32_v_f32m1(outptr + 4 * 23, _sumb1, vl); + } + + outptr += 96; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum20; + vfloat32m1_t _sum21; + vfloat32m1_t _sum30; + vfloat32m1_t _sum31; + vfloat32m1_t _sum40; + vfloat32m1_t _sum41; + vfloat32m1_t _sum50; + vfloat32m1_t _sum51; + vfloat32m1_t _sum60; + vfloat32m1_t _sum61; + vfloat32m1_t _sum70; + vfloat32m1_t _sum71; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum20 = vfmv_v_f_f32m1(0.f, vl); + _sum21 = vfmv_v_f_f32m1(0.f, vl); + _sum30 = vfmv_v_f_f32m1(0.f, vl); + _sum31 = vfmv_v_f_f32m1(0.f, vl); + _sum40 = vfmv_v_f_f32m1(0.f, vl); + _sum41 = vfmv_v_f_f32m1(0.f, vl); + _sum50 = vfmv_v_f_f32m1(0.f, vl); + _sum51 = vfmv_v_f_f32m1(0.f, vl); + _sum60 = vfmv_v_f_f32m1(0.f, vl); + _sum61 = vfmv_v_f_f32m1(0.f, vl); + _sum70 = vfmv_v_f_f32m1(0.f, vl); + _sum71 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum20 = vfmv_v_f_f32m1(pC[0], vl); + _sum21 = vfmv_v_f_f32m1(pC[0], vl); + _sum30 = vfmv_v_f_f32m1(pC[0], vl); + _sum31 = vfmv_v_f_f32m1(pC[0], vl); + _sum40 = vfmv_v_f_f32m1(pC[0], vl); + _sum41 = vfmv_v_f_f32m1(pC[0], vl); + _sum50 = vfmv_v_f_f32m1(pC[0], vl); + _sum51 = vfmv_v_f_f32m1(pC[0], vl); + _sum60 = vfmv_v_f_f32m1(pC[0], vl); + _sum61 = vfmv_v_f_f32m1(pC[0], vl); + _sum70 = vfmv_v_f_f32m1(pC[0], vl); + _sum71 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum20 = _sum00; + _sum21 = _sum01; + _sum30 = _sum00; + _sum31 = _sum01; + _sum40 = _sum00; + _sum41 = _sum01; + _sum50 = _sum00; + _sum51 = _sum01; + _sum60 = _sum00; + _sum61 = _sum01; + _sum70 = _sum00; + _sum71 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + _sum20 = vle32_v_f32m1(pC + 4 * 4, vl); + _sum21 = vle32_v_f32m1(pC + 4 * 5, vl); + _sum30 = vle32_v_f32m1(pC + 4 * 6, vl); + _sum31 = vle32_v_f32m1(pC + 4 * 7, vl); + _sum40 = vle32_v_f32m1(pC + 4 * 8, vl); + _sum41 = vle32_v_f32m1(pC + 4 * 9, vl); + _sum50 = vle32_v_f32m1(pC + 4 * 10, vl); + _sum51 = vle32_v_f32m1(pC + 4 * 11, vl); + _sum60 = vle32_v_f32m1(pC + 4 * 12, vl); + _sum61 = vle32_v_f32m1(pC + 4 * 13, vl); + _sum70 = vle32_v_f32m1(pC + 4 * 14, vl); + _sum71 = vle32_v_f32m1(pC + 4 * 15, vl); + pC += 64; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum20 = vfmv_v_f_f32m1(pC[2], vl); + _sum30 = vfmv_v_f_f32m1(pC[3], vl); + _sum40 = vfmv_v_f_f32m1(pC[4], vl); + _sum50 = vfmv_v_f_f32m1(pC[5], vl); + _sum60 = vfmv_v_f_f32m1(pC[6], vl); + _sum70 = vfmv_v_f_f32m1(pC[7], vl); + _sum01 = _sum00; + _sum11 = _sum10; + _sum21 = _sum20; + _sum31 = _sum30; + _sum41 = _sum40; + _sum51 = _sum50; + _sum61 = _sum60; + _sum71 = _sum70; + pC += 8; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum20 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum21 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum30 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum31 = vle32_v_f32m1(outptr + 4 * 7, vl); + _sum40 = vle32_v_f32m1(outptr + 4 * 8, vl); + _sum41 = vle32_v_f32m1(outptr + 4 * 9, vl); + _sum50 = vle32_v_f32m1(outptr + 4 * 10, vl); + _sum51 = vle32_v_f32m1(outptr + 4 * 11, vl); + _sum60 = vle32_v_f32m1(outptr + 4 * 12, vl); + _sum61 = vle32_v_f32m1(outptr + 4 * 13, vl); + _sum70 = vle32_v_f32m1(outptr + 4 * 14, vl); + _sum71 = vle32_v_f32m1(outptr + 4 * 15, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + _sum40 = vfmadd_vf_f32m1(_pA0, pB[4], _sum40, vl); + _sum41 = vfmadd_vf_f32m1(_pA1, pB[4], _sum41, vl); + _sum50 = vfmadd_vf_f32m1(_pA0, pB[5], _sum50, vl); + _sum51 = vfmadd_vf_f32m1(_pA1, pB[5], _sum51, vl); + _sum60 = vfmadd_vf_f32m1(_pA0, pB[6], _sum60, vl); + _sum61 = vfmadd_vf_f32m1(_pA1, pB[6], _sum61, vl); + _sum70 = vfmadd_vf_f32m1(_pA0, pB[7], _sum70, vl); + _sum71 = vfmadd_vf_f32m1(_pA1, pB[7], _sum71, vl); + + pA += 8; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum40, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum50, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum60, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum70, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 4, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 5, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 6, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 7, _sum71, vl); + + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, _sum40, _sum41, _sum50, _sum51, _sum60, _sum61, _sum70, _sum71, vl); + + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum30, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum31, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum40, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum41, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5, _sum50, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5 + 4, _sum51, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6, _sum60, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6 + 4, _sum61, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7, _sum70, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7 + 4, _sum71, vl); + + outptr0 += 8; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); + vse32_v_f32m1(outptr + 4 * 8, _sum40, vl); + vse32_v_f32m1(outptr + 4 * 9, _sum41, vl); + vse32_v_f32m1(outptr + 4 * 10, _sum50, vl); + vse32_v_f32m1(outptr + 4 * 11, _sum51, vl); + vse32_v_f32m1(outptr + 4 * 12, _sum60, vl); + vse32_v_f32m1(outptr + 4 * 13, _sum61, vl); + vse32_v_f32m1(outptr + 4 * 14, _sum70, vl); + vse32_v_f32m1(outptr + 4 * 15, _sum71, vl); + } + + outptr += 64; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum20; + vfloat32m1_t _sum21; + vfloat32m1_t _sum30; + vfloat32m1_t _sum31; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum20 = vfmv_v_f_f32m1(0.f, vl); + _sum21 = vfmv_v_f_f32m1(0.f, vl); + _sum30 = vfmv_v_f_f32m1(0.f, vl); + _sum31 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum20 = vfmv_v_f_f32m1(pC[0], vl); + _sum21 = vfmv_v_f_f32m1(pC[0], vl); + _sum30 = vfmv_v_f_f32m1(pC[0], vl); + _sum31 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum20 = _sum00; + _sum21 = _sum01; + _sum30 = _sum00; + _sum31 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + _sum20 = vle32_v_f32m1(pC + 4 * 4, vl); + _sum21 = vle32_v_f32m1(pC + 4 * 5, vl); + _sum30 = vle32_v_f32m1(pC + 4 * 6, vl); + _sum31 = vle32_v_f32m1(pC + 4 * 7, vl); + pC += 32; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum20 = vfmv_v_f_f32m1(pC[2], vl); + _sum30 = vfmv_v_f_f32m1(pC[3], vl); + _sum01 = _sum00; + _sum11 = _sum10; + _sum21 = _sum20; + _sum31 = _sum30; + pC += 4; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum20 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum21 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum30 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum31 = vle32_v_f32m1(outptr + 4 * 7, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + _sum20 = vfmadd_vf_f32m1(_pA0, pB[2], _sum20, vl); + _sum21 = vfmadd_vf_f32m1(_pA1, pB[2], _sum21, vl); + _sum30 = vfmadd_vf_f32m1(_pA0, pB[3], _sum30, vl); + _sum31 = vfmadd_vf_f32m1(_pA1, pB[3], _sum31, vl); + + pA += 8; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum20, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum30, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 2, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4 * 3, _sum31, vl); + + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose8x4_ps(_sum00, _sum01, _sum10, _sum11, _sum20, _sum21, _sum30, _sum31, vl); + + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + out_hstep * 1, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum20, vl); + vse32_v_f32m1(outptr0 + out_hstep * 5, _sum21, vl); + vse32_v_f32m1(outptr0 + out_hstep * 6, _sum30, vl); + vse32_v_f32m1(outptr0 + out_hstep * 7, _sum31, vl); + + outptr0 += 4; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum20, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum21, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum30, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum31, vl); + } + + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4 * 1, vl); + _sum10 = vle32_v_f32m1(pC + 4 * 2, vl); + _sum11 = vle32_v_f32m1(pC + 4 * 3, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum01 = _sum00; + _sum11 = _sum10; + pC += 2; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum10 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum11 = vle32_v_f32m1(outptr + 4 * 3, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pA0, pB[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pA1, pB[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pA0, pB[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pA1, pB[1], _sum11, vl); + + pA += 8; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum10, vl); + + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4 + 4, _sum11, vl); + outptr0 += 8; + } + if (out_elempack == 1) + { + float sum0[8]; + float sum1[8]; + vse32_v_f32m1(sum0, _sum00, vl); + vse32_v_f32m1(sum0 + 4, _sum01, vl); + vse32_v_f32m1(sum1, _sum10, vl); + vse32_v_f32m1(sum1 + 4, _sum11, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0[out_hstep * 4 + 1] = sum1[4]; + outptr0[out_hstep * 5 + 1] = sum1[5]; + outptr0[out_hstep * 6 + 1] = sum1[6]; + outptr0[out_hstep * 7 + 1] = sum1[7]; + outptr0 += 2; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum10, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum11, vl); + } + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + } + if (broadcast_type_C == 3) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = _sum00; + pC += 1; + } + } + } + else + { + _sum00 = vle32_v_f32m1(outptr, vl); + _sum01 = vle32_v_f32m1(outptr + 4 * 1, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA0 = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pA1 = vle32_v_f32m1(pA + 4, vl); + + vfloat32m1_t _pB = vfmv_v_f_f32m1(pB[0], vl); + + _sum00 = vfmadd_vv_f32m1(_pA0, _pB, _sum00, vl); + _sum01 = vfmadd_vv_f32m1(_pA1, _pB, _sum01, vl); + + pA += 8; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + out_hstep * 4, _sum01, vl); + outptr0 += 4; + } + if (out_elempack == 1) + { + float sum0[8]; + vse32_v_f32m1(sum0, _sum00, vl); + vse32_v_f32m1(sum0 + 4, _sum01, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep * 1] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + outptr0++; + } + } + else + { + vse32_v_f32m1(outptr, _sum00, vl); + vse32_v_f32m1(outptr + 4, _sum01, vl); + } + + outptr += 8; + } + + pAT += max_kk * 8; + } + for (; ii + 3 < max_ii; ii += 4) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + vfloat32m1_t _sum4; + vfloat32m1_t _sum5; + vfloat32m1_t _sum6; + vfloat32m1_t _sum7; + vfloat32m1_t _sum8; + vfloat32m1_t _sum9; + vfloat32m1_t _suma; + vfloat32m1_t _sumb; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + _sum3 = vfmv_v_f_f32m1(0.f, vl); + _sum4 = vfmv_v_f_f32m1(0.f, vl); + _sum5 = vfmv_v_f_f32m1(0.f, vl); + _sum6 = vfmv_v_f_f32m1(0.f, vl); + _sum7 = vfmv_v_f_f32m1(0.f, vl); + _sum8 = vfmv_v_f_f32m1(0.f, vl); + _sum9 = vfmv_v_f_f32m1(0.f, vl); + _suma = vfmv_v_f_f32m1(0.f, vl); + _sumb = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + _sum3 = vfmv_v_f_f32m1(pC[0], vl); + _sum4 = vfmv_v_f_f32m1(pC[0], vl); + _sum5 = vfmv_v_f_f32m1(pC[0], vl); + _sum6 = vfmv_v_f_f32m1(pC[0], vl); + _sum7 = vfmv_v_f_f32m1(pC[0], vl); + _sum8 = vfmv_v_f_f32m1(pC[0], vl); + _sum9 = vfmv_v_f_f32m1(pC[0], vl); + _suma = vfmv_v_f_f32m1(pC[0], vl); + _sumb = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + _sum3 = vle32_v_f32m1(pC + 12, vl); + _sum4 = vle32_v_f32m1(pC + 16, vl); + _sum5 = vle32_v_f32m1(pC + 20, vl); + _sum6 = vle32_v_f32m1(pC + 24, vl); + _sum7 = vle32_v_f32m1(pC + 28, vl); + _sum8 = vle32_v_f32m1(pC + 32, vl); + _sum9 = vle32_v_f32m1(pC + 36, vl); + _suma = vle32_v_f32m1(pC + 40, vl); + _sumb = vle32_v_f32m1(pC + 44, vl); + pC += 48; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + _sum2 = vfmv_v_f_f32m1(pC[2], vl); + _sum3 = vfmv_v_f_f32m1(pC[3], vl); + _sum4 = vfmv_v_f_f32m1(pC[4], vl); + _sum5 = vfmv_v_f_f32m1(pC[5], vl); + _sum6 = vfmv_v_f_f32m1(pC[6], vl); + _sum7 = vfmv_v_f_f32m1(pC[7], vl); + _sum8 = vfmv_v_f_f32m1(pC[8], vl); + _sum9 = vfmv_v_f_f32m1(pC[9], vl); + _suma = vfmv_v_f_f32m1(pC[10], vl); + _sumb = vfmv_v_f_f32m1(pC[11], vl); + pC += 12; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum4 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum5 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum6 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum7 = vle32_v_f32m1(outptr + 4 * 7, vl); + _sum8 = vle32_v_f32m1(outptr + 4 * 8, vl); + _sum9 = vle32_v_f32m1(outptr + 4 * 9, vl); + _suma = vle32_v_f32m1(outptr + 4 * 10, vl); + _sumb = vle32_v_f32m1(outptr + 4 * 11, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + _sum4 = vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); + _sum5 = vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); + _sum6 = vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); + _sum7 = vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); + _sum8 = vfmadd_vf_f32m1(_pA, pB[8], _sum8, vl); + _sum9 = vfmadd_vf_f32m1(_pA, pB[9], _sum9, vl); + _suma = vfmadd_vf_f32m1(_pA, pB[10], _suma, vl); + _sumb = vfmadd_vf_f32m1(_pA, pB[11], _sumb, vl); + + pA += 4; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum7, vl); + vse32_v_f32m1(outptr0 + 4 * 8, _sum8, vl); + vse32_v_f32m1(outptr0 + 4 * 9, _sum9, vl); + vse32_v_f32m1(outptr0 + 4 * 10, _suma, vl); + vse32_v_f32m1(outptr0 + 4 * 11, _sumb, vl); + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose4x12_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb, vl); + + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 8, _sum2, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum3, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum4, vl); + vse32_v_f32m1(outptr0 + out_hstep + 8, _sum5, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum6, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum7, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 8, _sum8, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum9, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _suma, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 8, _sumb, vl); + outptr0 += 12; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum7, vl); + vse32_v_f32m1(outptr + 4 * 8, _sum8, vl); + vse32_v_f32m1(outptr + 4 * 9, _sum9, vl); + vse32_v_f32m1(outptr + 4 * 10, _suma, vl); + vse32_v_f32m1(outptr + 4 * 11, _sumb, vl); + } + + outptr += 48; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + vfloat32m1_t _sum4; + vfloat32m1_t _sum5; + vfloat32m1_t _sum6; + vfloat32m1_t _sum7; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + _sum3 = vfmv_v_f_f32m1(0.f, vl); + _sum4 = vfmv_v_f_f32m1(0.f, vl); + _sum5 = vfmv_v_f_f32m1(0.f, vl); + _sum6 = vfmv_v_f_f32m1(0.f, vl); + _sum7 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + _sum3 = vfmv_v_f_f32m1(pC[0], vl); + _sum4 = vfmv_v_f_f32m1(pC[0], vl); + _sum5 = vfmv_v_f_f32m1(pC[0], vl); + _sum6 = vfmv_v_f_f32m1(pC[0], vl); + _sum7 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + _sum3 = vle32_v_f32m1(pC + 12, vl); + _sum4 = vle32_v_f32m1(pC + 16, vl); + _sum5 = vle32_v_f32m1(pC + 20, vl); + _sum6 = vle32_v_f32m1(pC + 24, vl); + _sum7 = vle32_v_f32m1(pC + 28, vl); + pC += 32; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + _sum2 = vfmv_v_f_f32m1(pC[2], vl); + _sum3 = vfmv_v_f_f32m1(pC[3], vl); + _sum4 = vfmv_v_f_f32m1(pC[4], vl); + _sum5 = vfmv_v_f_f32m1(pC[5], vl); + _sum6 = vfmv_v_f_f32m1(pC[6], vl); + _sum7 = vfmv_v_f_f32m1(pC[7], vl); + pC += 8; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m1(outptr + 4 * 3, vl); + _sum4 = vle32_v_f32m1(outptr + 4 * 4, vl); + _sum5 = vle32_v_f32m1(outptr + 4 * 5, vl); + _sum6 = vle32_v_f32m1(outptr + 4 * 6, vl); + _sum7 = vle32_v_f32m1(outptr + 4 * 7, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + _sum4 = vfmadd_vf_f32m1(_pA, pB[4], _sum4, vl); + _sum5 = vfmadd_vf_f32m1(_pA, pB[5], _sum5, vl); + _sum6 = vfmadd_vf_f32m1(_pA, pB[6], _sum6, vl); + _sum7 = vfmadd_vf_f32m1(_pA, pB[7], _sum7, vl); + + pA += 4; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr0 + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr0 + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr0 + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr0 + 4 * 7, _sum7, vl); + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose4x8_ps(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, vl); + + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum2, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum3, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum4, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2 + 4, _sum5, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum6, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3 + 4, _sum7, vl); + outptr0 += 8; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); + vse32_v_f32m1(outptr + 4 * 4, _sum4, vl); + vse32_v_f32m1(outptr + 4 * 5, _sum5, vl); + vse32_v_f32m1(outptr + 4 * 6, _sum6, vl); + vse32_v_f32m1(outptr + 4 * 7, _sum7, vl); + } + + outptr += 32; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + vfloat32m1_t _sum3; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + _sum3 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + _sum3 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + _sum3 = vle32_v_f32m1(pC + 12, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + _sum2 = vfmv_v_f_f32m1(pC[2], vl); + _sum3 = vfmv_v_f_f32m1(pC[3], vl); + pC += 4; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4 * 1, vl); + _sum2 = vle32_v_f32m1(outptr + 4 * 2, vl); + _sum3 = vle32_v_f32m1(outptr + 4 * 3, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + _sum2 = vfmadd_vf_f32m1(_pA, pB[2], _sum2, vl); + _sum3 = vfmadd_vf_f32m1(_pA, pB[3], _sum3, vl); + pA += 4; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + 4 * 3, _sum3, vl); + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_ps(_sum0, _sum1, _sum2, _sum3, vl); + + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + out_hstep * 1, _sum1, vl); + vse32_v_f32m1(outptr0 + out_hstep * 2, _sum2, vl); + vse32_v_f32m1(outptr0 + out_hstep * 3, _sum3, vl); + outptr0 += 4; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 4 * 2, _sum2, vl); + vse32_v_f32m1(outptr + 4 * 3, _sum3, vl); + } + + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + pC += 2; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + + _sum0 = vfmadd_vf_f32m1(_pA, pB[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pA, pB[1], _sum1, vl); + + pA += 4; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + outptr0 += 8; + } + if (out_elempack == 1) + { + float sum0[4]; + float sum1[4]; + vse32_v_f32m1(sum0, _sum0, vl); + vse32_v_f32m1(sum1, _sum1, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0 += 2; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + } + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + vfloat32m1_t _sum0; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vle32_v_f32m1(pC, vl); + } + if (broadcast_type_C == 3) + { + _sum0 = vle32_v_f32m1(pC, vl); + pC += 4; + } + if (broadcast_type_C == 4) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + pC += 1; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pA = vle32_v_f32m1(pA, vl); + vfloat32m1_t _pB = vfmv_v_f_f32m1(pB[0], vl); + + _sum0 = vfmadd_vv_f32m1(_pA, _pB, _sum0, vl); + + pA += 4; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 4) + { + vse32_v_f32m1(outptr0, _sum0, vl); + outptr0 += 4; + } + if (out_elempack == 1) + { + float sum0[4]; + vse32_v_f32m1(sum0, _sum0, vl); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0++; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + } + + outptr += 4; + } + + pAT += max_kk * 4; + } +#endif // __riscv_vector + for (; ii + 1 < max_ii; ii += 2) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum02; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + vfloat32m1_t _sum12; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum02 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + _sum12 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum02 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + _sum12 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum02 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum11 = vfmv_v_f_f32m1(pC[1], vl); + _sum12 = vfmv_v_f_f32m1(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, pC, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, pC + 8, vl); + vlseg2e32_v_f32m1(&_sum02, &_sum12, pC + 16, vl); + pC += 24; + } + if (broadcast_type_C == 4) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum02 = vle32_v_f32m1(pC + 8, vl); + _sum10 = _sum00; + _sum11 = _sum01; + _sum12 = _sum02; + pC += 12; + } + } + } + else + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, outptr, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, outptr + 8, vl); + vlseg2e32_v_f32m1(&_sum02, &_sum12, outptr + 16, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + vfloat32m1_t _pB2 = vle32_v_f32m1(pB + 8, vl); + + _sum00 = vfmadd_vf_f32m1(_pB0, pA[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pB1, pA[0], _sum01, vl); + _sum02 = vfmadd_vf_f32m1(_pB2, pA[0], _sum02, vl); + _sum10 = vfmadd_vf_f32m1(_pB0, pA[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pB1, pA[1], _sum11, vl); + _sum12 = vfmadd_vf_f32m1(_pB2, pA[1], _sum12, vl); + + pA += 2; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + 8, _sum02, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); + vse32_v_f32m1(outptr0 + out_hstep + 8, _sum12, vl); + outptr0 += 12; + } + } + else + { + store_float_v2(_sum00, _sum10, outptr, vl); + store_float_v2(_sum01, _sum11, outptr + 8, vl); + store_float_v2(_sum02, _sum12, outptr + 16, vl); + } + + outptr += 24; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum00; + vfloat32m1_t _sum01; + vfloat32m1_t _sum10; + vfloat32m1_t _sum11; + + if (k == 0) + { + _sum00 = vfmv_v_f_f32m1(0.f, vl); + _sum01 = vfmv_v_f_f32m1(0.f, vl); + _sum10 = vfmv_v_f_f32m1(0.f, vl); + _sum11 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[0], vl); + _sum11 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum00 = vfmv_v_f_f32m1(pC[0], vl); + _sum01 = vfmv_v_f_f32m1(pC[0], vl); + _sum10 = vfmv_v_f_f32m1(pC[1], vl); + _sum11 = vfmv_v_f_f32m1(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, pC, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, pC + 8, vl); + pC += 16; + } + if (broadcast_type_C == 4) + { + _sum00 = vle32_v_f32m1(pC, vl); + _sum01 = vle32_v_f32m1(pC + 4, vl); + _sum10 = _sum00; + _sum11 = _sum01; + pC += 8; + } + } + } + else + { + vlseg2e32_v_f32m1(&_sum00, &_sum10, outptr, vl); + vlseg2e32_v_f32m1(&_sum01, &_sum11, outptr + 8, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + + _sum00 = vfmadd_vf_f32m1(_pB0, pA[0], _sum00, vl); + _sum01 = vfmadd_vf_f32m1(_pB1, pA[0], _sum01, vl); + _sum10 = vfmadd_vf_f32m1(_pB0, pA[1], _sum10, vl); + _sum11 = vfmadd_vf_f32m1(_pB1, pA[1], _sum11, vl); + pA += 2; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum00, vl); + vse32_v_f32m1(outptr0 + 4, _sum01, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum10, vl); + vse32_v_f32m1(outptr0 + out_hstep + 4, _sum11, vl); + outptr0 += 8; + } + } + else + { + store_float_v2(_sum00, _sum10, outptr, vl); + store_float_v2(_sum01, _sum11, outptr + 8, vl); + } + + outptr += 16; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[1], vl); + } + if (broadcast_type_C == 3) + { + vlseg2e32_v_f32m1(&_sum0, &_sum1, pC, vl); + pC += 8; + } + if (broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = _sum0; + pC += 4; + } + } + } + else + { + vfloat32m1_t _tmp0; + vfloat32m1_t _tmp1; + vlseg2e32_v_f32m1(&_tmp0, &_tmp1, outptr, vl); + _sum0 = _tmp0; + _sum1 = _tmp1; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB = vle32_v_f32m1(pB, vl); + + _sum0 = vfmadd_vf_f32m1(_pB, pA[0], _sum0, vl); + _sum1 = vfmadd_vf_f32m1(_pB, pA[1], _sum1, vl); + + pA += 2; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + out_hstep, _sum1, vl); + outptr0 += 4; + } + } + else + { + store_float_v2(_sum0, _sum1, outptr, vl); + } + + outptr += 8; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + float sum00; + float sum01; + float sum10; + float sum11; + + if (k == 0) + { + sum00 = 0.f; + sum01 = 0.f; + sum10 = 0.f; + sum11 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[0]; + sum11 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[0]; + sum11 = pC[1]; + } + if (broadcast_type_C == 3) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[2]; + sum11 = pC[3]; + pC += 4; + } + if (broadcast_type_C == 4) + { + sum00 = pC[0]; + sum01 = pC[0]; + sum10 = pC[1]; + sum11 = pC[1]; + pC += 2; + } + } + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum00; + outptr0[1] = sum10; + outptr0[out_hstep] = sum01; + outptr0[out_hstep + 1] = sum11; + outptr0 += 2; + } + } + else + { + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + } + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + float sum0; + float sum1; + + if (k == 0) + { + sum0 = 0.f; + sum1 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum0 = pC[0]; + sum1 = pC[1]; + } + if (broadcast_type_C == 3) + { + sum0 = pC[0]; + sum1 = pC[1]; + pC += 2; + } + if (broadcast_type_C == 4) + { + sum0 = pC[0]; + sum1 = pC[0]; + pC += 1; + } + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[out_hstep] = sum1; + outptr0++; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + float* outptr0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float* pB = pBT; + + if (pC) + { + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)CT_tile + i + ii; + } + if (broadcast_type_C == 4) + { + pC = (const float*)CT_tile + j; + } + } + + int jj = 0; +#if __riscv_vector + for (; jj + 11 < max_jj; jj += 12) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + vfloat32m1_t _sum2; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + _sum2 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + _sum2 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + _sum2 = vle32_v_f32m1(pC + 8, vl); + pC += 12; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4, vl); + _sum2 = vle32_v_f32m1(outptr + 8, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + vfloat32m1_t _pB2 = vle32_v_f32m1(pB + 8, vl); + + vfloat32m1_t _pA0 = vfmv_v_f_f32m1(pA[0], vl); + + _sum0 = vfmadd_vv_f32m1(_pA0, _pB0, _sum0, vl); + _sum1 = vfmadd_vv_f32m1(_pA0, _pB1, _sum1, vl); + _sum2 = vfmadd_vv_f32m1(_pA0, _pB2, _sum2, vl); + + pA += 1; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + vse32_v_f32m1(outptr0 + 8, _sum2, vl); + outptr0 += 12; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + vse32_v_f32m1(outptr + 8, _sum2, vl); + } + + outptr += 12; + } + for (; jj + 7 < max_jj; jj += 8) + { + vfloat32m1_t _sum0; + vfloat32m1_t _sum1; + + if (k == 0) + { + _sum0 = vfmv_v_f_f32m1(0.f, vl); + _sum1 = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum0 = vfmv_v_f_f32m1(pC[0], vl); + _sum1 = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum0 = vle32_v_f32m1(pC, vl); + _sum1 = vle32_v_f32m1(pC + 4, vl); + pC += 8; + } + } + } + else + { + _sum0 = vle32_v_f32m1(outptr, vl); + _sum1 = vle32_v_f32m1(outptr + 4, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB0 = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pB1 = vle32_v_f32m1(pB + 4, vl); + + vfloat32m1_t _pA0 = vfmv_v_f_f32m1(pA[0], vl); + _sum0 = vfmadd_vv_f32m1(_pA0, _pB0, _sum0, vl); + _sum1 = vfmadd_vv_f32m1(_pA0, _pB1, _sum1, vl); + + pA += 1; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum0, vl); + vse32_v_f32m1(outptr0 + 4, _sum1, vl); + outptr0 += 8; + } + } + else + { + vse32_v_f32m1(outptr, _sum0, vl); + vse32_v_f32m1(outptr + 4, _sum1, vl); + } + + outptr += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + vfloat32m1_t _sum; + + if (k == 0) + { + _sum = vfmv_v_f_f32m1(0.f, vl); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _sum = vfmv_v_f_f32m1(pC[0], vl); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _sum = vle32_v_f32m1(pC, vl); + pC += 4; + } + } + } + else + { + _sum = vle32_v_f32m1(outptr, vl); + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + vfloat32m1_t _pB = vle32_v_f32m1(pB, vl); + vfloat32m1_t _pA = vfmv_v_f_f32m1(pA[0], vl); + + _sum = vfmadd_vv_f32m1(_pA, _pB, _sum, vl); + + pA += 1; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vse32_v_f32m1(outptr0, _sum, vl); + outptr0 += 4; + } + } + else + { + vse32_v_f32m1(outptr, _sum, vl); + } + + outptr += 4; + } +#endif // __riscv_vector + for (; jj + 1 < max_jj; jj += 2) + { + float sum0; + float sum1; + + if (k == 0) + { + sum0 = 0.f; + sum1 = 0.f; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + sum0 = pC[0]; + sum1 = pC[1]; + pC += 2; + } + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + + pA += 1; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0 += 2; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + float sum; + + if (k == 0) + { + sum = 0.f; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + sum = pC[0]; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + sum = pC[0]; + pC += 1; + } + } + } + else + { + sum = outptr[0]; + } + + const float* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum; + outptr0++; + } + } + else + { + outptr[0] = sum; + } + + outptr += 1; + } + + pAT += max_kk; + } +} + +static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size = get_cpu_level2_cache_size(); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + int tile_size = (int)sqrtf((float)l2_cache_size / 3 / sizeof(float)); + + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + TILE_K = std::max(8, tile_size / 8 * 8); + + if (K > 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(float) / TILE_K); + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + } + + if (nT > 1) + { + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { + TILE_M = (constant_TILE_M + 7) / 8 * 8; + } + + if (constant_TILE_N > 0) + { + TILE_N = (constant_TILE_N + 3) / 4 * 4; + } + + if (constant_TILE_K > 0) + { + TILE_K = (constant_TILE_K + 7) / 8 * 8; + } +} + +static int gemm_riscv(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 4u, opt.workspace_allocator); + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + // pack B +#if TIME_TEST + gettimeofday(&start_time, NULL); +#endif + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + else + { + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +static int gemm_AT_riscv(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + int nn_N = (N + TILE_N - 1) / TILE_N; + int nn_K = (K + TILE_K - 1) / TILE_K; + + Mat BT(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, opt.workspace_allocator); + + const int nn_NK = nn_N * nn_K; + + // pack B + #pragma omp parallel for num_threads(nT) + for (int ppjk = 0; ppjk < nn_NK; ppjk++) + { + const int ppj = ppjk / nn_K; + const int ppk = ppjk % nn_K; + + const int j = ppj * TILE_N; + const int k = ppk * TILE_K; + + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + else + { + transpose_pack_B_tile(B, BT_tile, j, max_jj, k, max_kk, vl); + } + } + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +static int gemm_BT_riscv(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat ATX(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, nT, 4u, opt.workspace_allocator); + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + // shadowed variable for less openmp task args + const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = ATX.channel(get_omp_thread_num()).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (j == 0) + { + if (transA) + { + transpose_pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + else + { + pack_A_tile(A, AT_tile, i, max_ii, k, max_kk, vl); + } + } + + bool k_end = !output_transpose && k + TILE_K >= K; + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +static int gemm_AT_BT_riscv(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int N, int K, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, size_t vl, const Option& opt) +{ + // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT); + + // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K); + + int nn_M = (M + TILE_M - 1) / TILE_M; + // int nn_N = (N + TILE_N - 1) / TILE_N; + + Mat topT; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT.create(TILE_N * TILE_M, 1, nT, 4u, opt.workspace_allocator); + + #pragma omp parallel for num_threads(nT) + for (int ppi = 0; ppi < nn_M; ppi++) + { + const int i = ppi * TILE_M; + + const int max_ii = std::min((M - i), TILE_M); + + Mat topT_tile; + if (K > TILE_K || broadcast_type_C == 3 || output_transpose) + topT_tile = topT.channel(get_omp_thread_num()); + + for (int j = 0; j < N; j += TILE_N) + { + const int max_jj = std::min((N - j), TILE_N); + + if (broadcast_type_C == 3) + { + pack_A_tile(C, topT_tile, i, max_ii, j, max_jj, vl); + } + + const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_kk = std::min((K - k), TILE_K); + + // NCNN_LOGE("max_ii/jj/kk = %d %d %d", max_ii, max_jj, max_kk); + + Mat AT_tile = AT.channel(i / TILE_M).row_range(k / TILE_K, 1); + + Mat BT_tile = BT.channel(j / TILE_N).row_range(k / TILE_K, 1); + + bool k_end = !output_transpose && k + TILE_K >= K; + + gemm_transB_packed_tile(AT_tile, BT_tile, CT_tile, topT_tile, top_blob, broadcast_type_C, i, max_ii, j, max_jj, k, max_kk, k_end, vl); + } + + if (output_transpose) + { + transpose_unpack_output_tile(topT_tile, top_blob, i, max_ii, j, max_jj, vl); + } + } + } + + return 0; +} + +int Gemm_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + std::vector bottom_blobs(1, bottom_blob); + std::vector top_blobs(1, top_blob); + int ret = forward(bottom_blobs, top_blobs, opt); + top_blob = top_blobs[0]; + return ret; +} + +int Gemm_riscv::create_pipeline(const Option& opt) +{ + if (constantA) + { + const int M = constantM; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_M = (M + TILE_M - 1) / TILE_M; + + AT_data.create(TILE_K * TILE_M, (K + TILE_K - 1) / TILE_K, (M + TILE_M - 1) / TILE_M, 4u, (Allocator*)0); + if (AT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_M; ppj++) + { + const int i = ppj * TILE_M; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_ii = std::min((M - i), TILE_M); + const int max_kk = std::min((K - k), TILE_K); + + Mat AT_tile = AT_data.channel(i / TILE_M).row_range(k / TILE_K, 1); + + if (transA) + { + transpose_pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk, vl); + } + else + { + pack_A_tile(A_data, AT_tile, i, max_ii, k, max_kk, vl); + } + } + } + + if (opt.lightmode) + { + A_data.release(); + } + } + + if (constantB) + { + const int N = constantN; + const int K = constantK; + + int TILE_M, TILE_N, TILE_K; + get_optimal_tile_mnk(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads); + + const int nn_N = (N + TILE_N - 1) / TILE_N; + + BT_data.create(TILE_K * TILE_N, (K + TILE_K - 1) / TILE_K, (N + TILE_N - 1) / TILE_N, 4u, (Allocator*)0); + if (BT_data.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ppj = 0; ppj < nn_N; ppj++) + { + const int j = ppj * TILE_N; + + for (int k = 0; k < K; k += TILE_K) + { + const int max_jj = std::min((N - j), TILE_N); + const int max_kk = std::min((K - k), TILE_K); + + Mat BT_tile = BT_data.channel(j / TILE_N).row_range(k / TILE_K, 1); + + if (transB) + { + pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk, vl); + } + else + { + transpose_pack_B_tile(B_data, BT_tile, j, max_jj, k, max_kk, vl); + } + } + } + + if (opt.lightmode) + { + B_data.release(); + } + } + + if (constantC && constant_broadcast_type_C != -1) + { + CT_data = C_data; + +#if __riscv_vector + if (constant_broadcast_type_C == 3 && opt.use_packing_layout) + { + int C_elempack = constantM % 4 == 0 ? 4 : 1; + convert_packing(C_data, CT_data, C_elempack, opt); + } +#endif // __riscv_vector + + // pre-multiply C with beta + if (beta != 1.f) + { + Mat C2; + C2.create_like(CT_data); + + const int size = CT_data.total() * CT_data.elempack; + for (int i = 0; i < size; i++) + { + C2[i] = CT_data[i] * beta; + } + + CT_data = C2; + } + + if (opt.lightmode) + { + C_data.release(); + } + } + + if (constantA || constantB || constantC) + { + nT = opt.num_threads; + } + + return 0; +} + +int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + // 1xN + broadcast_type_C = 4; + } + + // pre-multiply C with beta + if (beta != 1.f) + { + Mat CT_data; + CT_data.create_like(C, opt.workspace_allocator); + + const int size = C.total() * C.elempack; + for (int i = 0; i < size; i++) + { + CT_data[i] = C[i] * beta; + } + + C = CT_data; + } + } + } + + int out_elempack = 1; +#if __riscv_vector + if (opt.use_packing_layout) + { + int outh = output_transpose ? N : M; + out_elempack = outh % 4 == 0 ? 4 : 1; + } +#endif // __riscv_vector + if (output_elempack) + out_elempack = output_elempack; + size_t out_elemsize = 4u * out_elempack; + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_riscv(AT_data, BT_data, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_riscv(AT_data, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_riscv(A, BT_data, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_riscv(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, vl, opt); + } + if (ret != 0) + return ret; + + // multiply top_blob with alpha + if (alpha != 1.f) + { + const int size = top_blob.total() * out_elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < size; i++) + { + top_blob[i] *= alpha; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/riscv/gemm_riscv.h b/src/layer/riscv/gemm_riscv.h new file mode 100644 index 000000000000..b92add638915 --- /dev/null +++ b/src/layer/riscv/gemm_riscv.h @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_GEMM_RISCV_H +#define LAYER_GEMM_RISCV_H + +#include "gemm.h" + +namespace ncnn { + +class Gemm_riscv : virtual public Gemm +{ +public: + Gemm_riscv(); + + virtual int create_pipeline(const Option& opt); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + + // public: + int nT; + size_t vl; + Mat AT_data; + Mat BT_data; + Mat CT_data; +}; + +} // namespace ncnn + +#endif // LAYER_GEMM_RISCV_H diff --git a/src/layer/riscv/riscv_usability.h b/src/layer/riscv/riscv_usability.h index 596bf4435c64..938d3ce39985 100644 --- a/src/layer/riscv/riscv_usability.h +++ b/src/layer/riscv/riscv_usability.h @@ -86,6 +86,282 @@ static inline vfloat32m8_t vle32_v_f32m8_f32m1(const float* ptr) return vloxei32_v_f32m8(ptr, bindex, vl); } +static inline void transpose8x8_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3h, + vfloat32m1_t& _r4l, vfloat32m1_t& _r4h, + vfloat32m1_t& _r5l, vfloat32m1_t& _r5h, + vfloat32m1_t& _r6l, vfloat32m1_t& _r6h, + vfloat32m1_t& _r7l, vfloat32m1_t& _r7h, size_t vl) +{ + float tmp[8][8]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 8, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 8, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 8, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 8, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 8, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 8, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 8, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 8, _r3h, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 8, _r4l, vl); + vsse32_v_f32m1(&tmp[4][4], sizeof(float) * 8, _r4h, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 8, _r5l, vl); + vsse32_v_f32m1(&tmp[4][5], sizeof(float) * 8, _r5h, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 8, _r6l, vl); + vsse32_v_f32m1(&tmp[4][6], sizeof(float) * 8, _r6h, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 8, _r7l, vl); + vsse32_v_f32m1(&tmp[4][7], sizeof(float) * 8, _r7h, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 1 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 3 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 4 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 7 * 4, vl); + _r4l = vle32_v_f32m1(ptr + 8 * 4, vl); + _r4h = vle32_v_f32m1(ptr + 9 * 4, vl); + _r5l = vle32_v_f32m1(ptr + 10 * 4, vl); + _r5h = vle32_v_f32m1(ptr + 11 * 4, vl); + _r6l = vle32_v_f32m1(ptr + 12 * 4, vl); + _r6h = vle32_v_f32m1(ptr + 13 * 4, vl); + _r7l = vle32_v_f32m1(ptr + 14 * 4, vl); + _r7h = vle32_v_f32m1(ptr + 15 * 4, vl); +} + +static inline void transpose4x4_ps(vfloat32m1_t& _r0, vfloat32m1_t& _r1, vfloat32m1_t& _r2, vfloat32m1_t& _r3, size_t vl) +{ + float tmp[4][4]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 4, _r0, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 4, _r1, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 4, _r2, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 4, _r3, vl); + float* ptr = (float*)tmp; + _r0 = vle32_v_f32m1(ptr + 0 * 4, vl); + _r1 = vle32_v_f32m1(ptr + 1 * 4, vl); + _r2 = vle32_v_f32m1(ptr + 2 * 4, vl); + _r3 = vle32_v_f32m1(ptr + 3 * 4, vl); +} + +static inline void transpose8x12_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3h, + vfloat32m1_t& _r4l, vfloat32m1_t& _r4h, + vfloat32m1_t& _r5l, vfloat32m1_t& _r5h, + vfloat32m1_t& _r6l, vfloat32m1_t& _r6h, + vfloat32m1_t& _r7l, vfloat32m1_t& _r7h, + vfloat32m1_t& _r8l, vfloat32m1_t& _r8h, + vfloat32m1_t& _r9l, vfloat32m1_t& _r9h, + vfloat32m1_t& _ral, vfloat32m1_t& _rah, + vfloat32m1_t& _rbl, vfloat32m1_t& _rbh, size_t vl) +{ + float tmp[8][12]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 12, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 12, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 12, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 12, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 12, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 12, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 12, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 12, _r3h, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 12, _r4l, vl); + vsse32_v_f32m1(&tmp[4][4], sizeof(float) * 12, _r4h, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 12, _r5l, vl); + vsse32_v_f32m1(&tmp[4][5], sizeof(float) * 12, _r5h, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 12, _r6l, vl); + vsse32_v_f32m1(&tmp[4][6], sizeof(float) * 12, _r6h, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 12, _r7l, vl); + vsse32_v_f32m1(&tmp[4][7], sizeof(float) * 12, _r7h, vl); + vsse32_v_f32m1(&tmp[0][8], sizeof(float) * 12, _r8l, vl); + vsse32_v_f32m1(&tmp[4][8], sizeof(float) * 12, _r8h, vl); + vsse32_v_f32m1(&tmp[0][9], sizeof(float) * 12, _r9l, vl); + vsse32_v_f32m1(&tmp[4][9], sizeof(float) * 12, _r9h, vl); + vsse32_v_f32m1(&tmp[0][10], sizeof(float) * 12, _ral, vl); + vsse32_v_f32m1(&tmp[4][10], sizeof(float) * 12, _rah, vl); + vsse32_v_f32m1(&tmp[0][11], sizeof(float) * 12, _rbl, vl); + vsse32_v_f32m1(&tmp[4][11], sizeof(float) * 12, _rbh, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 1 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 3 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 4 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 7 * 4, vl); + _r4l = vle32_v_f32m1(ptr + 8 * 4, vl); + _r4h = vle32_v_f32m1(ptr + 9 * 4, vl); + _r5l = vle32_v_f32m1(ptr + 10 * 4, vl); + _r5h = vle32_v_f32m1(ptr + 11 * 4, vl); + _r6l = vle32_v_f32m1(ptr + 12 * 4, vl); + _r6h = vle32_v_f32m1(ptr + 13 * 4, vl); + _r7l = vle32_v_f32m1(ptr + 14 * 4, vl); + _r7h = vle32_v_f32m1(ptr + 15 * 4, vl); + _r8l = vle32_v_f32m1(ptr + 16 * 4, vl); + _r8h = vle32_v_f32m1(ptr + 17 * 4, vl); + _r9l = vle32_v_f32m1(ptr + 18 * 4, vl); + _r9h = vle32_v_f32m1(ptr + 19 * 4, vl); + _ral = vle32_v_f32m1(ptr + 20 * 4, vl); + _rah = vle32_v_f32m1(ptr + 21 * 4, vl); + _rbl = vle32_v_f32m1(ptr + 22 * 4, vl); + _rbh = vle32_v_f32m1(ptr + 23 * 4, vl); +} + +static inline void transpose12x8_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0m, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1m, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2m, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3m, vfloat32m1_t& _r3h, + vfloat32m1_t& _r4l, vfloat32m1_t& _r4m, vfloat32m1_t& _r4h, + vfloat32m1_t& _r5l, vfloat32m1_t& _r5m, vfloat32m1_t& _r5h, + vfloat32m1_t& _r6l, vfloat32m1_t& _r6m, vfloat32m1_t& _r6h, + vfloat32m1_t& _r7l, vfloat32m1_t& _r7m, vfloat32m1_t& _r7h, size_t vl) +{ + float tmp[12][8]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 8, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 8, _r0m, vl); + vsse32_v_f32m1(&tmp[8][0], sizeof(float) * 8, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 8, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 8, _r1m, vl); + vsse32_v_f32m1(&tmp[8][0], sizeof(float) * 8, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 8, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 8, _r2m, vl); + vsse32_v_f32m1(&tmp[8][2], sizeof(float) * 8, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 8, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 8, _r3m, vl); + vsse32_v_f32m1(&tmp[8][3], sizeof(float) * 8, _r3h, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 8, _r4l, vl); + vsse32_v_f32m1(&tmp[4][4], sizeof(float) * 8, _r4m, vl); + vsse32_v_f32m1(&tmp[8][4], sizeof(float) * 8, _r4h, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 8, _r5l, vl); + vsse32_v_f32m1(&tmp[4][5], sizeof(float) * 8, _r5m, vl); + vsse32_v_f32m1(&tmp[8][5], sizeof(float) * 8, _r5h, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 8, _r6l, vl); + vsse32_v_f32m1(&tmp[4][6], sizeof(float) * 8, _r6m, vl); + vsse32_v_f32m1(&tmp[8][6], sizeof(float) * 8, _r6h, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 8, _r7l, vl); + vsse32_v_f32m1(&tmp[4][7], sizeof(float) * 8, _r7m, vl); + vsse32_v_f32m1(&tmp[8][7], sizeof(float) * 8, _r7h, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0m = vle32_v_f32m1(ptr + 1 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 3 * 4, vl); + _r1m = vle32_v_f32m1(ptr + 4 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r2m = vle32_v_f32m1(ptr + 7 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 8 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 9 * 4, vl); + _r3m = vle32_v_f32m1(ptr + 10 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 11 * 4, vl); + _r4l = vle32_v_f32m1(ptr + 12 * 4, vl); + _r4m = vle32_v_f32m1(ptr + 13 * 4, vl); + _r4h = vle32_v_f32m1(ptr + 14 * 4, vl); + _r5l = vle32_v_f32m1(ptr + 15 * 4, vl); + _r5m = vle32_v_f32m1(ptr + 16 * 4, vl); + _r5h = vle32_v_f32m1(ptr + 17 * 4, vl); + _r6l = vle32_v_f32m1(ptr + 18 * 4, vl); + _r6m = vle32_v_f32m1(ptr + 19 * 4, vl); + _r6h = vle32_v_f32m1(ptr + 20 * 4, vl); + _r7l = vle32_v_f32m1(ptr + 21 * 4, vl); + _r7m = vle32_v_f32m1(ptr + 22 * 4, vl); + _r7h = vle32_v_f32m1(ptr + 23 * 4, vl); +} + +static inline void transpose4x8_ps(vfloat32m1_t& _r0, vfloat32m1_t& _r1, vfloat32m1_t& _r2, vfloat32m1_t& _r3, vfloat32m1_t& _r4, vfloat32m1_t& _r5, vfloat32m1_t& _r6, vfloat32m1_t& _r7, size_t vl) +{ + float tmp[4][8]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 8, _r0, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 8, _r1, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 8, _r2, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 8, _r3, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 8, _r4, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 8, _r5, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 8, _r6, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 8, _r7, vl); + float* ptr = (float*)tmp; + _r0 = vle32_v_f32m1(ptr + 0 * 4, vl); + _r1 = vle32_v_f32m1(ptr + 1 * 4, vl); + _r2 = vle32_v_f32m1(ptr + 2 * 4, vl); + _r3 = vle32_v_f32m1(ptr + 3 * 4, vl); + _r4 = vle32_v_f32m1(ptr + 4 * 4, vl); + _r5 = vle32_v_f32m1(ptr + 5 * 4, vl); + _r6 = vle32_v_f32m1(ptr + 6 * 4, vl); + _r7 = vle32_v_f32m1(ptr + 7 * 4, vl); +} + +static inline void transpose4x12_ps(vfloat32m1_t& _r0, vfloat32m1_t& _r1, vfloat32m1_t& _r2, vfloat32m1_t& _r3, vfloat32m1_t& _r4, vfloat32m1_t& _r5, vfloat32m1_t& _r6, vfloat32m1_t& _r7, vfloat32m1_t& _r8, vfloat32m1_t& _r9, vfloat32m1_t& _ra, vfloat32m1_t& _rb, size_t vl) +{ + float tmp[4][12]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 12, _r0, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 12, _r1, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 12, _r2, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 12, _r3, vl); + vsse32_v_f32m1(&tmp[0][4], sizeof(float) * 12, _r4, vl); + vsse32_v_f32m1(&tmp[0][5], sizeof(float) * 12, _r5, vl); + vsse32_v_f32m1(&tmp[0][6], sizeof(float) * 12, _r6, vl); + vsse32_v_f32m1(&tmp[0][7], sizeof(float) * 12, _r7, vl); + vsse32_v_f32m1(&tmp[0][8], sizeof(float) * 12, _r8, vl); + vsse32_v_f32m1(&tmp[0][9], sizeof(float) * 12, _r9, vl); + vsse32_v_f32m1(&tmp[0][10], sizeof(float) * 12, _ra, vl); + vsse32_v_f32m1(&tmp[0][11], sizeof(float) * 12, _rb, vl); + float* ptr = (float*)tmp; + _r0 = vle32_v_f32m1(ptr + 0 * 4, vl); + _r1 = vle32_v_f32m1(ptr + 1 * 4, vl); + _r2 = vle32_v_f32m1(ptr + 2 * 4, vl); + _r3 = vle32_v_f32m1(ptr + 3 * 4, vl); + _r4 = vle32_v_f32m1(ptr + 4 * 4, vl); + _r5 = vle32_v_f32m1(ptr + 5 * 4, vl); + _r6 = vle32_v_f32m1(ptr + 6 * 4, vl); + _r7 = vle32_v_f32m1(ptr + 7 * 4, vl); + _r8 = vle32_v_f32m1(ptr + 8 * 4, vl); + _r9 = vle32_v_f32m1(ptr + 9 * 4, vl); + _ra = vle32_v_f32m1(ptr + 10 * 4, vl); + _rb = vle32_v_f32m1(ptr + 11 * 4, vl); +} + +static inline void transpose8x4_ps(vfloat32m1_t& _r0l, vfloat32m1_t& _r0h, + vfloat32m1_t& _r1l, vfloat32m1_t& _r1h, + vfloat32m1_t& _r2l, vfloat32m1_t& _r2h, + vfloat32m1_t& _r3l, vfloat32m1_t& _r3h, size_t vl) +{ + float tmp[8][4]; + vsse32_v_f32m1(&tmp[0][0], sizeof(float) * 4, _r0l, vl); + vsse32_v_f32m1(&tmp[4][0], sizeof(float) * 4, _r0h, vl); + vsse32_v_f32m1(&tmp[0][1], sizeof(float) * 4, _r1l, vl); + vsse32_v_f32m1(&tmp[4][1], sizeof(float) * 4, _r1h, vl); + vsse32_v_f32m1(&tmp[0][2], sizeof(float) * 4, _r2l, vl); + vsse32_v_f32m1(&tmp[4][2], sizeof(float) * 4, _r2h, vl); + vsse32_v_f32m1(&tmp[0][3], sizeof(float) * 4, _r3l, vl); + vsse32_v_f32m1(&tmp[4][3], sizeof(float) * 4, _r3h, vl); + float* ptr = (float*)tmp; + _r0l = vle32_v_f32m1(ptr + 0 * 4, vl); + _r0h = vle32_v_f32m1(ptr + 1 * 4, vl); + _r1l = vle32_v_f32m1(ptr + 2 * 4, vl); + _r1h = vle32_v_f32m1(ptr + 3 * 4, vl); + _r2l = vle32_v_f32m1(ptr + 4 * 4, vl); + _r2h = vle32_v_f32m1(ptr + 5 * 4, vl); + _r3l = vle32_v_f32m1(ptr + 6 * 4, vl); + _r3h = vle32_v_f32m1(ptr + 7 * 4, vl); +} + +static inline void store_float_v2(vfloat32m1_t& vector1, vfloat32m1_t& vector2, float* buf, size_t vl) +{ + vsse32_v_f32m1(buf + 0, sizeof(float) * 2, vector1, vl); + vsse32_v_f32m1(buf + 1, sizeof(float) * 2, vector2, vl); +} + +static inline void store_float_v4(vfloat32m1_t& vector1, vfloat32m1_t& vector2, vfloat32m1_t& vector3, vfloat32m1_t& vector4, float* buf, size_t vl) +{ + vsse32_v_f32m1(buf + 0, sizeof(float) * 4, vector1, vl); + vsse32_v_f32m1(buf + 1, sizeof(float) * 4, vector2, vl); + vsse32_v_f32m1(buf + 2, sizeof(float) * 4, vector3, vl); + vsse32_v_f32m1(buf + 3, sizeof(float) * 4, vector4, vl); +} + #if __riscv_zfh static inline vfloat16m8_t vle16_v_f16m8_f16m1(const __fp16* ptr) { From 84aaedbe964874ed0ef98357a57b9233603cb8fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=B5=E5=B0=8F=E5=87=A1?= <2672931+whyb@users.noreply.github.com> Date: Fri, 20 Oct 2023 15:46:51 +0800 Subject: [PATCH 14/16] Added 5 devices(AMD*2, Intel*2, NVIDIA*1) benchmark result. (#5085) --- benchmark/README.md | 250 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) diff --git a/benchmark/README.md b/benchmark/README.md index ebe977bd46e0..211480893b8f 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -6648,3 +6648,253 @@ cooling_down = 0 vision_transformer min = 600.83 max = 666.35 avg = 617.33 FastestDet min = 6.05 max = 6.72 avg = 6.23 ``` + +### AMD Ryzen 9 5950X 16-Core of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 -1 0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 2.68 max = 3.10 avg = 2.77 + squeezenet_int8 min = 3.57 max = 4.72 avg = 4.04 + mobilenet min = 3.09 max = 5.44 avg = 3.38 + mobilenet_int8 min = 2.36 max = 3.40 avg = 2.74 + mobilenet_v2 min = 4.24 max = 4.81 avg = 4.40 + mobilenet_v3 min = 3.46 max = 3.93 avg = 3.58 + shufflenet min = 3.21 max = 4.54 avg = 4.01 + shufflenet_v2 min = 2.99 max = 4.49 avg = 3.34 + mnasnet min = 3.62 max = 4.31 avg = 3.83 + proxylessnasnet min = 4.06 max = 5.70 avg = 4.23 + efficientnet_b0 min = 5.60 max = 6.55 avg = 5.81 + efficientnetv2_b0 min = 6.83 max = 8.82 avg = 7.12 + regnety_400m min = 8.02 max = 9.75 avg = 8.34 + blazeface min = 1.34 max = 1.77 avg = 1.46 + googlenet min = 11.62 max = 15.95 avg = 12.70 + googlenet_int8 min = 7.43 max = 10.06 avg = 7.92 + resnet18 min = 8.39 max = 10.39 avg = 9.04 + resnet18_int8 min = 6.23 max = 8.64 avg = 6.75 + alexnet min = 7.78 max = 12.51 avg = 8.51 + vgg16 min = 53.85 max = 63.39 avg = 56.36 + vgg16_int8 min = 35.61 max = 46.94 avg = 38.08 + resnet50 min = 18.55 max = 24.46 avg = 19.81 + resnet50_int8 min = 11.95 max = 23.21 avg = 13.51 + squeezenet_ssd min = 10.01 max = 13.16 avg = 10.69 + squeezenet_ssd_int8 min = 9.29 max = 14.02 avg = 10.47 + mobilenet_ssd min = 6.38 max = 10.26 avg = 7.15 + mobilenet_ssd_int8 min = 4.69 max = 6.98 avg = 5.42 + mobilenet_yolo min = 17.63 max = 22.59 avg = 19.45 + mobilenetv2_yolov3 min = 11.79 max = 15.67 avg = 12.76 + yolov4-tiny min = 21.53 max = 25.79 avg = 22.46 + nanodet_m min = 7.16 max = 9.99 avg = 8.01 + yolo-fastest-1.1 min = 3.66 max = 5.00 avg = 4.38 + yolo-fastestv2 min = 3.52 max = 5.20 avg = 4.60 + vision_transformer min = 67.01 max = 93.71 avg = 78.48 + FastestDet min = 4.44 max = 8.62 avg = 4.69 +``` + +### AMD Radeon RX 6900 XT of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 0 0 +[0 AMD Radeon RX 6900 XT] queueC=1[2] queueG=0[1] queueT=2[2] +[0 AMD Radeon RX 6900 XT] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 AMD Radeon RX 6900 XT] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 AMD Radeon RX 6900 XT] subgroup=64 basic/vote/ballot/shuffle=1/1/1/1 +[0 AMD Radeon RX 6900 XT] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = 0 +cooling_down = 0 + squeezenet min = 2.19 max = 2.70 avg = 2.47 + squeezenet_int8 min = 3.94 max = 4.51 avg = 4.18 + mobilenet min = 2.03 max = 2.63 avg = 2.28 + mobilenet_int8 min = 2.56 max = 3.34 avg = 2.69 + mobilenet_v2 min = 2.29 max = 2.98 avg = 2.62 + mobilenet_v3 min = 2.31 max = 3.10 avg = 2.75 + shufflenet min = 1.89 max = 2.61 avg = 2.30 + shufflenet_v2 min = 2.17 max = 3.04 avg = 2.59 + mnasnet min = 2.19 max = 2.98 avg = 2.69 + proxylessnasnet min = 2.12 max = 4.08 avg = 2.62 + efficientnet_b0 min = 3.62 max = 5.27 avg = 4.21 + efficientnetv2_b0 min = 6.09 max = 7.15 avg = 6.49 + regnety_400m min = 2.55 max = 3.82 avg = 3.00 + blazeface min = 1.93 max = 2.56 avg = 2.28 + googlenet min = 3.35 max = 4.46 avg = 3.75 + googlenet_int8 min = 8.02 max = 12.84 avg = 9.15 + resnet18 min = 2.46 max = 3.14 avg = 2.84 + resnet18_int8 min = 6.37 max = 9.15 avg = 7.30 + alexnet min = 2.31 max = 2.91 avg = 2.69 + vgg16 min = 4.76 max = 5.79 avg = 5.24 + vgg16_int8 min = 35.94 max = 46.27 avg = 39.05 + resnet50 min = 3.25 max = 4.09 avg = 3.75 + resnet50_int8 min = 12.04 max = 20.53 avg = 14.61 + squeezenet_ssd min = 3.03 max = 5.31 avg = 3.66 + squeezenet_ssd_int8 min = 9.74 max = 13.46 avg = 10.42 + mobilenet_ssd min = 2.82 max = 4.75 avg = 3.39 + mobilenet_ssd_int8 min = 4.67 max = 6.76 avg = 5.30 + mobilenet_yolo min = 3.01 max = 3.67 avg = 3.34 + mobilenetv2_yolov3 min = 4.04 max = 6.46 avg = 4.55 + yolov4-tiny min = 5.75 max = 8.05 avg = 6.52 + nanodet_m min = 10.16 max = 14.97 avg = 13.11 + yolo-fastest-1.1 min = 2.36 max = 3.80 avg = 2.88 + yolo-fastestv2 min = 2.24 max = 3.19 avg = 2.80 + vision_transformer min = 20.43 max = 25.06 avg = 21.07 + FastestDet min = 2.49 max = 3.18 avg = 2.93 +``` + +### NVIDIA GeForce RTX 3060 Ti of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 0 0 +[0 NVIDIA GeForce RTX 3060 Ti] queueC=2[8] queueG=0[16] queueT=1[2] +[0 NVIDIA GeForce RTX 3060 Ti] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-matrix-16_8_8/16_8_16/16_16_16=1/1/1 +[1 Intel(R) UHD Graphics 770] queueC=0[1] queueG=0[1] queueT=0[1] +[1 Intel(R) UHD Graphics 770] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[1 Intel(R) UHD Graphics 770] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[1 Intel(R) UHD Graphics 770] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[1 Intel(R) UHD Graphics 770] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = 0 +cooling_down = 0 + squeezenet min = 0.80 max = 2.51 avg = 0.89 + squeezenet_int8 min = 2.81 max = 3.51 avg = 2.96 + mobilenet min = 0.70 max = 0.79 avg = 0.71 + mobilenet_int8 min = 2.95 max = 3.44 avg = 3.03 + mobilenet_v2 min = 1.09 max = 1.25 avg = 1.12 + mobilenet_v3 min = 1.33 max = 2.04 avg = 1.56 + shufflenet min = 1.20 max = 1.39 avg = 1.27 + shufflenet_v2 min = 1.50 max = 1.66 avg = 1.57 + mnasnet min = 1.11 max = 1.22 avg = 1.15 + proxylessnasnet min = 1.20 max = 1.63 avg = 1.24 + efficientnet_b0 min = 2.38 max = 3.21 avg = 2.61 + efficientnetv2_b0 min = 9.16 max = 11.35 avg = 9.63 + regnety_400m min = 1.86 max = 2.03 avg = 1.94 + blazeface min = 0.70 max = 1.10 avg = 0.76 + googlenet min = 2.11 max = 2.40 avg = 2.30 + googlenet_int8 min = 6.91 max = 7.88 avg = 7.17 + resnet18 min = 1.14 max = 1.47 avg = 1.19 + resnet18_int8 min = 4.96 max = 6.82 avg = 5.40 + alexnet min = 1.10 max = 1.85 avg = 1.19 + vgg16 min = 2.27 max = 3.97 avg = 2.46 + vgg16_int8 min = 19.02 max = 22.20 avg = 20.28 + resnet50 min = 2.00 max = 2.99 avg = 2.10 + resnet50_int8 min = 10.66 max = 13.30 avg = 11.29 + squeezenet_ssd min = 2.74 max = 3.44 avg = 2.90 + squeezenet_ssd_int8 min = 6.93 max = 7.95 avg = 7.19 + mobilenet_ssd min = 1.86 max = 2.07 avg = 1.96 + mobilenet_ssd_int8 min = 5.92 max = 6.48 avg = 6.09 + mobilenet_yolo min = 1.65 max = 2.58 avg = 1.78 + mobilenetv2_yolov3 min = 3.85 max = 4.11 avg = 3.96 + yolov4-tiny min = 6.54 max = 7.05 avg = 6.69 + nanodet_m min = 2.38 max = 3.28 avg = 2.72 + yolo-fastest-1.1 min = 1.73 max = 2.07 avg = 1.83 + yolo-fastestv2 min = 1.72 max = 1.92 avg = 1.80 + vision_transformer min = 53.91 max = 56.59 avg = 55.27 + FastestDet min = 1.48 max = 1.83 avg = 1.69 +``` + +### Intel(R) UHD Graphics 770 of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 1 0 +[0 NVIDIA GeForce RTX 3060 Ti] queueC=2[8] queueG=0[16] queueT=1[2] +[0 NVIDIA GeForce RTX 3060 Ti] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[0 NVIDIA GeForce RTX 3060 Ti] fp16-matrix-16_8_8/16_8_16/16_16_16=1/1/1 +[1 Intel(R) UHD Graphics 770] queueC=0[1] queueG=0[1] queueT=0[1] +[1 Intel(R) UHD Graphics 770] bugsbn1=0 bugbilz=0 bugcopc=0 bugihfa=0 +[1 Intel(R) UHD Graphics 770] fp16-p/s/a=1/1/1 int8-p/s/a=1/1/1 +[1 Intel(R) UHD Graphics 770] subgroup=32 basic/vote/ballot/shuffle=1/1/1/1 +[1 Intel(R) UHD Graphics 770] fp16-matrix-16_8_8/16_8_16/16_16_16=0/0/0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = 1 +cooling_down = 0 + squeezenet min = 3.11 max = 4.47 avg = 3.45 + squeezenet_int8 min = 1.89 max = 2.84 avg = 2.23 + mobilenet min = 4.98 max = 5.67 avg = 5.18 + mobilenet_int8 min = 2.54 max = 3.17 avg = 2.98 + mobilenet_v2 min = 4.03 max = 4.89 avg = 4.37 + mobilenet_v3 min = 4.45 max = 5.68 avg = 4.86 + shufflenet min = 3.42 max = 4.42 avg = 3.79 + shufflenet_v2 min = 3.00 max = 4.01 avg = 3.30 + mnasnet min = 4.21 max = 5.12 avg = 4.51 + proxylessnasnet min = 4.62 max = 5.64 avg = 4.90 + efficientnet_b0 min = 7.82 max = 8.63 avg = 8.10 + efficientnetv2_b0 min = 34.52 max = 36.34 avg = 35.29 + regnety_400m min = 6.07 max = 7.31 avg = 6.44 + blazeface min = 1.54 max = 1.67 avg = 1.59 + googlenet min = 11.53 max = 12.64 avg = 11.89 + googlenet_int8 min = 13.71 max = 15.52 avg = 14.38 + resnet18 min = 10.75 max = 12.94 avg = 11.07 + resnet18_int8 min = 9.04 max = 11.05 avg = 9.53 + alexnet min = 13.64 max = 14.37 avg = 13.98 + vgg16 min = 38.53 max = 40.16 avg = 39.22 + vgg16_int8 min = 16.04 max = 21.16 avg = 19.35 + resnet50 min = 25.61 max = 28.22 avg = 26.62 + resnet50_int8 min = 7.72 max = 12.83 avg = 10.29 + squeezenet_ssd min = 10.34 max = 15.88 avg = 14.75 + squeezenet_ssd_int8 min = 4.63 max = 7.13 avg = 5.66 + mobilenet_ssd min = 11.35 max = 13.06 avg = 12.44 + mobilenet_ssd_int8 min = 4.21 max = 6.31 avg = 5.32 + mobilenet_yolo min = 20.14 max = 22.92 avg = 21.94 + mobilenetv2_yolov3 min = 12.58 max = 14.88 avg = 14.21 + yolov4-tiny min = 20.62 max = 25.58 avg = 24.39 + nanodet_m min = 7.75 max = 12.49 avg = 11.42 + yolo-fastest-1.1 min = 3.68 max = 6.49 avg = 5.54 + yolo-fastestv2 min = 4.32 max = 5.39 avg = 4.51 + vision_transformer min = 796.51 max = 805.29 avg = 802.39 + FastestDet min = 2.89 max = 4.83 avg = 3.95 +``` + +### IntelĀ® Coreā„¢ i7-13700K of Desktop[2023-10-12] +``` +E:\github\ncnn\build-ncnn-vs2019\benchmark\Release>benchncnn.exe 100 16 0 -1 0 +loop_count = 100 +num_threads = 16 +powersave = 0 +gpu_device = -1 +cooling_down = 0 + squeezenet min = 1.69 max = 2.63 avg = 2.12 + squeezenet_int8 min = 1.83 max = 3.03 avg = 2.26 + mobilenet min = 1.69 max = 2.64 avg = 2.24 + mobilenet_int8 min = 2.47 max = 3.06 avg = 2.84 + mobilenet_v2 min = 1.94 max = 3.47 avg = 2.47 + mobilenet_v3 min = 1.49 max = 2.74 avg = 1.87 + shufflenet min = 1.57 max = 3.00 avg = 1.82 + shufflenet_v2 min = 1.41 max = 1.72 avg = 1.51 + mnasnet min = 1.73 max = 2.94 avg = 2.13 + proxylessnasnet min = 2.08 max = 3.31 avg = 2.69 + efficientnet_b0 min = 3.20 max = 4.99 avg = 3.78 + efficientnetv2_b0 min = 3.51 max = 5.16 avg = 4.08 + regnety_400m min = 4.51 max = 10.29 avg = 6.18 + blazeface min = 0.52 max = 0.92 avg = 0.59 + googlenet min = 5.49 max = 7.48 avg = 6.26 + googlenet_int8 min = 4.83 max = 7.54 avg = 5.90 + resnet18 min = 4.05 max = 6.61 avg = 4.83 + resnet18_int8 min = 3.77 max = 5.70 avg = 4.57 + alexnet min = 3.60 max = 5.09 avg = 4.26 + vgg16 min = 25.19 max = 28.79 avg = 26.81 + vgg16_int8 min = 17.52 max = 21.79 avg = 19.80 + resnet50 min = 9.23 max = 13.15 avg = 11.34 + resnet50_int8 min = 7.77 max = 12.00 avg = 10.18 + squeezenet_ssd min = 4.33 max = 6.73 avg = 4.96 + squeezenet_ssd_int8 min = 4.77 max = 7.62 avg = 5.71 + mobilenet_ssd min = 3.70 max = 6.43 avg = 4.53 + mobilenet_ssd_int8 min = 4.16 max = 6.53 avg = 5.38 + mobilenet_yolo min = 11.27 max = 14.93 avg = 12.90 + mobilenetv2_yolov3 min = 7.41 max = 11.52 avg = 9.11 + yolov4-tiny min = 12.05 max = 18.96 avg = 14.15 + nanodet_m min = 3.39 max = 5.77 avg = 4.07 + yolo-fastest-1.1 min = 1.95 max = 3.85 avg = 2.30 + yolo-fastestv2 min = 1.91 max = 3.52 avg = 2.27 + vision_transformer min = 79.50 max = 99.93 avg = 88.91 + FastestDet min = 1.92 max = 2.72 avg = 2.19 +``` \ No newline at end of file From dc251281953ebbab6d58aacd3641a42f456b5d37 Mon Sep 17 00:00:00 2001 From: FhqTreap <45459183+FhqTreap@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:51:16 +0800 Subject: [PATCH 15/16] Vulkan conv1d (#5060) --- src/layer/vulkan/convolution1d_vulkan.cpp | 423 ++++++++++++++++++ src/layer/vulkan/convolution1d_vulkan.h | 53 +++ src/layer/vulkan/shader/convolution1d.comp | 177 ++++++++ .../vulkan/shader/convolution1d_pack1to4.comp | 177 ++++++++ .../vulkan/shader/convolution1d_pack1to8.comp | 186 ++++++++ .../vulkan/shader/convolution1d_pack4.comp | 208 +++++++++ .../vulkan/shader/convolution1d_pack4to1.comp | 177 ++++++++ .../vulkan/shader/convolution1d_pack4to8.comp | 270 +++++++++++ .../vulkan/shader/convolution1d_pack8.comp | 270 +++++++++++ .../vulkan/shader/convolution1d_pack8to1.comp | 178 ++++++++ .../vulkan/shader/convolution1d_pack8to4.comp | 220 +++++++++ 11 files changed, 2339 insertions(+) create mode 100644 src/layer/vulkan/convolution1d_vulkan.cpp create mode 100644 src/layer/vulkan/convolution1d_vulkan.h create mode 100644 src/layer/vulkan/shader/convolution1d.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack1to4.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack1to8.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack4.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack4to1.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack4to8.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack8.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack8to1.comp create mode 100644 src/layer/vulkan/shader/convolution1d_pack8to4.comp diff --git a/src/layer/vulkan/convolution1d_vulkan.cpp b/src/layer/vulkan/convolution1d_vulkan.cpp new file mode 100644 index 000000000000..f9d9135ceb90 --- /dev/null +++ b/src/layer/vulkan/convolution1d_vulkan.cpp @@ -0,0 +1,423 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convolution1d_vulkan.h" + +#include "layer_shader_type.h" +#include "layer_type.h" + +namespace ncnn { + +Convolution1D_vulkan::Convolution1D_vulkan() +{ + support_vulkan = true; + support_image_storage = true; + + padding = 0; + + pipeline_convolution1d = 0; +} + +int Convolution1D_vulkan::create_pipeline(const Option& _opt) +{ + if (dynamic_weight) + { + support_vulkan = false; + support_image_storage = false; + return 0; + } + + Option opt = _opt; + + const int maxk = kernel_w; + int num_input = weight_data_size / maxk / num_output; + + int elempack = opt.use_shader_pack8 && num_input % 8 == 0 ? 8 : num_input % 4 == 0 ? 4 : 1; + int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + + { + padding = ncnn::create_layer(ncnn::LayerType::Padding); + padding->vkdev = vkdev; + + ncnn::ParamDict pd; + pd.set(0, 0); + pd.set(1, 0); + pd.set(2, pad_left); + pd.set(3, pad_right); + pd.set(4, 0); + pd.set(5, pad_value); + + padding->load_param(pd); + + padding->create_pipeline(opt); + } + + { + Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output); + + weight_data_packed.create(maxk, num_input / elempack, num_output / out_elempack, (size_t)4 * elempack * out_elempack, elempack * out_elempack); + + for (int q = 0; q + (out_elempack - 1) < num_output; q += out_elempack) + { + float* g00 = weight_data_packed.channel(q / out_elempack); + + for (int p = 0; p + (elempack - 1) < num_input; p += elempack) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < out_elempack; i++) + { + const Mat k0 = weight_data_r2.channel(q + i); + + for (int j = 0; j < elempack; j++) + { + const float* k00 = k0.row(p + j); + g00[0] = k00[k]; + g00++; + } + } + } + } + } + } + + if (bias_term) + { + convert_packing(bias_data, bias_data_packed, out_elempack, opt); + } + + { + std::vector specializations(7 + 10); + specializations[0].i = kernel_w; + specializations[1].i = dilation_w; + specializations[2].i = stride_w; + specializations[3].i = bias_term; + specializations[4].i = activation_type; + specializations[5].f = activation_params.w >= 1 ? activation_params[0] : 0.f; + specializations[6].f = activation_params.w == 2 ? activation_params[1] : 0.f; + specializations[7 + 0].i = 0; + specializations[7 + 1].i = 0; + specializations[7 + 2].i = 0; + specializations[7 + 3].i = 0; + specializations[7 + 4].i = 0; + specializations[7 + 5].i = 0; + specializations[7 + 6].i = 0; + specializations[7 + 7].i = 0; + specializations[7 + 8].i = 0; + specializations[7 + 9].i = 0; + + int shader_type_index = -1; + if (elempack == 1 && out_elempack == 1) shader_type_index = LayerShaderType::convolution1d; + if (elempack == 4 && out_elempack == 4) shader_type_index = LayerShaderType::convolution1d_pack4; + if (elempack == 1 && out_elempack == 4) shader_type_index = LayerShaderType::convolution1d_pack1to4; + if (elempack == 4 && out_elempack == 1) shader_type_index = LayerShaderType::convolution1d_pack4to1; + if (elempack == 8 && out_elempack == 8) shader_type_index = LayerShaderType::convolution1d_pack8; + if (elempack == 1 && out_elempack == 8) shader_type_index = LayerShaderType::convolution1d_pack1to8; + if (elempack == 8 && out_elempack == 1) shader_type_index = LayerShaderType::convolution1d_pack8to1; + if (elempack == 4 && out_elempack == 8) shader_type_index = LayerShaderType::convolution1d_pack4to8; + if (elempack == 8 && out_elempack == 4) shader_type_index = LayerShaderType::convolution1d_pack8to4; + + pipeline_convolution1d = new Pipeline(vkdev); + pipeline_convolution1d->set_optimal_local_size_xyz(1, 1, 1); + pipeline_convolution1d->create(shader_type_index, opt, specializations); + } + + return 0; +} + +int Convolution1D_vulkan::destroy_pipeline(const Option& opt) +{ + if (padding) + { + padding->destroy_pipeline(opt); + delete padding; + padding = 0; + } + + delete pipeline_convolution1d; + pipeline_convolution1d = 0; + + return 0; +} + +int Convolution1D_vulkan::upload_model(VkTransfer& cmd, const Option& opt) +{ + if (padding) + { + padding->upload_model(cmd, opt); + } + + if (support_image_storage && opt.use_image_storage) + { + cmd.record_upload(weight_data_packed, weight_data_gpu_image, opt); + } + else + { + cmd.record_upload(weight_data_packed, weight_data_gpu, opt); + } + + weight_data_packed.release(); + + if (bias_term) + { + if (support_image_storage && opt.use_image_storage) + { + cmd.record_upload(bias_data_packed, bias_data_gpu_image, opt); + } + else + { + cmd.record_upload(bias_data_packed, bias_data_gpu, opt); + } + + bias_data_packed.release(); + } + + return 0; +} + +int Convolution1D_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + + VkMat bottom_blob_bordered = bottom_blob; + if (pad_left > 0 || pad_right > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + padding->forward(bottom_blob, bottom_blob_bordered, cmd, opt_pad); + } + else if (pad_left == -233 && pad_right == -233) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad / 2; + padding_params[3] = wpad - wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + else if (pad_left == -234 && pad_right == -234) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad - wpad / 2; + padding_params[3] = wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + + int outw = (bottom_blob_bordered.w - kernel_extent_w) / stride_w + 1; + + int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (out_elempack == 8) out_elemsize = 8 * 2u; + if (out_elempack == 4) out_elemsize = 4 * 2u; + if (out_elempack == 1) out_elemsize = 4u; + } + + top_blob.create(outw, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator); + if (top_blob.empty()) + return -100; + + std::vector bindings(4); + bindings[0] = bottom_blob_bordered; + bindings[1] = top_blob; + bindings[2] = weight_data_gpu; + bindings[3] = bias_data_gpu; + + std::vector constants(10); + constants[0].i = bottom_blob_bordered.dims; + constants[1].i = bottom_blob_bordered.w; + constants[2].i = bottom_blob_bordered.h; + constants[3].i = bottom_blob_bordered.c; + constants[4].i = bottom_blob_bordered.cstep; + constants[5].i = top_blob.dims; + constants[6].i = top_blob.w; + constants[7].i = top_blob.h; + constants[8].i = top_blob.c; + constants[9].i = top_blob.cstep; + + VkMat dispatcher; + dispatcher.w = (top_blob.w + 1) / 2; + dispatcher.h = (top_blob.h + 1) / 2; + dispatcher.c = (top_blob.c + 1) / 2; + + cmd.record_pipeline(pipeline_convolution1d, bindings, constants, dispatcher); + + return 0; +} + +int Convolution1D_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + + VkImageMat bottom_blob_bordered = bottom_blob; + if (pad_left > 0 || pad_right > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + padding->forward(bottom_blob, bottom_blob_bordered, cmd, opt_pad); + } + else if (pad_left == -233 && pad_right == -233) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkImageMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad / 2; + padding_params[3] = wpad - wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + else if (pad_left == -234 && pad_right == -234) + { + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + if (wpad > 0) + { + Option opt_pad = opt; + opt_pad.blob_vkallocator = opt.workspace_vkallocator; + + VkImageMat padding_param_blob(6, (size_t)4u, 1, opt.staging_vkallocator); + int* padding_params = padding_param_blob.mapped(); + + padding_params[0] = 0; + padding_params[1] = 0; + padding_params[2] = wpad - wpad / 2; + padding_params[3] = wpad / 2; + padding_params[4] = 0; + padding_params[5] = 0; + + std::vector padding_inputs(2); + padding_inputs[0] = bottom_blob; + padding_inputs[1] = padding_param_blob; + + std::vector padding_outputs(1); + padding->forward(padding_inputs, padding_outputs, cmd, opt_pad); + bottom_blob_bordered = padding_outputs[0]; + } + } + + int outw = (bottom_blob_bordered.w - kernel_extent_w) / stride_w + 1; + + int out_elempack = opt.use_shader_pack8 && num_output % 8 == 0 ? 8 : num_output % 4 == 0 ? 4 : 1; + + size_t out_elemsize = elemsize / elempack * out_elempack; + + if (opt.use_fp16_packed && !opt.use_fp16_storage) + { + if (out_elempack == 8) out_elemsize = 8 * 2u; + if (out_elempack == 4) out_elemsize = 4 * 2u; + if (out_elempack == 1) out_elemsize = 4u; + } + + top_blob.create(outw, num_output / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator); + if (top_blob.empty()) + return -100; + + std::vector bindings(4); + bindings[0] = bottom_blob_bordered; + bindings[1] = top_blob; + bindings[2] = weight_data_gpu_image; + bindings[3] = bias_data_gpu_image; + + std::vector constants(10); + constants[0].i = bottom_blob_bordered.dims; + constants[1].i = bottom_blob_bordered.w; + constants[2].i = bottom_blob_bordered.h; + constants[3].i = bottom_blob_bordered.c; + constants[4].i = 0; //bottom_blob_bordered.cstep; + constants[5].i = top_blob.dims; + constants[6].i = top_blob.w; + constants[7].i = top_blob.h; + constants[8].i = top_blob.c; + constants[9].i = 0; //top_blob.cstep; + + VkImageMat dispatcher; + dispatcher.w = (top_blob.w + 1) / 2; + dispatcher.h = (top_blob.h + 1) / 2; + dispatcher.c = (top_blob.c + 1) / 2; + + cmd.record_pipeline(pipeline_convolution1d, bindings, constants, dispatcher); + + return 0; +} + +} // namespace ncnn \ No newline at end of file diff --git a/src/layer/vulkan/convolution1d_vulkan.h b/src/layer/vulkan/convolution1d_vulkan.h new file mode 100644 index 000000000000..4fb22040daa2 --- /dev/null +++ b/src/layer/vulkan/convolution1d_vulkan.h @@ -0,0 +1,53 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_CONVOLUTION1D_VULKAN_H +#define LAYER_CONVOLUTION1D_VULKAN_H + +#include "convolution1d.h" + +namespace ncnn { + +class Convolution1D_vulkan : virtual public Convolution1D +{ +public: + Convolution1D_vulkan(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + virtual int upload_model(VkTransfer& cmd, const Option& opt); + + using Convolution1D::forward; + virtual int forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, const Option& opt) const; + virtual int forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, VkCompute& cmd, const Option& opt) const; + +public: + ncnn::Layer* padding; + + Mat weight_data_packed; + Mat bias_data_packed; + + VkMat weight_data_gpu; + VkMat bias_data_gpu; + + VkImageMat weight_data_gpu_image; + VkImageMat bias_data_gpu_image; + + Pipeline* pipeline_convolution1d; +}; + +} // namespace ncnn + +#endif // LAYER_CONVOLUTION1D_VULKAN_H diff --git a/src/layer/vulkan/shader/convolution1d.comp b/src/layer/vulkan/shader/convolution1d.comp new file mode 100644 index 000000000000..3403a50f1939 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d.comp @@ -0,0 +1,177 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfp weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfp bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afp sum0 = afp(0.0f); + afp sum1 = afp(0.0f); + afp sum2 = afp(0.0f); + afp sum3 = afp(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld1(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld1(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld1(bias_data, gy2.x); + sum2 = buffer_ld1(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afp v0 = image3d_ld1(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afp v1 = image3d_ld1(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afp k0 = image3d_ld1(weight_blob, ivec3(wx, y, gy2.x)); + afp k1 = image3d_ld1(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afp v0 = buffer_ld1(bottom_blob_data, v_offset.x + x * dilation_w); + afp v1 = buffer_ld1(bottom_blob_data, v_offset.y + x * dilation_w); + + afp k0 = buffer_ld1(weight_data, w_offset.x + x); + afp k1 = buffer_ld1(weight_data, w_offset.y + x); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afp(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afp(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afp(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afp(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st1(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st1(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st1(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st1(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st1(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st1(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack1to4.comp b/src/layer/vulkan/shader/convolution1d_pack1to4.comp new file mode 100644 index 000000000000..98e6fadd3c18 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack1to4.comp @@ -0,0 +1,177 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec4 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec4 sum0 = afpvec4(0.0f); + afpvec4 sum1 = afpvec4(0.0f); + afpvec4 sum2 = afpvec4(0.0f); + afpvec4 sum3 = afpvec4(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld4(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld4(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld4(bias_data, gy2.x); + sum2 = buffer_ld4(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afp v0 = image3d_ld1(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afp v1 = image3d_ld1(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec4 k0 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.x)); + afpvec4 k1 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afp v0 = buffer_ld1(bottom_blob_data, v_offset.x + x * dilation_w); + afp v1 = buffer_ld1(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec4 k0 = buffer_ld4(weight_data, w_offset.x + x); + afpvec4 k1 = buffer_ld4(weight_data, w_offset.y + x); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st4(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st4(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st4(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st4(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st4(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st4(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack1to8.comp b/src/layer/vulkan/shader/convolution1d_pack1to8.comp new file mode 100644 index 000000000000..c32bc2114e58 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack1to8.comp @@ -0,0 +1,186 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec8 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec8 sum0 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum1 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum2 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum3 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld8(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld8(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld8(bias_data, gy2.x); + sum2 = buffer_ld8(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afp v0 = image3d_ld1(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afp v1 = image3d_ld1(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.y)); + + sum0[0] += v0 * k0[0]; + sum0[1] += v0 * k0[1]; + sum1[0] += v1 * k0[0]; + sum1[1] += v1 * k0[1]; + sum2[0] += v0 * k1[0]; + sum2[1] += v0 * k1[1]; + sum3[0] += v1 * k1[0]; + sum3[1] += v1 * k1[1]; + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afp v0 = buffer_ld1(bottom_blob_data, v_offset.x + x * dilation_w); + afp v1 = buffer_ld1(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, w_offset.x + x); + afpvec8 k1 = buffer_ld8(weight_data, w_offset.y + x); + + sum0[0] += v0 * k0[0]; + sum0[1] += v0 * k0[1]; + sum1[0] += v1 * k0[0]; + sum1[1] += v1 * k0[1]; + sum2[0] += v0 * k1[0]; + sum2[1] += v0 * k1[1]; + sum3[0] += v1 * k1[0]; + sum3[1] += v1 * k1[1]; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec8(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec8(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec8(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec8(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st8(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st8(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st8(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st8(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st8(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st8(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack4.comp b/src/layer/vulkan/shader/convolution1d_pack4.comp new file mode 100644 index 000000000000..f1e125867852 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack4.comp @@ -0,0 +1,208 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +#if NCNN_fp16_packed || (NCNN_fp16_storage && !NCNN_fp16_arithmetic) +// GL_EXT_shader_16bit_storage does not define f16mat4 type :( +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +#else +layout (binding = 2) readonly buffer weight_blob { sfpmat4 weight_data[]; }; +#endif +layout (binding = 3) readonly buffer bias_blob { sfpvec4 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec4 sum0 = afpvec4(0.0f); + afpvec4 sum1 = afpvec4(0.0f); + afpvec4 sum2 = afpvec4(0.0f); + afpvec4 sum3 = afpvec4(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld4(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld4(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld4(bias_data, gy2.x); + sum2 = buffer_ld4(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpmat4 k0 = afpmat4( + image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.x)), + image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.x)), + image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.x)), + image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.x)) + ); + afpmat4 k1 = afpmat4( + image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.y)), + image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.y)), + image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.y)), + image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.y)) + ); + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + + wx += 4; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = buffer_ld4(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec4 v1 = buffer_ld4(bottom_blob_data, v_offset.y + x * dilation_w); + +#if NCNN_fp16_packed || (NCNN_fp16_storage && !NCNN_fp16_arithmetic) + // GL_EXT_shader_16bit_storage does not define f16mat4 type :( + afpmat4 k0 = afpmat4( + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 0), + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 1), + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 2), + buffer_ld4(weight_data, (w_offset.x + x) * 4 + 3) + ); + afpmat4 k1 = afpmat4( + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 0), + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 1), + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 2), + buffer_ld4(weight_data, (w_offset.y + x) * 4 + 3) + ); +#else + afpmat4 k0 = sfp2afpmat4(weight_data[w_offset.x + x]); + afpmat4 k1 = sfp2afpmat4(weight_data[w_offset.y + x]); +#endif + + sum0 += v0 * k0; + sum1 += v1 * k0; + sum2 += v0 * k1; + sum3 += v1 * k1; + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st4(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st4(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st4(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st4(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st4(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st4(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack4to1.comp b/src/layer/vulkan/shader/convolution1d_pack4to1.comp new file mode 100644 index 000000000000..1f5c87e1835b --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack4to1.comp @@ -0,0 +1,177 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfp bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afp sum0 = afp(0.0f); + afp sum1 = afp(0.0f); + afp sum2 = afp(0.0f); + afp sum3 = afp(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld1(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld1(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld1(bias_data, gy2.x); + sum2 = buffer_ld1(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec4 k0 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.x)); + afpvec4 k1 = image3d_ld4(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += dot(v0, k0); + sum1 += dot(v1, k0); + sum2 += dot(v0, k1); + sum3 += dot(v1, k1); + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = buffer_ld4(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec4 v1 = buffer_ld4(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec4 k0 = buffer_ld4(weight_data, w_offset.x + x); + afpvec4 k1 = buffer_ld4(weight_data, w_offset.y + x); + + sum0 += dot(v0, k0); + sum1 += dot(v1, k0); + sum2 += dot(v0, k1); + sum3 += dot(v1, k1); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afp(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afp(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afp(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afp(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st1(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st1(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st1(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st1(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st1(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st1(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack4to8.comp b/src/layer/vulkan/shader/convolution1d_pack4to8.comp new file mode 100644 index 000000000000..1133b097ac09 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack4to8.comp @@ -0,0 +1,270 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec4 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec4 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec8 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec8 sum0 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum1 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum2 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum3 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld8(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld8(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld8(bias_data, gy2.x); + sum2 = buffer_ld8(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = image3d_ld4(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec4 v1 = image3d_ld4(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec4 k0 = image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.x)); + afpvec4 k1 = image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.x)); + afpvec4 k2 = image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.x)); + afpvec4 k3 = image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.x)); + afpvec4 k4 = image3d_ld4(weight_blob, ivec3(wx + 4, y, gy2.x)); + afpvec4 k5 = image3d_ld4(weight_blob, ivec3(wx + 5, y, gy2.x)); + afpvec4 k6 = image3d_ld4(weight_blob, ivec3(wx + 6, y, gy2.x)); + afpvec4 k7 = image3d_ld4(weight_blob, ivec3(wx + 7, y, gy2.x)); + + afpvec4 k8 = image3d_ld4(weight_blob, ivec3(wx + 0, y, gy2.y)); + afpvec4 k9 = image3d_ld4(weight_blob, ivec3(wx + 1, y, gy2.y)); + afpvec4 ka = image3d_ld4(weight_blob, ivec3(wx + 2, y, gy2.y)); + afpvec4 kb = image3d_ld4(weight_blob, ivec3(wx + 3, y, gy2.y)); + afpvec4 kc = image3d_ld4(weight_blob, ivec3(wx + 4, y, gy2.y)); + afpvec4 kd = image3d_ld4(weight_blob, ivec3(wx + 5, y, gy2.y)); + afpvec4 ke = image3d_ld4(weight_blob, ivec3(wx + 6, y, gy2.y)); + afpvec4 kf = image3d_ld4(weight_blob, ivec3(wx + 7, y, gy2.y)); + + sum0[0].r += dot(v0, k0); + sum0[0].g += dot(v0, k1); + sum0[0].b += dot(v0, k2); + sum0[0].a += dot(v0, k3); + sum0[1].r += dot(v0, k4); + sum0[1].g += dot(v0, k5); + sum0[1].b += dot(v0, k6); + sum0[1].a += dot(v0, k7); + + sum1[0].r += dot(v1, k0); + sum1[0].g += dot(v1, k1); + sum1[0].b += dot(v1, k2); + sum1[0].a += dot(v1, k3); + sum1[1].r += dot(v1, k4); + sum1[1].g += dot(v1, k5); + sum1[1].b += dot(v1, k6); + sum1[1].a += dot(v1, k7); + + sum2[0].r += dot(v0, k8); + sum2[0].g += dot(v0, k9); + sum2[0].b += dot(v0, ka); + sum2[0].a += dot(v0, kb); + sum2[1].r += dot(v0, kc); + sum2[1].g += dot(v0, kd); + sum2[1].b += dot(v0, ke); + sum2[1].a += dot(v0, kf); + + sum3[0].r += dot(v1, k8); + sum3[0].g += dot(v1, k9); + sum3[0].b += dot(v1, ka); + sum3[0].a += dot(v1, kb); + sum3[1].r += dot(v1, kc); + sum3[1].g += dot(v1, kd); + sum3[1].b += dot(v1, ke); + sum3[1].a += dot(v1, kf); + + wx += 8; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec4 v0 = buffer_ld4(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec4 v1 = buffer_ld4(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec4 k0 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 0); + afpvec4 k1 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 1); + afpvec4 k2 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 2); + afpvec4 k3 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 3); + afpvec4 k4 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 4); + afpvec4 k5 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 5); + afpvec4 k6 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 6); + afpvec4 k7 = buffer_ld4(weight_data, (w_offset.x + x) * 8 + 7); + + afpvec4 k8 = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 0); + afpvec4 k9 = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 1); + afpvec4 ka = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 2); + afpvec4 kb = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 3); + afpvec4 kc = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 4); + afpvec4 kd = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 5); + afpvec4 ke = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 6); + afpvec4 kf = buffer_ld4(weight_data, (w_offset.y + x) * 8 + 7); + + sum0[0].r += dot(v0, k0); + sum0[0].g += dot(v0, k1); + sum0[0].b += dot(v0, k2); + sum0[0].a += dot(v0, k3); + sum0[1].r += dot(v0, k4); + sum0[1].g += dot(v0, k5); + sum0[1].b += dot(v0, k6); + sum0[1].a += dot(v0, k7); + + sum1[0].r += dot(v1, k0); + sum1[0].g += dot(v1, k1); + sum1[0].b += dot(v1, k2); + sum1[0].a += dot(v1, k3); + sum1[1].r += dot(v1, k4); + sum1[1].g += dot(v1, k5); + sum1[1].b += dot(v1, k6); + sum1[1].a += dot(v1, k7); + + sum2[0].r += dot(v0, k8); + sum2[0].g += dot(v0, k9); + sum2[0].b += dot(v0, ka); + sum2[0].a += dot(v0, kb); + sum2[1].r += dot(v0, kc); + sum2[1].g += dot(v0, kd); + sum2[1].b += dot(v0, ke); + sum2[1].a += dot(v0, kf); + + sum3[0].r += dot(v1, k8); + sum3[0].g += dot(v1, k9); + sum3[0].b += dot(v1, ka); + sum3[0].a += dot(v1, kb); + sum3[1].r += dot(v1, kc); + sum3[1].g += dot(v1, kd); + sum3[1].b += dot(v1, ke); + sum3[1].a += dot(v1, kf); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec8(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec8(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec8(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec8(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st8(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st8(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st8(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st8(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st8(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st8(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack8.comp b/src/layer/vulkan/shader/convolution1d_pack8.comp new file mode 100644 index 000000000000..fff72ade8290 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack8.comp @@ -0,0 +1,270 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec8 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec8 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec8 sum0 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum1 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum2 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + afpvec8 sum3 = afpvec8(afpvec4(0.0f), afpvec4(0.0f)); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld8(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld8(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld8(bias_data, gy2.x); + sum2 = buffer_ld8(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = image3d_ld8(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec8 v1 = image3d_ld8(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.x)); + afpvec8 k2 = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.x)); + afpvec8 k3 = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.x)); + afpvec8 k4 = image3d_ld8(weight_blob, ivec3(wx + 4, y, gy2.x)); + afpvec8 k5 = image3d_ld8(weight_blob, ivec3(wx + 5, y, gy2.x)); + afpvec8 k6 = image3d_ld8(weight_blob, ivec3(wx + 6, y, gy2.x)); + afpvec8 k7 = image3d_ld8(weight_blob, ivec3(wx + 7, y, gy2.x)); + + afpvec8 k8 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.y)); + afpvec8 k9 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.y)); + afpvec8 ka = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.y)); + afpvec8 kb = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.y)); + afpvec8 kc = image3d_ld8(weight_blob, ivec3(wx + 4, y, gy2.y)); + afpvec8 kd = image3d_ld8(weight_blob, ivec3(wx + 5, y, gy2.y)); + afpvec8 ke = image3d_ld8(weight_blob, ivec3(wx + 6, y, gy2.y)); + afpvec8 kf = image3d_ld8(weight_blob, ivec3(wx + 7, y, gy2.y)); + + sum0[0].r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0[0].g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0[0].b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0[0].a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + sum0[1].r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum0[1].g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum0[1].b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum0[1].a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum1[0].r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1[0].g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1[0].b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1[0].a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + sum1[1].r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum1[1].g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum1[1].b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum1[1].a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + + sum2[0].r += dot(v0[0], k8[0]) + dot(v0[1], k8[1]); + sum2[0].g += dot(v0[0], k9[0]) + dot(v0[1], k9[1]); + sum2[0].b += dot(v0[0], ka[0]) + dot(v0[1], ka[1]); + sum2[0].a += dot(v0[0], kb[0]) + dot(v0[1], kb[1]); + sum2[1].r += dot(v0[0], kc[0]) + dot(v0[1], kc[1]); + sum2[1].g += dot(v0[0], kd[0]) + dot(v0[1], kd[1]); + sum2[1].b += dot(v0[0], ke[0]) + dot(v0[1], ke[1]); + sum2[1].a += dot(v0[0], kf[0]) + dot(v0[1], kf[1]); + + sum3[0].r += dot(v1[0], k8[0]) + dot(v1[1], k8[1]); + sum3[0].g += dot(v1[0], k9[0]) + dot(v1[1], k9[1]); + sum3[0].b += dot(v1[0], ka[0]) + dot(v1[1], ka[1]); + sum3[0].a += dot(v1[0], kb[0]) + dot(v1[1], kb[1]); + sum3[1].r += dot(v1[0], kc[0]) + dot(v1[1], kc[1]); + sum3[1].g += dot(v1[0], kd[0]) + dot(v1[1], kd[1]); + sum3[1].b += dot(v1[0], ke[0]) + dot(v1[1], ke[1]); + sum3[1].a += dot(v1[0], kf[0]) + dot(v1[1], kf[1]); + + wx += 8; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = buffer_ld8(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec8 v1 = buffer_ld8(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 0); + afpvec8 k1 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 1); + afpvec8 k2 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 2); + afpvec8 k3 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 3); + afpvec8 k4 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 4); + afpvec8 k5 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 5); + afpvec8 k6 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 6); + afpvec8 k7 = buffer_ld8(weight_data, (w_offset.x + x) * 8 + 7); + + afpvec8 k8 = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 0); + afpvec8 k9 = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 1); + afpvec8 ka = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 2); + afpvec8 kb = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 3); + afpvec8 kc = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 4); + afpvec8 kd = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 5); + afpvec8 ke = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 6); + afpvec8 kf = buffer_ld8(weight_data, (w_offset.y + x) * 8 + 7); + + sum0[0].r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0[0].g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0[0].b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0[0].a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + sum0[1].r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum0[1].g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum0[1].b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum0[1].a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum1[0].r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1[0].g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1[0].b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1[0].a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + sum1[1].r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum1[1].g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum1[1].b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum1[1].a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + + sum2[0].r += dot(v0[0], k8[0]) + dot(v0[1], k8[1]); + sum2[0].g += dot(v0[0], k9[0]) + dot(v0[1], k9[1]); + sum2[0].b += dot(v0[0], ka[0]) + dot(v0[1], ka[1]); + sum2[0].a += dot(v0[0], kb[0]) + dot(v0[1], kb[1]); + sum2[1].r += dot(v0[0], kc[0]) + dot(v0[1], kc[1]); + sum2[1].g += dot(v0[0], kd[0]) + dot(v0[1], kd[1]); + sum2[1].b += dot(v0[0], ke[0]) + dot(v0[1], ke[1]); + sum2[1].a += dot(v0[0], kf[0]) + dot(v0[1], kf[1]); + + sum3[0].r += dot(v1[0], k8[0]) + dot(v1[1], k8[1]); + sum3[0].g += dot(v1[0], k9[0]) + dot(v1[1], k9[1]); + sum3[0].b += dot(v1[0], ka[0]) + dot(v1[1], ka[1]); + sum3[0].a += dot(v1[0], kb[0]) + dot(v1[1], kb[1]); + sum3[1].r += dot(v1[0], kc[0]) + dot(v1[1], kc[1]); + sum3[1].g += dot(v1[0], kd[0]) + dot(v1[1], kd[1]); + sum3[1].b += dot(v1[0], ke[0]) + dot(v1[1], ke[1]); + sum3[1].a += dot(v1[0], kf[0]) + dot(v1[1], kf[1]); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec8(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec8(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec8(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec8(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st8(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st8(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st8(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st8(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st8(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st8(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st8(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack8to1.comp b/src/layer/vulkan/shader/convolution1d_pack8to1.comp new file mode 100644 index 000000000000..9d08d3b11af7 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack8to1.comp @@ -0,0 +1,178 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec8 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfp bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afp sum0 = afp(0.0f); + afp sum1 = afp(0.0f); + afp sum2 = afp(0.0f); + afp sum3 = afp(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld1(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld1(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld1(bias_data, gy2.x); + sum2 = buffer_ld1(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = image3d_ld8(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec8 v1 = image3d_ld8(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx, y, gy2.y)); + + sum0 += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum1 += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum2 += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum3 += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + + wx += 1; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = buffer_ld8(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec8 v1 = buffer_ld8(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, w_offset.x + x); + afpvec8 k1 = buffer_ld8(weight_data, w_offset.y + x); + + sum0 += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum1 += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum2 += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum3 += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afp(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afp(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afp(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afp(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st1(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st1(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st1(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st1(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st1(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st1(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st1(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file diff --git a/src/layer/vulkan/shader/convolution1d_pack8to4.comp b/src/layer/vulkan/shader/convolution1d_pack8to4.comp new file mode 100644 index 000000000000..86ca696d5840 --- /dev/null +++ b/src/layer/vulkan/shader/convolution1d_pack8to4.comp @@ -0,0 +1,220 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#version 450 + +#if NCNN_fp16_storage +#extension GL_EXT_shader_16bit_storage: require +struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; +#endif +#if NCNN_fp16_arithmetic +#extension GL_EXT_shader_explicit_arithmetic_types_float16: require +#endif + +#extension GL_GOOGLE_include_directive: enable +#include "vulkan_activation.comp" + +layout (constant_id = 0) const int kernel_w = 1; +layout (constant_id = 1) const int dilation_w = 1; +layout (constant_id = 2) const int stride_w = 1; +layout (constant_id = 3) const int bias_term = 0; +layout (constant_id = 4) const int activation_type = 0; +layout (constant_id = 5) const float activation_param_0 = 0; +layout (constant_id = 6) const float activation_param_1 = 0; + +#define shape_constant_id_offset 7 +layout (constant_id = shape_constant_id_offset + 0) const int dims = 0; +layout (constant_id = shape_constant_id_offset + 1) const int w = 0; +layout (constant_id = shape_constant_id_offset + 2) const int h = 0; +layout (constant_id = shape_constant_id_offset + 3) const int c = 0; +layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0; + +layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0; +layout (constant_id = shape_constant_id_offset + 6) const int outw = 0; +layout (constant_id = shape_constant_id_offset + 7) const int outh = 0; +layout (constant_id = shape_constant_id_offset + 8) const int outc = 0; +layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0; + +#if NCNN_image_shader +layout (binding = 0) uniform unfp sampler3D bottom_blob; +layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob; +layout (binding = 2) uniform unfp sampler3D weight_blob; +layout (binding = 3) uniform unfp sampler3D bias_blob; +#else +layout (binding = 0) readonly buffer bottom_blob { sfpvec8 bottom_blob_data[]; }; +layout (binding = 1) writeonly buffer top_blob { sfpvec4 top_blob_data[]; }; +layout (binding = 2) readonly buffer weight_blob { sfpvec8 weight_data[]; }; +layout (binding = 3) readonly buffer bias_blob { sfpvec4 bias_data[]; }; +#endif + +layout (push_constant) uniform parameter +{ + int dims; + int w; + int h; + int c; + int cstep; + + int outdims; + int outw; + int outh; + int outc; + int outcstep; +} p; + +void main() +{ + int gx = int(gl_GlobalInvocationID.x) * 2; + int gy = int(gl_GlobalInvocationID.y) * 2; + + if (gx >= psc(outw) || gy >= psc(outh)) + return; + + const ivec2 gx2 = gx + ivec2(0, 1); + const ivec2 gy2 = gy + ivec2(0, 1); + + afpvec4 sum0 = afpvec4(0.0f); + afpvec4 sum1 = afpvec4(0.0f); + afpvec4 sum2 = afpvec4(0.0f); + afpvec4 sum3 = afpvec4(0.0f); + + if (bias_term == 1) + { +#if NCNN_image_shader + sum0 = image3d_ld4(bias_blob, ivec3(gy2.x, 0, 0)); + sum2 = image3d_ld4(bias_blob, ivec3(gy2.y, 0, 0)); +#else + sum0 = buffer_ld4(bias_data, gy2.x); + sum2 = buffer_ld4(bias_data, gy2.y); +#endif + sum1 = sum0; + sum3 = sum2; + } + +#if NCNN_image_shader + + ivec2 v_offset = gx2 * stride_w; + + for (int y = 0; y < psc(h); y++) + { + int wx = 0; + + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = image3d_ld8(bottom_blob, ivec3(v_offset.x + x * dilation_w, y, 0)); + afpvec8 v1 = image3d_ld8(bottom_blob, ivec3(v_offset.y + x * dilation_w, y, 0)); + + afpvec8 k0 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.x)); + afpvec8 k1 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.x)); + afpvec8 k2 = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.x)); + afpvec8 k3 = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.x)); + afpvec8 k4 = image3d_ld8(weight_blob, ivec3(wx + 0, y, gy2.y)); + afpvec8 k5 = image3d_ld8(weight_blob, ivec3(wx + 1, y, gy2.y)); + afpvec8 k6 = image3d_ld8(weight_blob, ivec3(wx + 2, y, gy2.y)); + afpvec8 k7 = image3d_ld8(weight_blob, ivec3(wx + 3, y, gy2.y)); + + sum0.r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0.g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0.b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0.a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + + sum1.r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1.g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1.b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1.a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + + sum2.r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum2.g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum2.b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum2.a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum3.r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum3.g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum3.b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum3.a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + + wx += 4; + } + } + +#else + + ivec2 v_offset = gx2 * stride_w; + ivec2 w_offset = gy2 * psc(h) * kernel_w; + + for (int y = 0; y < psc(h); y++) + { + for (int x = 0; x < kernel_w; x++) + { + afpvec8 v0 = buffer_ld8(bottom_blob_data, v_offset.x + x * dilation_w); + afpvec8 v1 = buffer_ld8(bottom_blob_data, v_offset.y + x * dilation_w); + + afpvec8 k0 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 0); + afpvec8 k1 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 1); + afpvec8 k2 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 2); + afpvec8 k3 = buffer_ld8(weight_data, (w_offset.x + x) * 4 + 3); + afpvec8 k4 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 0); + afpvec8 k5 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 1); + afpvec8 k6 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 2); + afpvec8 k7 = buffer_ld8(weight_data, (w_offset.y + x) * 4 + 3); + + sum0.r += dot(v0[0], k0[0]) + dot(v0[1], k0[1]); + sum0.g += dot(v0[0], k1[0]) + dot(v0[1], k1[1]); + sum0.b += dot(v0[0], k2[0]) + dot(v0[1], k2[1]); + sum0.a += dot(v0[0], k3[0]) + dot(v0[1], k3[1]); + + sum1.r += dot(v1[0], k0[0]) + dot(v1[1], k0[1]); + sum1.g += dot(v1[0], k1[0]) + dot(v1[1], k1[1]); + sum1.b += dot(v1[0], k2[0]) + dot(v1[1], k2[1]); + sum1.a += dot(v1[0], k3[0]) + dot(v1[1], k3[1]); + + sum2.r += dot(v0[0], k4[0]) + dot(v0[1], k4[1]); + sum2.g += dot(v0[0], k5[0]) + dot(v0[1], k5[1]); + sum2.b += dot(v0[0], k6[0]) + dot(v0[1], k6[1]); + sum2.a += dot(v0[0], k7[0]) + dot(v0[1], k7[1]); + + sum3.r += dot(v1[0], k4[0]) + dot(v1[1], k4[1]); + sum3.g += dot(v1[0], k5[0]) + dot(v1[1], k5[1]); + sum3.b += dot(v1[0], k6[0]) + dot(v1[1], k6[1]); + sum3.a += dot(v1[0], k7[0]) + dot(v1[1], k7[1]); + } + v_offset += psc(w); + w_offset += kernel_w; + } + +#endif + + sum0 = activation_afpvec4(sum0, activation_type, activation_param_0, activation_param_1); + sum1 = activation_afpvec4(sum1, activation_type, activation_param_0, activation_param_1); + sum2 = activation_afpvec4(sum2, activation_type, activation_param_0, activation_param_1); + sum3 = activation_afpvec4(sum3, activation_type, activation_param_0, activation_param_1); + +#if NCNN_image_shader + + image3d_st4(top_blob, ivec3(gx2.x, gy2.x, 0), sum0); + image3d_st4(top_blob, ivec3(gx2.y, gy2.x, 0), sum1); + image3d_st4(top_blob, ivec3(gx2.x, gy2.y, 0), sum2); + image3d_st4(top_blob, ivec3(gx2.y, gy2.y, 0), sum3); + +#else + + const int gi = gy * psc(outw) + gx; + + buffer_st4(top_blob_data, gi, sum0); + if (gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + 1, sum1); + if (gy + 1 < psc(outh)) buffer_st4(top_blob_data, gi + psc(outw), sum2); + if (gy + 1 < psc(outh) && gx + 1 < psc(outw)) buffer_st4(top_blob_data, gi + psc(outw) + 1, sum3); + +#endif +} \ No newline at end of file From 3f437d3f3d1e0d306ea105ae06f29b11c7bec83c Mon Sep 17 00:00:00 2001 From: Yoh Date: Fri, 20 Oct 2023 19:09:59 +0800 Subject: [PATCH 16/16] Grid sample op (#4373) * pnnx support grid_sample op * complete the permute and gridsample operator fusion * spilt calculation into two stages and support permute fusion --- docs/developer-guide/operators.md | 1 + src/layer/gridsample.cpp | 321 ++++++--- src/layer/gridsample.h | 18 +- .../gridsample_bicubic_apply_interpolation.h | 288 ++++++++ .../x86/gridsample_bicubic_compute_blob.h | 299 +++++++++ .../gridsample_bilinear_apply_interpolation.h | 373 +++++++++++ .../x86/gridsample_bilinear_compute_blob.h | 623 ++++++++++++++++++ src/layer/x86/gridsample_compute_blob.h | 145 ++++ .../gridsample_nearest_apply_interpolation.h | 126 ++++ .../x86/gridsample_nearest_compute_blob.h | 315 +++++++++ src/layer/x86/gridsample_x86.cpp | 455 +++++++++++++ src/layer/x86/gridsample_x86.h | 32 + src/layer/x86/x86_usability.h | 192 ++++++ tests/test_gridsample.cpp | 189 ++++-- tools/modelwriter.h | 1 + tools/pnnx/src/pass_ncnn/F_grid_sample.cpp | 105 ++- tools/pnnx/tests/ncnn/test_F_grid_sample.py | 92 +-- 17 files changed, 3374 insertions(+), 201 deletions(-) create mode 100644 src/layer/x86/gridsample_bicubic_apply_interpolation.h create mode 100644 src/layer/x86/gridsample_bicubic_compute_blob.h create mode 100644 src/layer/x86/gridsample_bilinear_apply_interpolation.h create mode 100644 src/layer/x86/gridsample_bilinear_compute_blob.h create mode 100644 src/layer/x86/gridsample_compute_blob.h create mode 100644 src/layer/x86/gridsample_nearest_apply_interpolation.h create mode 100644 src/layer/x86/gridsample_nearest_compute_blob.h create mode 100644 src/layer/x86/gridsample_x86.cpp create mode 100644 src/layer/x86/gridsample_x86.h diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index b0975ee46877..4cabb049340f 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -914,6 +914,7 @@ This function is often used in conjunction with affine_grid() to build Spatial T | 0 | sample_type | int | 1 | | | 1 | padding_mode | int | 1 | | | 2 | align_corner | int | 0 | | +| 3 | permute_fusion| int | 0 | fuse with permute | Sample type: diff --git a/src/layer/gridsample.cpp b/src/layer/gridsample.cpp index 31405047ec13..e8579cf4aefe 100644 --- a/src/layer/gridsample.cpp +++ b/src/layer/gridsample.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // coord compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,6 @@ // specific language governing permissions and limitations under the License. #include "gridsample.h" - #include namespace ncnn { @@ -29,6 +28,7 @@ int GridSample::load_param(const ParamDict& pd) sample_type = pd.get(0, 1); padding_mode = pd.get(1, 1); align_corner = pd.get(2, 0); + permute_fusion = pd.get(3, 0); if (sample_type < 1 || sample_type > 3) { @@ -59,19 +59,19 @@ static float grid_sample_unormalize(int w, float coordx, int align_corner) return align_corner ? (coordx + 1) / 2.f * (w - 1) : ((coordx + 1) * w - 1) / 2.f; } -static float border_coord(int x, int border) +static float border_coord(float x, float border) { - return std::min(border, std::max(x, 0)); + return std::min(border, std::max(x, 0.0f)); } static float reflect_coord(float x, int high) { - x = abs(x); - x = high - abs(x - high); + x = fabs(x); + x = high - fabs(x - high); return x; } -static int compute_coord(int sx, int w, int padding_mode, int align_corner) +static float compute_coord(float sx, int w, int padding_mode, int align_corner) { if (padding_mode == 2) // border { @@ -85,7 +85,7 @@ static int compute_coord(int sx, int w, int padding_mode, int align_corner) } else { - sx = static_cast(reflect_coord(sx + 0.5f, w) - 0.5f); + sx = reflect_coord(sx + 0.5, w) - 0.5; sx = border_coord(sx, w - 1); } } @@ -110,7 +110,7 @@ static float get_value_bounded(const Mat& image, int x, int y) static float get_value_bounded(const Mat& image, int x, int y, int z) { - return in_bounds(image, x, y, z) ? image.channel(z).row(y)[x] : 0.f; + return in_bounds(image, x, y, z) ? image.depth(z).row(y)[x] : 0.f; } static float get_value_bounded(const Mat& image, int x, int y, int padding_mode, int align_corner) @@ -121,15 +121,6 @@ static float get_value_bounded(const Mat& image, int x, int y, int padding_mode, return get_value_bounded(image, x, y); } -static float get_value_bounded(const Mat& image, int x, int y, int z, int padding_mode, int align_corner) -{ - x = compute_coord(x, image.w, padding_mode, align_corner); - y = compute_coord(y, image.h, padding_mode, align_corner); - z = compute_coord(z, image.c, padding_mode, align_corner); - - return get_value_bounded(image, x, y, z); -} - static inline void interpolate_cubic(float fx, float* coeffs) { const float A = -0.75f; @@ -160,45 +151,102 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& if (dims == 3) { - int outw = grid.h; - int outh = grid.c; + int outw = permute_fusion == 0 ? grid.h : grid.w; + int outh = permute_fusion == 0 ? grid.c : grid.h; top_blob.create(outw, outh, channels, elemsize, opt.blob_allocator); - if (top_blob.empty()) + + Mat offset_blob; + offset_blob.create(outw, outh, grid.c, elemsize, opt.workspace_allocator); + + if (top_blob.empty() || offset_blob.empty()) return -100; - if (sample_type == 1) // bilinear + //pre-calculate all interpolation offsets for each x y, unpack grid on-the-fly + if (permute_fusion == 0) + { + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + + for (int y = 0; y < outh; y++) + { + const float* gridptr = grid.channel(y); + for (int x = 0; x < outw; x++) + { + float sample_x = gridptr[0]; + float sample_y = gridptr[1]; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + + gridptr += 2; + offsetptr_x++; + offsetptr_y++; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + + for (int y = 0; y < outh; y++) + { + for (int x = 0; x < outw; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + + gridptr_x++; + gridptr_y++; + offsetptr_x++; + offsetptr_y++; + } + } + } + + if (sample_type == Interpolation_BILINEAR) // bilinear { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); for (int y = 0; y < outh; y++) { - const float* gridptr = grid.channel(y); - for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; // bilinear interpolate float v; { - int x0 = (int)floor(sample_x); - int y0 = (int)floor(sample_y); + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); + int x0 = floor(sample_x); + int y0 = floor(sample_y); int x1 = x0 + 1; int y1 = y0 + 1; - float v00 = get_value_bounded(image, x0, y0, padding_mode, align_corner); - float v01 = get_value_bounded(image, x1, y0, padding_mode, align_corner); - float v10 = get_value_bounded(image, x0, y1, padding_mode, align_corner); - float v11 = get_value_bounded(image, x1, y1, padding_mode, align_corner); + float v00 = get_value_bounded(image, x0, y0); + float v01 = get_value_bounded(image, x1, y0); + float v10 = get_value_bounded(image, x0, y1); + float v11 = get_value_bounded(image, x1, y1); float alpha = sample_x - x0; float beta = sample_y - y0; @@ -212,63 +260,61 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& outptr[0] = v; outptr += 1; - gridptr += 2; + offsetptr_x++; + offsetptr_y++; } } } } - else if (sample_type == 2) // nearest + else if (sample_type == Interpolation_NEAREST) // nearest { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); for (int y = 0; y < outh; y++) { - const float* gridptr = grid.channel(y); - for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); - int x0 = static_cast(round(sample_x)); - int y0 = static_cast(round(sample_y)); + int x0 = static_cast(floor(sample_x + 0.5f)); + int y0 = static_cast(floor(sample_y + 0.5f)); - float v = get_value_bounded(image, x0, y0, padding_mode, align_corner); + float v = get_value_bounded(image, x0, y0); outptr[0] = v; outptr += 1; - gridptr += 2; + offsetptr_x++; + offsetptr_y++; } } } } - else if (sample_type == 3) // bicubic + else if (sample_type == Interpolation_BICUBIC) // bicubic { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); for (int y = 0; y < outh; y++) { - const float* gridptr = grid.channel(y); - for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; // bicubic interpolate float v; @@ -315,7 +361,8 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& outptr[0] = v; outptr += 1; - gridptr += 2; + offsetptr_x++; + offsetptr_y++; } } } @@ -324,37 +371,120 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& if (dims == 4) { - int outw = grid.h; - int outh = grid.d; - int outd = grid.c; + int outw = permute_fusion == 0 ? grid.h : grid.w; + int outh = permute_fusion == 0 ? grid.d : grid.h; + int outd = permute_fusion == 0 ? grid.c : grid.d; top_blob.create(outw, outh, outd, channels, elemsize, opt.blob_allocator); - if (top_blob.empty()) + + Mat offset_blob; + offset_blob.create(outw, outh, outd, grid.c, elemsize, opt.workspace_allocator); + + if (top_blob.empty() || offset_blob.empty()) return -100; - if (sample_type == 1) // bilinear + //pre-calculate all interpolation offsets for each x y, unpack grid on-the-fly + if (permute_fusion == 0) + { + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + float* offsetptr_z = offset_blob.channel(2); + + for (int z = 0; z < outd; z++) + { + const float* gridptr = grid.channel(z); + for (int y = 0; y < outh; y++) + { + for (int x = 0; x < outw; x++) + { + float sample_x = gridptr[0]; + float sample_y = gridptr[1]; + float sample_z = gridptr[2]; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); + + sample_z = grid_sample_unormalize(d, sample_z, align_corner); + sample_z = compute_coord(sample_z, d, padding_mode, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + *offsetptr_z = sample_z; + + gridptr += 3; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; + } + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + const float* gridptr_z = grid.channel(2); + float* offsetptr_x = offset_blob.channel(0); + float* offsetptr_y = offset_blob.channel(1); + float* offsetptr_z = offset_blob.channel(2); + + for (int z = 0; z < outd; z++) + { + for (int y = 0; y < outh; y++) + { + for (int x = 0; x < outw; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + float sample_z = *gridptr_z; + + sample_x = grid_sample_unormalize(w, sample_x, align_corner); + sample_x = compute_coord(sample_x, w, padding_mode, align_corner); + + sample_y = grid_sample_unormalize(h, sample_y, align_corner); + sample_y = compute_coord(sample_y, h, padding_mode, align_corner); + + sample_z = grid_sample_unormalize(d, sample_z, align_corner); + sample_z = compute_coord(sample_z, d, padding_mode, align_corner); + + *offsetptr_x = sample_x; + *offsetptr_y = sample_y; + *offsetptr_z = sample_z; + + gridptr_x++; + gridptr_y++; + gridptr_z++; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; + } + } + } + } + + if (sample_type == Interpolation_BILINEAR) // bilinear { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); + const float* offsetptr_z = offset_blob.channel(2); for (int z = 0; z < outd; z++) { - const float* gridptr = grid.channel(z); - for (int y = 0; y < outh; y++) { for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - float sample_z = gridptr[2]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); - sample_z = grid_sample_unormalize(d, sample_z, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; + float sample_z = *offsetptr_z; // bilinear interpolate float v; @@ -366,14 +496,14 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& int y1 = y0 + 1; int z1 = z0 + 1; - float v000 = get_value_bounded(image, x0, y0, z0, padding_mode, align_corner); - float v001 = get_value_bounded(image, x1, y0, z0, padding_mode, align_corner); - float v010 = get_value_bounded(image, x0, y1, z0, padding_mode, align_corner); - float v011 = get_value_bounded(image, x1, y1, z0, padding_mode, align_corner); - float v100 = get_value_bounded(image, x0, y0, z1, padding_mode, align_corner); - float v101 = get_value_bounded(image, x1, y0, z1, padding_mode, align_corner); - float v110 = get_value_bounded(image, x0, y1, z1, padding_mode, align_corner); - float v111 = get_value_bounded(image, x1, y1, z1, padding_mode, align_corner); + float v000 = get_value_bounded(image, x0, y0, z0); + float v001 = get_value_bounded(image, x1, y0, z0); + float v010 = get_value_bounded(image, x0, y1, z0); + float v011 = get_value_bounded(image, x1, y1, z0); + float v100 = get_value_bounded(image, x0, y0, z1); + float v101 = get_value_bounded(image, x1, y0, z1); + float v110 = get_value_bounded(image, x0, y1, z1); + float v111 = get_value_bounded(image, x1, y1, z1); float alpha = sample_x - x0; float beta = sample_y - y0; @@ -393,46 +523,47 @@ int GridSample::forward(const std::vector& bottom_blobs, std::vector& outptr[0] = v; outptr += 1; - gridptr += 3; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; } } } } } - else if (sample_type == 2) // nearest + else if (sample_type == Interpolation_NEAREST) // nearest { #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const Mat image = bottom_blob.channel(q); float* outptr = top_blob.channel(q); + const float* offsetptr_x = offset_blob.channel(0); + const float* offsetptr_y = offset_blob.channel(1); + const float* offsetptr_z = offset_blob.channel(2); for (int z = 0; z < outd; z++) { - const float* gridptr = grid.channel(z); - for (int y = 0; y < outh; y++) { for (int x = 0; x < outw; x++) { - float sample_x = gridptr[0]; - float sample_y = gridptr[1]; - float sample_z = gridptr[2]; - - sample_x = grid_sample_unormalize(w, sample_x, align_corner); - sample_y = grid_sample_unormalize(h, sample_y, align_corner); - sample_z = grid_sample_unormalize(d, sample_z, align_corner); + float sample_x = *offsetptr_x; + float sample_y = *offsetptr_y; + float sample_z = *offsetptr_z; - int x0 = static_cast(round(sample_x)); - int y0 = static_cast(round(sample_y)); - int z0 = static_cast(round(sample_z)); + int x0 = static_cast(floor(sample_x + 0.5f)); + int y0 = static_cast(floor(sample_y + 0.5f)); + int z0 = static_cast(floor(sample_z + 0.5f)); - float v = get_value_bounded(image, x0, y0, z0, padding_mode, align_corner); + float v = get_value_bounded(image, x0, y0, z0); outptr[0] = v; outptr += 1; - gridptr += 3; + offsetptr_x++; + offsetptr_y++; + offsetptr_z++; } } } diff --git a/src/layer/gridsample.h b/src/layer/gridsample.h index 0ea540eb4baf..f6e17c9d2f49 100644 --- a/src/layer/gridsample.h +++ b/src/layer/gridsample.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -28,11 +28,27 @@ class GridSample : public Layer virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + enum InterpolationMode // 1=bilinear 2=nearest 3=bicubic + { + Interpolation_BILINEAR = 1, + Interpolation_NEAREST = 2, + Interpolation_BICUBIC = 3 + }; + + enum PaddingMode // 1=zeros 2=border 3=reflection + { + Padding_ZEROS = 1, + Padding_BORDER = 2, + Padding_REFLECTION = 3 + }; + public: // param int sample_type; // 1=bilinear 2=nearest 3=bicubic int padding_mode; // 1=zeros 2=border 3=reflection int align_corner; + + int permute_fusion; }; } // namespace ncnn diff --git a/src/layer/x86/gridsample_bicubic_apply_interpolation.h b/src/layer/x86/gridsample_bicubic_apply_interpolation.h new file mode 100644 index 000000000000..0b7be771d3b7 --- /dev/null +++ b/src/layer/x86/gridsample_bicubic_apply_interpolation.h @@ -0,0 +1,288 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ +static void cubic_interp1d_p16(__m512& coeffs0, __m512& coeffs1, __m512& coeffs2, __m512& coeffs3, const __m512& tx) +{ + const __m512 A = _mm512_set1_ps(-0.75f); + + const __m512 x0 = _mm512_add_ps(tx, _mm512_set1_ps(1.0f)); + const __m512& x1 = tx; + const __m512 x2 = _mm512_sub_ps(_mm512_set1_ps(1.0f), tx); + + coeffs0 = _mm512_sub_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_mul_ps(A, x0), _mm512_mul_ps(_mm512_set1_ps(5.0f), A)), x0), _mm512_mul_ps(_mm512_set1_ps(8.0f), A)), x0), _mm512_mul_ps(_mm512_set1_ps(4), A)); + coeffs1 = _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_mul_ps(_mm512_add_ps(A, _mm512_set1_ps(2.0f)), x1), _mm512_add_ps(A, _mm512_set1_ps(3.0f))), x1), x1), _mm512_set1_ps(1.0f)); + coeffs2 = _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_mul_ps(_mm512_add_ps(A, _mm512_set1_ps(2.0f)), x2), _mm512_add_ps(A, _mm512_set1_ps(3.0f))), x2), x2), _mm512_set1_ps(1.0f)); + coeffs3 = _mm512_sub_ps(_mm512_sub_ps(_mm512_sub_ps(_mm512_set1_ps(1.0f), coeffs0), coeffs1), coeffs2); +} + +static void gridsample_2d_bicubic_apply_interpolation_p16(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m512 x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + __m512 y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + __m512 value_f[4]; + cubic_interp1d_p16(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, _mm512_set1_ps(offset_value_ptr[0])); + cubic_interp1d_p16(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, _mm512_set1_ps(offset_value_ptr[1])); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + __m512 x0_val = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + __m512 x1_val = offset_ptr[1] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[1]) : _mm512_set1_ps(0); + __m512 x2_val = offset_ptr[2] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[2]) : _mm512_set1_ps(0); + __m512 x3_val = offset_ptr[3] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[3]) : _mm512_set1_ps(0); + + value_f[ii] = _mm512_mul_ps(x_coeffs0, x0_val); + value_f[ii] = _mm512_fmadd_ps(x_coeffs1, x1_val, value_f[ii]); + value_f[ii] = _mm512_fmadd_ps(x_coeffs2, x2_val, value_f[ii]); + value_f[ii] = _mm512_fmadd_ps(x_coeffs3, x3_val, value_f[ii]); + + offset_ptr += 4; + } + + __m512 _v = _mm512_mul_ps(y_coeffs0, value_f[0]); + _v = _mm512_fmadd_ps(y_coeffs1, value_f[1], _v); + _v = _mm512_fmadd_ps(y_coeffs2, value_f[2], _v); + _v = _mm512_fmadd_ps(y_coeffs3, value_f[3], _v); + _mm512_storeu_ps(dstptr, _v); + + dstptr += 16; + offset_value_ptr += 18; + } + } +} + +#endif // __AVX512F__ +static void cubic_interp1d_p8(__m256& coeffs0, __m256& coeffs1, __m256& coeffs2, __m256& coeffs3, const __m256& tx) +{ + const __m256 A = _mm256_set1_ps(-0.75f); + + const __m256 x0 = _mm256_add_ps(tx, _mm256_set1_ps(1)); + const __m256& x1 = tx; + const __m256 x2 = _mm256_sub_ps(_mm256_set1_ps(1), tx); + //const __m256 x3 = _mm256_add_ps(x2, _mm256_set1_ps(1)); + + coeffs0 = _mm256_sub_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_mul_ps(A, x0), _mm256_mul_ps(_mm256_set1_ps(5.0f), A)), x0), _mm256_mul_ps(_mm256_set1_ps(8.0f), A)), x0), _mm256_mul_ps(_mm256_set1_ps(4), A)); + coeffs1 = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_mul_ps(_mm256_add_ps(A, _mm256_set1_ps(2.0f)), x1), _mm256_add_ps(A, _mm256_set1_ps(3.0f))), x1), x1), _mm256_set1_ps(1)); + coeffs2 = _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_mul_ps(_mm256_add_ps(A, _mm256_set1_ps(2.0f)), x2), _mm256_add_ps(A, _mm256_set1_ps(3.0f))), x2), x2), _mm256_set1_ps(1)); + coeffs3 = _mm256_sub_ps(_mm256_sub_ps(_mm256_sub_ps(_mm256_set1_ps(1), coeffs0), coeffs1), coeffs2); +} + +static void gridsample_2d_bicubic_apply_interpolation_p8(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m256 x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + __m256 y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + __m256 value_f[4]; + cubic_interp1d_p8(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, _mm256_set1_ps(offset_value_ptr[0])); + cubic_interp1d_p8(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, _mm256_set1_ps(offset_value_ptr[1])); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + __m256 x0_val = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + __m256 x1_val = offset_ptr[1] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[1]) : _mm256_set1_ps(0); + __m256 x2_val = offset_ptr[2] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[2]) : _mm256_set1_ps(0); + __m256 x3_val = offset_ptr[3] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[3]) : _mm256_set1_ps(0); + + value_f[ii] = _mm256_mul_ps(x_coeffs0, x0_val); + value_f[ii] = _mm256_comp_fmadd_ps(x_coeffs1, x1_val, value_f[ii]); + value_f[ii] = _mm256_comp_fmadd_ps(x_coeffs2, x2_val, value_f[ii]); + value_f[ii] = _mm256_comp_fmadd_ps(x_coeffs3, x3_val, value_f[ii]); + + offset_ptr += 4; + } + + __m256 _v = _mm256_mul_ps(y_coeffs0, value_f[0]); + _v = _mm256_comp_fmadd_ps(y_coeffs1, value_f[1], _v); + _v = _mm256_comp_fmadd_ps(y_coeffs2, value_f[2], _v); + _v = _mm256_comp_fmadd_ps(y_coeffs3, value_f[3], _v); + _mm256_storeu_ps(dstptr, _v); + + dstptr += 8; + offset_value_ptr += 18; + } + } +} + +#endif // __AVX__ +static void cubic_interp1d_p4(__m128& coeffs0, __m128& coeffs1, __m128& coeffs2, __m128& coeffs3, const __m128& tx) +{ + const __m128 A = _mm_set_ps1(-0.75f); + + const __m128 x0 = _mm_add_ps(tx, _mm_set_ps1(1.0f)); + const __m128& x1 = tx; + const __m128 x2 = _mm_sub_ps(_mm_set_ps1(1.0f), tx); + //const __m128 x3 = _mm_add_ps(x2, _mm_set_ps1(1.0f)); + + coeffs0 = _mm_sub_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_sub_ps(_mm_mul_ps(A, x0), _mm_mul_ps(_mm_set_ps1(5.0f), A)), x0), _mm_mul_ps(_mm_set_ps1(8.0f), A)), x0), _mm_mul_ps(_mm_set_ps1(4), A)); + coeffs1 = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_mm_sub_ps(_mm_mul_ps(_mm_add_ps(A, _mm_set_ps1(2.0f)), x1), _mm_add_ps(A, _mm_set_ps1(3.0f))), x1), x1), _mm_set_ps1(1.0f)); + coeffs2 = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_mm_sub_ps(_mm_mul_ps(_mm_add_ps(A, _mm_set_ps1(2.0f)), x2), _mm_add_ps(A, _mm_set_ps1(3.0f))), x2), x2), _mm_set_ps1(1.0f)); + coeffs3 = _mm_sub_ps(_mm_sub_ps(_mm_sub_ps(_mm_set_ps1(1.0f), coeffs0), coeffs1), coeffs2); +} + +static void gridsample_2d_bicubic_apply_interpolation_p4(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + __m128 x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + __m128 y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + __m128 value_f[4]; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + cubic_interp1d_p4(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, _mm_set_ps1(offset_value_ptr[0])); + cubic_interp1d_p4(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, _mm_set_ps1(offset_value_ptr[1])); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + __m128 x0_val = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + __m128 x1_val = offset_ptr[1] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[1]) : _mm_set1_ps(0); + __m128 x2_val = offset_ptr[2] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[2]) : _mm_set1_ps(0); + __m128 x3_val = offset_ptr[3] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[3]) : _mm_set1_ps(0); + + value_f[ii] = _mm_mul_ps(x_coeffs0, x0_val); + value_f[ii] = _mm_comp_fmadd_ps(x_coeffs1, x1_val, value_f[ii]); + value_f[ii] = _mm_comp_fmadd_ps(x_coeffs2, x2_val, value_f[ii]); + value_f[ii] = _mm_comp_fmadd_ps(x_coeffs3, x3_val, value_f[ii]); + + offset_ptr += 4; + } + + __m128 _v = _mm_mul_ps(y_coeffs0, value_f[0]); + _v = _mm_comp_fmadd_ps(y_coeffs1, value_f[1], _v); + _v = _mm_comp_fmadd_ps(y_coeffs2, value_f[2], _v); + _v = _mm_comp_fmadd_ps(y_coeffs3, value_f[3], _v); + _mm_storeu_ps(dstptr, _v); + + dstptr += 4; + offset_value_ptr += 18; + } + } +} + +#endif // __SSE2__ + +static inline void cubic_interp1d(float& coeffs0, float& coeffs1, float& coeffs2, float& coeffs3, float fx) +{ + const float A = -0.75f; + + float fx0 = fx + 1; + float fx1 = fx; + float fx2 = 1 - fx; + // float fx3 = 2 - fx; + + coeffs0 = A * fx0 * fx0 * fx0 - 5 * A * fx0 * fx0 + 8 * A * fx0 - 4 * A; + coeffs1 = (A + 2) * fx1 * fx1 * fx1 - (A + 3) * fx1 * fx1 + 1; + coeffs2 = (A + 2) * fx2 * fx2 * fx2 - (A + 3) * fx2 * fx2 + 1; + coeffs3 = 1.f - coeffs0 - coeffs1 - coeffs2; +} + +static void gridsample_2d_bicubic_apply_interpolation_p1(const Mat& src, Mat& dst, Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + float x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3; + float y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3; + float value_f[4]; + cubic_interp1d(x_coeffs0, x_coeffs1, x_coeffs2, x_coeffs3, offset_value_ptr[0]); + cubic_interp1d(y_coeffs0, y_coeffs1, y_coeffs2, y_coeffs3, offset_value_ptr[1]); + + const int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int ii = 0; ii < 4; ii++) + { + float x0_val = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + float x1_val = offset_ptr[1] >= 0 ? *(srcptr + offset_ptr[1]) : 0; + float x2_val = offset_ptr[2] >= 0 ? *(srcptr + offset_ptr[2]) : 0; + float x3_val = offset_ptr[3] >= 0 ? *(srcptr + offset_ptr[3]) : 0; + + value_f[ii] = x_coeffs0 * x0_val; + value_f[ii] = x_coeffs1 * x1_val + value_f[ii]; + value_f[ii] = x_coeffs2 * x2_val + value_f[ii]; + value_f[ii] = x_coeffs3 * x3_val + value_f[ii]; + + offset_ptr += 4; + } + + float _v = y_coeffs0 * value_f[0]; + _v = y_coeffs1 * value_f[1] + _v; + _v = y_coeffs2 * value_f[2] + _v; + _v = y_coeffs3 * value_f[3] + _v; + *dstptr = _v; + + dstptr++; + offset_value_ptr += 18; + } + } +} \ No newline at end of file diff --git a/src/layer/x86/gridsample_bicubic_compute_blob.h b/src/layer/x86/gridsample_bicubic_compute_blob.h new file mode 100644 index 000000000000..9006153d9d26 --- /dev/null +++ b/src/layer/x86/gridsample_bicubic_compute_blob.h @@ -0,0 +1,299 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +template +void gridsample_2d_bicubic_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h; + + float* offset_value_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 15 < grid_size; x += 16) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + + transpose2x8_ps(gx, gy); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gy = unormalize(_mm256_set1_ps(src.h), gy); + + __m256 gx_floor = _mm256_floor_ps(gx); + __m256 gy_floor = _mm256_floor_ps(gy); + + __m256 tx = _mm256_sub_ps(gx, gx_floor); + __m256 ty = _mm256_sub_ps(gy, gy_floor); + + __m256 gx0 = _mm256_add_ps(gx_floor, _mm256_set1_ps(-1)); + __m256 gx1 = gx_floor; + __m256 gx2 = _mm256_add_ps(gx_floor, _mm256_set1_ps(1)); + __m256 gx3 = _mm256_add_ps(gx2, _mm256_set1_ps(1)); + + gx0 = get_coord(_mm256_set1_ps(src.w), gx0); + gx1 = get_coord(_mm256_set1_ps(src.w), gx1); + gx2 = get_coord(_mm256_set1_ps(src.w), gx2); + gx3 = get_coord(_mm256_set1_ps(src.w), gx3); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(gx0, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx0, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(gx1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx1, _CMP_GT_OS)); + __m256 x2_in_range = _mm256_and_ps(_mm256_cmp_ps(gx2, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx2, _CMP_GT_OS)); + __m256 x3_in_range = _mm256_and_ps(_mm256_cmp_ps(gx3, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx3, _CMP_GT_OS)); + __m256 v0_offset_f[4], v1_offset_f[4], v2_offset_f[4], v3_offset_f[4]; + for (int i = 0; i < 4; i++) + { + gy = _mm256_add_ps(gy_floor, _mm256_set1_ps(-1.0f + i)); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 y_in_range = _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS)); + + __m256 gy_offset = _mm256_mul_ps(gy, _mm256_set1_ps(src.w)); + + v0_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx0), _mm256_set1_ps(src.elempack)); + v1_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx1), _mm256_set1_ps(src.elempack)); + v2_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx2), _mm256_set1_ps(src.elempack)); + v3_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx3), _mm256_set1_ps(src.elempack)); + + v0_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v0_offset_f[i], _mm256_and_ps(x0_in_range, y_in_range)); + v1_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v1_offset_f[i], _mm256_and_ps(x1_in_range, y_in_range)); + v2_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v2_offset_f[i], _mm256_and_ps(x2_in_range, y_in_range)); + v3_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v3_offset_f[i], _mm256_and_ps(x3_in_range, y_in_range)); + + v0_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v0_offset_f[i])); + v1_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v1_offset_f[i])); + v2_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v2_offset_f[i])); + v3_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v3_offset_f[i])); + } + + transpose8x18_ps(tx, ty, v0_offset_f[0], v1_offset_f[0], v2_offset_f[0], v3_offset_f[0], v0_offset_f[1], v1_offset_f[1], v2_offset_f[1], v3_offset_f[1], v0_offset_f[2], v1_offset_f[2], v2_offset_f[2], v3_offset_f[2], v0_offset_f[3], v1_offset_f[3], v2_offset_f[3], v3_offset_f[3]); + + _mm256_storeu_ps(offset_value_ptr, tx); + _mm256_storeu_ps(offset_value_ptr + 8, ty); + offset_value_ptr += 16; + for (int i = 0; i < 4; i++) + { + _mm256_storeu_ps(offset_value_ptr, v0_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 8, v1_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 16, v2_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 24, v3_offset_f[i]); + offset_value_ptr += 32; + } + gridptr += 16; + } + +#endif // __AVX__ + + for (; x < grid_size; x += 2) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + + sample_x = unormalize(src.w, sample_x); + sample_y = unormalize(src.h, sample_y); + + int x1 = floorf(sample_x); + int y1 = floorf(sample_y); + int x0 = x1 - 1; + int x2 = x1 + 1; + int x3 = x1 + 2; + + offset_value_ptr[0] = sample_x - static_cast(x1); + offset_value_ptr[1] = sample_y - static_cast(y1); + + x1 = get_coord(src.w, x1); + x0 = get_coord(src.w, x0); + x2 = get_coord(src.w, x2); + x3 = get_coord(src.w, x3); + + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool x2_in_range = (x2 > -1) & (x2 < src.w); + bool x3_in_range = (x3 > -1) & (x3 < src.w); + + int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int i = 0; i < 4; i++) + { + int gy = y1 + i - 1; + gy = get_coord(src.h, gy); + int offset_y = gy * src.w; + + bool y_in_range = (gy > -1) & (gy < src.h); + + bool v0_in_bound = (x0_in_range & y_in_range); + bool v1_in_bound = (x1_in_range & y_in_range); + bool v2_in_bound = (x2_in_range & y_in_range); + bool v3_in_bound = (x3_in_range & y_in_range); + + offset_ptr[0] = v0_in_bound ? (offset_y + x0) * src.elempack : -1.0; + offset_ptr[1] = v1_in_bound ? (offset_y + x1) * src.elempack : -1.0; + offset_ptr[2] = v2_in_bound ? (offset_y + x2) * src.elempack : -1.0; + offset_ptr[3] = v3_in_bound ? (offset_y + x3) * src.elempack : -1.0; + + offset_ptr += 4; + } + + gridptr += 2; + offset_value_ptr += 18; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gy = unormalize(_mm256_set1_ps(src.h), gy); + + __m256 gx_floor = _mm256_floor_ps(gx); + __m256 gy_floor = _mm256_floor_ps(gy); + + __m256 tx = _mm256_sub_ps(gx, gx_floor); + __m256 ty = _mm256_sub_ps(gy, gy_floor); + + __m256 gx0 = _mm256_add_ps(gx_floor, _mm256_set1_ps(-1)); + __m256 gx1 = gx_floor; + __m256 gx2 = _mm256_add_ps(gx_floor, _mm256_set1_ps(1)); + __m256 gx3 = _mm256_add_ps(gx2, _mm256_set1_ps(1)); + + gx0 = get_coord(_mm256_set1_ps(src.w), gx0); + gx1 = get_coord(_mm256_set1_ps(src.w), gx1); + gx2 = get_coord(_mm256_set1_ps(src.w), gx2); + gx3 = get_coord(_mm256_set1_ps(src.w), gx3); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(gx0, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx0, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(gx1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx1, _CMP_GT_OS)); + __m256 x2_in_range = _mm256_and_ps(_mm256_cmp_ps(gx2, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx2, _CMP_GT_OS)); + __m256 x3_in_range = _mm256_and_ps(_mm256_cmp_ps(gx3, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx3, _CMP_GT_OS)); + + __m256 v0_offset_f[4], v1_offset_f[4], v2_offset_f[4], v3_offset_f[4]; + for (int i = 0; i < 4; i++) + { + gy = _mm256_add_ps(gy_floor, _mm256_set1_ps(-1.0f + i)); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 y_in_range = _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS)); + + __m256 gy_offset = _mm256_mul_ps(gy, _mm256_set1_ps(src.w)); + + v0_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx0), _mm256_set1_ps(src.elempack)); + v1_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx1), _mm256_set1_ps(src.elempack)); + v2_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx2), _mm256_set1_ps(src.elempack)); + v3_offset_f[i] = _mm256_mul_ps(_mm256_add_ps(gy_offset, gx3), _mm256_set1_ps(src.elempack)); + + v0_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v0_offset_f[i], _mm256_and_ps(x0_in_range, y_in_range)); + v1_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v1_offset_f[i], _mm256_and_ps(x1_in_range, y_in_range)); + v2_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v2_offset_f[i], _mm256_and_ps(x2_in_range, y_in_range)); + v3_offset_f[i] = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), v3_offset_f[i], _mm256_and_ps(x3_in_range, y_in_range)); + + v0_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v0_offset_f[i])); + v1_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v1_offset_f[i])); + v2_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v2_offset_f[i])); + v3_offset_f[i] = _mm256_castsi256_ps(_mm256_cvtps_epi32(v3_offset_f[i])); + } + + transpose8x18_ps(tx, ty, v0_offset_f[0], v1_offset_f[0], v2_offset_f[0], v3_offset_f[0], v0_offset_f[1], v1_offset_f[1], v2_offset_f[1], v3_offset_f[1], v0_offset_f[2], v1_offset_f[2], v2_offset_f[2], v3_offset_f[2], v0_offset_f[3], v1_offset_f[3], v2_offset_f[3], v3_offset_f[3]); + + _mm256_storeu_ps(offset_value_ptr, tx); + _mm256_storeu_ps(offset_value_ptr + 8, ty); + offset_value_ptr += 16; + for (int i = 0; i < 4; i++) + { + _mm256_storeu_ps(offset_value_ptr, v0_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 8, v1_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 16, v2_offset_f[i]); + _mm256_storeu_ps(offset_value_ptr + 24, v3_offset_f[i]); + offset_value_ptr += 32; + } + + gridptr_x += 8; + gridptr_y += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = unormalize(src.w, sample_x); + sample_y = unormalize(src.h, sample_y); + + int x1 = floorf(sample_x); + int y1 = floorf(sample_y); + int x0 = x1 - 1; + int x2 = x1 + 1; + int x3 = x1 + 2; + + offset_value_ptr[0] = sample_x - static_cast(x1); + offset_value_ptr[1] = sample_y - static_cast(y1); + + x1 = get_coord(src.w, x1); + x0 = get_coord(src.w, x0); + x2 = get_coord(src.w, x2); + x3 = get_coord(src.w, x3); + + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool x2_in_range = (x2 > -1) & (x2 < src.w); + bool x3_in_range = (x3 > -1) & (x3 < src.w); + + int* offset_ptr = (int*)offset_value_ptr + 2; + + for (int i = 0; i < 4; i++) + { + int gy = y1 + i - 1; + gy = get_coord(src.h, gy); + int offset_y = gy * src.w; + + bool y_in_range = (gy > -1) & (gy < src.h); + + bool v0_in_bound = (x0_in_range & y_in_range); + bool v1_in_bound = (x1_in_range & y_in_range); + bool v2_in_bound = (x2_in_range & y_in_range); + bool v3_in_bound = (x3_in_range & y_in_range); + + offset_ptr[0] = v0_in_bound ? (offset_y + x0) * src.elempack : -1.0; + offset_ptr[1] = v1_in_bound ? (offset_y + x1) * src.elempack : -1.0; + offset_ptr[2] = v2_in_bound ? (offset_y + x2) * src.elempack : -1.0; + offset_ptr[3] = v3_in_bound ? (offset_y + x3) * src.elempack : -1.0; + + offset_ptr += 4; + } + + gridptr_x++; + gridptr_y++; + + offset_value_ptr += 18; + } + } +} diff --git a/src/layer/x86/gridsample_bilinear_apply_interpolation.h b/src/layer/x86/gridsample_bilinear_apply_interpolation.h new file mode 100644 index 000000000000..0af661b6ca0e --- /dev/null +++ b/src/layer/x86/gridsample_bilinear_apply_interpolation.h @@ -0,0 +1,373 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ +static void gridsample_2d_bilinear_apply_interpolation_p16(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + __m512 v00_val = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + __m512 v01_val = offset_ptr[1] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[1]) : _mm512_set1_ps(0); + __m512 v10_val = offset_ptr[2] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[2]) : _mm512_set1_ps(0); + __m512 v11_val = offset_ptr[3] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[3]) : _mm512_set1_ps(0); + + __m512 value1 = _mm512_set1_ps(value_ptr[0]); + __m512 v0 = _mm512_fmadd_ps(v01_val, value1, _mm512_fnmadd_ps(v00_val, value1, v00_val)); + __m512 v1 = _mm512_fmadd_ps(v11_val, value1, _mm512_fnmadd_ps(v10_val, value1, v10_val)); + + __m512 value2 = _mm512_set1_ps(value_ptr[1]); + __m512 _v = _mm512_fmadd_ps(v1, value2, _mm512_fnmadd_ps(v0, value2, v0)); + _mm512_storeu_ps(dstptr, _v); + + dstptr += 16; + offset_value_ptr += 6; + } + } +} + +static void gridsample_3d_bilinear_apply_interpolation_p16(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + __m512 v000_val = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + __m512 v001_val = offset_ptr[1] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[1]) : _mm512_set1_ps(0); + __m512 v010_val = offset_ptr[2] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[2]) : _mm512_set1_ps(0); + __m512 v011_val = offset_ptr[3] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[3]) : _mm512_set1_ps(0); + + __m512 v100_val = offset_ptr[4] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[4]) : _mm512_set1_ps(0); + __m512 v101_val = offset_ptr[5] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[5]) : _mm512_set1_ps(0); + __m512 v110_val = offset_ptr[6] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[6]) : _mm512_set1_ps(0); + __m512 v111_val = offset_ptr[7] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[7]) : _mm512_set1_ps(0); + + __m512 value = _mm512_set1_ps(value_ptr[0]); + __m512 v00 = _mm512_fmadd_ps(v001_val, value, _mm512_fnmadd_ps(v000_val, value, v000_val)); + __m512 v01 = _mm512_fmadd_ps(v011_val, value, _mm512_fnmadd_ps(v010_val, value, v010_val)); + __m512 v10 = _mm512_fmadd_ps(v101_val, value, _mm512_fnmadd_ps(v100_val, value, v100_val)); + __m512 v11 = _mm512_fmadd_ps(v111_val, value, _mm512_fnmadd_ps(v110_val, value, v110_val)); + + value = _mm512_set1_ps(value_ptr[1]); + __m512 v0 = _mm512_fmadd_ps(v01, value, _mm512_fnmadd_ps(v00, value, v00)); + __m512 v1 = _mm512_fmadd_ps(v11, value, _mm512_fnmadd_ps(v10, value, v10)); + + value = _mm512_set1_ps(value_ptr[2]); + __m512 _v = _mm512_fmadd_ps(v1, value, _mm512_fnmadd_ps(v0, value, v0)); + _mm512_storeu_ps(dstptr, _v); + + dstptr += 16; + offset_value_ptr += 11; + } + } +} + +#endif // __AVX512F__ + +static void gridsample_2d_bilinear_apply_interpolation_p8(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + __m256 v00_val = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + __m256 v01_val = offset_ptr[1] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[1]) : _mm256_set1_ps(0); + __m256 v10_val = offset_ptr[2] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[2]) : _mm256_set1_ps(0); + __m256 v11_val = offset_ptr[3] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[3]) : _mm256_set1_ps(0); + + __m256 value1 = _mm256_set1_ps(value_ptr[0]); + __m256 v0 = _mm256_comp_fmadd_ps(v01_val, value1, _mm256_comp_fnmadd_ps(v00_val, value1, v00_val)); + __m256 v1 = _mm256_comp_fmadd_ps(v11_val, value1, _mm256_comp_fnmadd_ps(v10_val, value1, v10_val)); + + __m256 value2 = _mm256_set1_ps(value_ptr[1]); + __m256 _v = _mm256_comp_fmadd_ps(v1, value2, _mm256_comp_fnmadd_ps(v0, value2, v0)); + _mm256_storeu_ps(dstptr, _v); + + dstptr += 8; + offset_value_ptr += 6; + } + } +} +static void gridsample_3d_bilinear_apply_interpolation_p8(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + __m256 v000_val = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + __m256 v001_val = offset_ptr[1] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[1]) : _mm256_set1_ps(0); + __m256 v010_val = offset_ptr[2] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[2]) : _mm256_set1_ps(0); + __m256 v011_val = offset_ptr[3] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[3]) : _mm256_set1_ps(0); + + __m256 v100_val = offset_ptr[4] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[4]) : _mm256_set1_ps(0); + __m256 v101_val = offset_ptr[5] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[5]) : _mm256_set1_ps(0); + __m256 v110_val = offset_ptr[6] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[6]) : _mm256_set1_ps(0); + __m256 v111_val = offset_ptr[7] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[7]) : _mm256_set1_ps(0); + + __m256 value = _mm256_set1_ps(value_ptr[0]); + __m256 v00 = _mm256_comp_fmadd_ps(v001_val, value, _mm256_comp_fnmadd_ps(v000_val, value, v000_val)); + __m256 v01 = _mm256_comp_fmadd_ps(v011_val, value, _mm256_comp_fnmadd_ps(v010_val, value, v010_val)); + __m256 v10 = _mm256_comp_fmadd_ps(v101_val, value, _mm256_comp_fnmadd_ps(v100_val, value, v100_val)); + __m256 v11 = _mm256_comp_fmadd_ps(v111_val, value, _mm256_comp_fnmadd_ps(v110_val, value, v110_val)); + + value = _mm256_set1_ps(value_ptr[1]); + __m256 v0 = _mm256_comp_fmadd_ps(v01, value, _mm256_comp_fnmadd_ps(v00, value, v00)); + __m256 v1 = _mm256_comp_fmadd_ps(v11, value, _mm256_comp_fnmadd_ps(v10, value, v10)); + + value = _mm256_set1_ps(value_ptr[2]); + __m256 _v = _mm256_comp_fmadd_ps(v1, value, _mm256_comp_fnmadd_ps(v0, value, v0)); + _mm256_storeu_ps(dstptr, _v); + + dstptr += 8; + offset_value_ptr += 11; + } + } +} +#endif // __AVX__ +static void gridsample_2d_bilinear_apply_interpolation_p4(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + __m128 v00_val = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + __m128 v01_val = offset_ptr[1] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[1]) : _mm_set1_ps(0); + __m128 v10_val = offset_ptr[2] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[2]) : _mm_set1_ps(0); + __m128 v11_val = offset_ptr[3] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[3]) : _mm_set1_ps(0); + + __m128 value1 = _mm_set1_ps(value_ptr[0]); + __m128 v0 = _mm_comp_fmadd_ps(v01_val, value1, _mm_comp_fnmadd_ps(v00_val, value1, v00_val)); + __m128 v1 = _mm_comp_fmadd_ps(v11_val, value1, _mm_comp_fnmadd_ps(v10_val, value1, v10_val)); + + __m128 value2 = _mm_set1_ps(value_ptr[1]); + __m128 _v = _mm_comp_fmadd_ps(v1, value2, _mm_comp_fnmadd_ps(v0, value2, v0)); + _mm_storeu_ps(dstptr, _v); + + dstptr += 4; + offset_value_ptr += 6; + } + } +} +static void gridsample_3d_bilinear_apply_interpolation_p4(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + __m128 v000_val = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + __m128 v001_val = offset_ptr[1] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[1]) : _mm_set1_ps(0); + __m128 v010_val = offset_ptr[2] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[2]) : _mm_set1_ps(0); + __m128 v011_val = offset_ptr[3] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[3]) : _mm_set1_ps(0); + + __m128 v100_val = offset_ptr[4] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[4]) : _mm_set1_ps(0); + __m128 v101_val = offset_ptr[5] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[5]) : _mm_set1_ps(0); + __m128 v110_val = offset_ptr[6] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[6]) : _mm_set1_ps(0); + __m128 v111_val = offset_ptr[7] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[7]) : _mm_set1_ps(0); + + __m128 value = _mm_set1_ps(value_ptr[0]); + __m128 v00 = _mm_comp_fmadd_ps(v001_val, value, _mm_comp_fnmadd_ps(v000_val, value, v000_val)); + __m128 v01 = _mm_comp_fmadd_ps(v011_val, value, _mm_comp_fnmadd_ps(v010_val, value, v010_val)); + __m128 v10 = _mm_comp_fmadd_ps(v101_val, value, _mm_comp_fnmadd_ps(v100_val, value, v100_val)); + __m128 v11 = _mm_comp_fmadd_ps(v111_val, value, _mm_comp_fnmadd_ps(v110_val, value, v110_val)); + + value = _mm_set1_ps(value_ptr[1]); + __m128 v0 = _mm_comp_fmadd_ps(v01, value, _mm_comp_fnmadd_ps(v00, value, v00)); + __m128 v1 = _mm_comp_fmadd_ps(v11, value, _mm_comp_fnmadd_ps(v10, value, v10)); + + value = _mm_set1_ps(value_ptr[2]); + __m128 _v = _mm_comp_fmadd_ps(v1, value, _mm_comp_fnmadd_ps(v0, value, v0)); + _mm_storeu_ps(dstptr, _v); + + dstptr += 4; + offset_value_ptr += 11; + } + } +} +#pragma fenv_access(off) + +#pragma float_control(precise, off) +#endif // __SSE2__ + +static void gridsample_2d_bilinear_apply_interpolation_p1(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int grid_size = outw * outh; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 4; + + float v00 = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + float v01 = offset_ptr[1] >= 0 ? *(srcptr + offset_ptr[1]) : 0; + float v10 = offset_ptr[2] >= 0 ? *(srcptr + offset_ptr[2]) : 0; + float v11 = offset_ptr[3] >= 0 ? *(srcptr + offset_ptr[3]) : 0; + + float v0 = v00 * (1 - value_ptr[0]) + v01 * value_ptr[0]; + float v1 = v10 * (1 - value_ptr[0]) + v11 * value_ptr[0]; + + *dstptr = v0 * (1 - value_ptr[1]) + v1 * value_ptr[1]; + + dstptr++; + offset_value_ptr += 6; + } + } +} + +static void gridsample_3d_bilinear_apply_interpolation_p1(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const float* offset_value_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + const int* offset_ptr = (int*)offset_value_ptr; + const float* value_ptr = offset_value_ptr + 8; + + float v000 = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + float v001 = offset_ptr[1] >= 0 ? *(srcptr + offset_ptr[1]) : 0; + float v010 = offset_ptr[2] >= 0 ? *(srcptr + offset_ptr[2]) : 0; + float v011 = offset_ptr[3] >= 0 ? *(srcptr + offset_ptr[3]) : 0; + + float v100 = offset_ptr[4] >= 0 ? *(srcptr + offset_ptr[4]) : 0; + float v101 = offset_ptr[5] >= 0 ? *(srcptr + offset_ptr[5]) : 0; + float v110 = offset_ptr[6] >= 0 ? *(srcptr + offset_ptr[6]) : 0; + float v111 = offset_ptr[7] >= 0 ? *(srcptr + offset_ptr[7]) : 0; + + float v00 = v000 * (1 - value_ptr[0]) + v001 * value_ptr[0]; + float v01 = v010 * (1 - value_ptr[0]) + v011 * value_ptr[0]; + float v10 = v100 * (1 - value_ptr[0]) + v101 * value_ptr[0]; + float v11 = v110 * (1 - value_ptr[0]) + v111 * value_ptr[0]; + + float v0 = v00 * (1 - value_ptr[1]) + v01 * value_ptr[1]; + float v1 = v10 * (1 - value_ptr[1]) + v11 * value_ptr[1]; + + *dstptr = v0 * (1 - value_ptr[2]) + v1 * value_ptr[2]; + + dstptr++; + offset_value_ptr += 11; + } + } +} \ No newline at end of file diff --git a/src/layer/x86/gridsample_bilinear_compute_blob.h b/src/layer/x86/gridsample_bilinear_compute_blob.h new file mode 100644 index 000000000000..f78e017c1c75 --- /dev/null +++ b/src/layer/x86/gridsample_bilinear_compute_blob.h @@ -0,0 +1,623 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +template +void gridsample_2d_bilinear_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h; + + float* offset_value_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 15 < grid_size; x += 16) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + + transpose2x8_ps(gx, gy); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + __m256 nw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w), _mm256_set1_ps(src.elempack)); + __m256 ne_offset = _mm256_add_ps(nw_offset, _mm256_set1_ps(src.elempack)); + __m256 sw_offset = _mm256_comp_fmadd_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack), nw_offset); + __m256 se_offset = _mm256_add_ps(sw_offset, _mm256_set1_ps(src.elempack)); + + nw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), nw_offset, v00_in_range); + ne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), ne_offset, v01_in_range); + sw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), sw_offset, v10_in_range); + se_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), se_offset, v11_in_range); + + nw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(nw_offset)); + ne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(ne_offset)); + sw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(sw_offset)); + se_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(se_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + + transpose8x6_ps(nw_offset, ne_offset, sw_offset, se_offset, alpha, beta); + + _mm256_storeu_ps(offset_value_ptr, nw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, ne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, sw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, se_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, alpha); + _mm256_storeu_ps(offset_value_ptr + 40, beta); + + gridptr += 16; + offset_value_ptr += 48; + } +#endif // __AVX__ + + for (; x < grid_size; x += 2) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int x1 = x0 + 1; + int y1 = y0 + 1; + + bool x0_in_bound = (x0 > -1) & (x0 < src.w); + bool x1_in_bound = (x1 > -1) & (x1 < src.w); + bool y0_in_bound = (y0 > -1) & (y0 < src.h); + bool y1_in_bound = (y1 > -1) & (y1 < src.h); + + bool in_bound_00 = x0_in_bound & y0_in_bound; + bool in_bound_01 = x1_in_bound & y0_in_bound; + bool in_bound_10 = x0_in_bound & y1_in_bound; + bool in_bound_11 = x1_in_bound & y1_in_bound; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 4; + + offset_ptr[0] = in_bound_00 ? (x0 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[1] = in_bound_01 ? (x1 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[2] = in_bound_10 ? (x0 + y1 * src.w) * src.elempack : -1.0; + offset_ptr[3] = in_bound_11 ? (x1 + y1 * src.w) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + + gridptr += 2; + offset_value_ptr += 6; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + __m256 nw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w), _mm256_set1_ps(src.elempack)); + __m256 ne_offset = _mm256_add_ps(nw_offset, _mm256_set1_ps(src.elempack)); + __m256 sw_offset = _mm256_comp_fmadd_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack), nw_offset); + __m256 se_offset = _mm256_add_ps(sw_offset, _mm256_set1_ps(src.elempack)); + + nw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), nw_offset, v00_in_range); + ne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), ne_offset, v01_in_range); + sw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), sw_offset, v10_in_range); + se_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), se_offset, v11_in_range); + + nw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(nw_offset)); + ne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(ne_offset)); + sw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(sw_offset)); + se_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(se_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + + transpose8x6_ps(nw_offset, ne_offset, sw_offset, se_offset, alpha, beta); + + _mm256_storeu_ps(offset_value_ptr, nw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, ne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, sw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, se_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, alpha); + _mm256_storeu_ps(offset_value_ptr + 40, beta); + + gridptr_x += 8; + gridptr_y += 8; + offset_value_ptr += 48; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int x1 = x0 + 1; + int y1 = y0 + 1; + + bool x0_in_bound = (x0 > -1) & (x0 < src.w); + bool x1_in_bound = (x1 > -1) & (x1 < src.w); + bool y0_in_bound = (y0 > -1) & (y0 < src.h); + bool y1_in_bound = (y1 > -1) & (y1 < src.h); + + bool in_bound_00 = x0_in_bound & y0_in_bound; + bool in_bound_01 = x1_in_bound & y0_in_bound; + bool in_bound_10 = x0_in_bound & y1_in_bound; + bool in_bound_11 = x1_in_bound & y1_in_bound; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 4; + + offset_ptr[0] = in_bound_00 ? (x0 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[1] = in_bound_01 ? (x1 + y0 * src.w) * src.elempack : -1.0; + offset_ptr[2] = in_bound_10 ? (x0 + y1 * src.w) * src.elempack : -1.0; + offset_ptr[3] = in_bound_11 ? (x1 + y1 * src.w) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + + gridptr_x++; + gridptr_y++; + offset_value_ptr += 6; + } + } +} + +template +void gridsample_3d_bilinear_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h * grid.d; + + float* offset_value_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 23 < grid_size; x += 24) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + __m256 gz = _mm256_loadu_ps(gridptr + 16); + + transpose3x8_ps(gx, gy, gz); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + __m256 z_t = _mm256_floor_ps(gz); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + __m256 z1 = _mm256_add_ps(z_t, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + __m256 z0_in_range = _mm256_and_ps(_mm256_cmp_ps(z_t, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z_t, _CMP_GT_OS)); + __m256 z1_in_range = _mm256_and_ps(_mm256_cmp_ps(z1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z1, _CMP_GT_OS)); + + __m256 v000_in_range, v010_in_range, v100_in_range, v110_in_range, v001_in_range, v011_in_range, v101_in_range, v111_in_range; + { + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + v000_in_range = _mm256_and_ps(v00_in_range, z0_in_range); + v001_in_range = _mm256_and_ps(v01_in_range, z0_in_range); + v010_in_range = _mm256_and_ps(v10_in_range, z0_in_range); + v011_in_range = _mm256_and_ps(v11_in_range, z0_in_range); + + v100_in_range = _mm256_and_ps(v00_in_range, z1_in_range); + v101_in_range = _mm256_and_ps(v01_in_range, z1_in_range); + v110_in_range = _mm256_and_ps(v10_in_range, z1_in_range); + v111_in_range = _mm256_and_ps(v11_in_range, z1_in_range); + } + + __m256 tnw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), z_t, + _mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w)), + _mm256_set1_ps(src.elempack)); + __m256 tne_offset = _mm256_add_ps(tnw_offset, _mm256_set1_ps(src.elempack)); + __m256 tsw_offset = _mm256_add_ps(tnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 tse_offset = _mm256_add_ps(tsw_offset, _mm256_set1_ps(src.elempack)); + + __m256 bnw_offset = _mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), _mm256_set1_ps(src.elempack), tnw_offset); + __m256 bne_offset = _mm256_add_ps(bnw_offset, _mm256_set1_ps(src.elempack)); + __m256 bsw_offset = _mm256_add_ps(bnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 bse_offset = _mm256_add_ps(bsw_offset, _mm256_set1_ps(src.elempack)); + + tnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tnw_offset, v000_in_range); + tne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tne_offset, v001_in_range); + tsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tsw_offset, v010_in_range); + tse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tse_offset, v011_in_range); + + bnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bnw_offset, v100_in_range); + bne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bne_offset, v101_in_range); + bsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bsw_offset, v110_in_range); + bse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bse_offset, v111_in_range); + + tnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tnw_offset)); + tne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tne_offset)); + tsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tsw_offset)); + tse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tse_offset)); + + bnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bnw_offset)); + bne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bne_offset)); + bsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bsw_offset)); + bse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bse_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + __m256 gamma = _mm256_sub_ps(gz, z_t); + + transpose8x11_ps(tnw_offset, tne_offset, tsw_offset, tse_offset, bnw_offset, bne_offset, bsw_offset, bse_offset, alpha, beta, gamma); + + _mm256_storeu_ps(offset_value_ptr, tnw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, tne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, tsw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, tse_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, bnw_offset); + _mm256_storeu_ps(offset_value_ptr + 40, bne_offset); + _mm256_storeu_ps(offset_value_ptr + 48, bsw_offset); + _mm256_storeu_ps(offset_value_ptr + 56, bse_offset); + + _mm256_storeu_ps(offset_value_ptr + 64, alpha); + _mm256_storeu_ps(offset_value_ptr + 72, beta); + _mm256_storeu_ps(offset_value_ptr + 80, gamma); + + gridptr += 24; + + offset_value_ptr += 88; + } +#endif // __AVX__ + + for (; x < grid_size; x += 3) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + float sample_z = *(gridptr + 2); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int z0 = (int)floorf(sample_z); + int x1 = x0 + 1; + int y1 = y0 + 1; + int z1 = z0 + 1; + + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool y0_in_range = (y0 > -1) & (y0 < src.h); + bool z0_in_range = (z0 > -1) & (z0 < src.d); + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool y1_in_range = (y1 > -1) & (y1 < src.h); + bool z1_in_range = (z1 > -1) & (z1 < src.d); + + bool v00_in_range = x0_in_range & y0_in_range; + bool v01_in_range = x1_in_range & y0_in_range; + bool v10_in_range = x0_in_range & y1_in_range; + bool v11_in_range = x1_in_range & y1_in_range; + + bool in_bound_000 = v00_in_range & z0_in_range; + bool in_bound_001 = v01_in_range & z0_in_range; + bool in_bound_010 = v10_in_range & z0_in_range; + bool in_bound_011 = v11_in_range & z0_in_range; + + bool in_bound_100 = v00_in_range & z1_in_range; + bool in_bound_101 = v01_in_range & z1_in_range; + bool in_bound_110 = v10_in_range & z1_in_range; + bool in_bound_111 = v11_in_range & z1_in_range; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 8; + + offset_ptr[0] = in_bound_000 ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[1] = in_bound_001 ? (x1 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[2] = in_bound_010 ? (x0 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[3] = in_bound_011 ? (x1 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + offset_ptr[4] = in_bound_100 ? (x0 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[5] = in_bound_101 ? (x1 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[6] = in_bound_110 ? (x0 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[7] = in_bound_111 ? (x1 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + value_ptr[2] = sample_z - z0; + + gridptr += 3; + offset_value_ptr += 11; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + const float* gridptr_z = grid.channel(2); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + __m256 gz = _mm256_loadu_ps(gridptr_z); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + __m256 x_w = _mm256_floor_ps(gx); + __m256 y_n = _mm256_floor_ps(gy); + __m256 z_t = _mm256_floor_ps(gz); + + __m256 x1 = _mm256_add_ps(x_w, _mm256_set1_ps(1)); + __m256 y1 = _mm256_add_ps(y_n, _mm256_set1_ps(1)); + __m256 z1 = _mm256_add_ps(z_t, _mm256_set1_ps(1)); + + __m256 x0_in_range = _mm256_and_ps(_mm256_cmp_ps(x_w, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x_w, _CMP_GT_OS)); + __m256 x1_in_range = _mm256_and_ps(_mm256_cmp_ps(x1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), x1, _CMP_GT_OS)); + __m256 y0_in_range = _mm256_and_ps(_mm256_cmp_ps(y_n, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y_n, _CMP_GT_OS)); + __m256 y1_in_range = _mm256_and_ps(_mm256_cmp_ps(y1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), y1, _CMP_GT_OS)); + __m256 z0_in_range = _mm256_and_ps(_mm256_cmp_ps(z_t, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z_t, _CMP_GT_OS)); + __m256 z1_in_range = _mm256_and_ps(_mm256_cmp_ps(z1, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), z1, _CMP_GT_OS)); + + __m256 v000_in_range, v010_in_range, v100_in_range, v110_in_range, v001_in_range, v011_in_range, v101_in_range, v111_in_range; + { + __m256 v00_in_range = _mm256_and_ps(x0_in_range, y0_in_range); + __m256 v01_in_range = _mm256_and_ps(x1_in_range, y0_in_range); + __m256 v10_in_range = _mm256_and_ps(x0_in_range, y1_in_range); + __m256 v11_in_range = _mm256_and_ps(x1_in_range, y1_in_range); + + v000_in_range = _mm256_and_ps(v00_in_range, z0_in_range); + v001_in_range = _mm256_and_ps(v01_in_range, z0_in_range); + v010_in_range = _mm256_and_ps(v10_in_range, z0_in_range); + v011_in_range = _mm256_and_ps(v11_in_range, z0_in_range); + + v100_in_range = _mm256_and_ps(v00_in_range, z1_in_range); + v101_in_range = _mm256_and_ps(v01_in_range, z1_in_range); + v110_in_range = _mm256_and_ps(v10_in_range, z1_in_range); + v111_in_range = _mm256_and_ps(v11_in_range, z1_in_range); + } + + __m256 tnw_offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), z_t, + _mm256_comp_fmadd_ps(y_n, _mm256_set1_ps(src.w), x_w)), + _mm256_set1_ps(src.elempack)); + __m256 tne_offset = _mm256_add_ps(tnw_offset, _mm256_set1_ps(src.elempack)); + __m256 tsw_offset = _mm256_add_ps(tnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 tse_offset = _mm256_add_ps(tsw_offset, _mm256_set1_ps(src.elempack)); + + __m256 bnw_offset = _mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), _mm256_set1_ps(src.elempack), tnw_offset); + __m256 bne_offset = _mm256_add_ps(bnw_offset, _mm256_set1_ps(src.elempack)); + __m256 bsw_offset = _mm256_add_ps(bnw_offset, _mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.elempack))); + __m256 bse_offset = _mm256_add_ps(bsw_offset, _mm256_set1_ps(src.elempack)); + + tnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tnw_offset, v000_in_range); + tne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tne_offset, v001_in_range); + tsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tsw_offset, v010_in_range); + tse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), tse_offset, v011_in_range); + + bnw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bnw_offset, v100_in_range); + bne_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bne_offset, v101_in_range); + bsw_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bsw_offset, v110_in_range); + bse_offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), bse_offset, v111_in_range); + + tnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tnw_offset)); + tne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tne_offset)); + tsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tsw_offset)); + tse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(tse_offset)); + + bnw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bnw_offset)); + bne_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bne_offset)); + bsw_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bsw_offset)); + bse_offset = _mm256_castsi256_ps(_mm256_cvtps_epi32(bse_offset)); + + __m256 alpha = _mm256_sub_ps(gx, x_w); + __m256 beta = _mm256_sub_ps(gy, y_n); + __m256 gamma = _mm256_sub_ps(gz, z_t); + + transpose8x11_ps(tnw_offset, tne_offset, tsw_offset, tse_offset, bnw_offset, bne_offset, bsw_offset, bse_offset, alpha, beta, gamma); + + _mm256_storeu_ps(offset_value_ptr, tnw_offset); + _mm256_storeu_ps(offset_value_ptr + 8, tne_offset); + _mm256_storeu_ps(offset_value_ptr + 16, tsw_offset); + _mm256_storeu_ps(offset_value_ptr + 24, tse_offset); + + _mm256_storeu_ps(offset_value_ptr + 32, bnw_offset); + _mm256_storeu_ps(offset_value_ptr + 40, bne_offset); + _mm256_storeu_ps(offset_value_ptr + 48, bsw_offset); + _mm256_storeu_ps(offset_value_ptr + 56, bse_offset); + + _mm256_storeu_ps(offset_value_ptr + 64, alpha); + _mm256_storeu_ps(offset_value_ptr + 72, beta); + _mm256_storeu_ps(offset_value_ptr + 80, gamma); + + gridptr_x += 8; + gridptr_y += 8; + gridptr_z += 8; + + offset_value_ptr += 88; + } +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + float sample_z = *gridptr_z; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = (int)floorf(sample_x); + int y0 = (int)floorf(sample_y); + int z0 = (int)floorf(sample_z); + int x1 = x0 + 1; + int y1 = y0 + 1; + int z1 = z0 + 1; + + bool x0_in_range = (x0 > -1) & (x0 < src.w); + bool y0_in_range = (y0 > -1) & (y0 < src.h); + bool z0_in_range = (z0 > -1) & (z0 < src.d); + bool x1_in_range = (x1 > -1) & (x1 < src.w); + bool y1_in_range = (y1 > -1) & (y1 < src.h); + bool z1_in_range = (z1 > -1) & (z1 < src.d); + + bool v00_in_range = x0_in_range & y0_in_range; + bool v01_in_range = x1_in_range & y0_in_range; + bool v10_in_range = x0_in_range & y1_in_range; + bool v11_in_range = x1_in_range & y1_in_range; + + bool in_bound_000 = v00_in_range & z0_in_range; + bool in_bound_001 = v01_in_range & z0_in_range; + bool in_bound_010 = v10_in_range & z0_in_range; + bool in_bound_011 = v11_in_range & z0_in_range; + + bool in_bound_100 = v00_in_range & z1_in_range; + bool in_bound_101 = v01_in_range & z1_in_range; + bool in_bound_110 = v10_in_range & z1_in_range; + bool in_bound_111 = v11_in_range & z1_in_range; + + int* offset_ptr = (int*)offset_value_ptr; + float* value_ptr = offset_value_ptr + 8; + + offset_ptr[0] = in_bound_000 ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[1] = in_bound_001 ? (x1 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[2] = in_bound_010 ? (x0 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[3] = in_bound_011 ? (x1 + y1 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + offset_ptr[4] = in_bound_100 ? (x0 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[5] = in_bound_101 ? (x1 + y0 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[6] = in_bound_110 ? (x0 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + offset_ptr[7] = in_bound_111 ? (x1 + y1 * src.w + z1 * src.w * src.h) * src.elempack : -1.0; + + value_ptr[0] = sample_x - x0; + value_ptr[1] = sample_y - y0; + value_ptr[2] = sample_z - z0; + + gridptr_x++; + gridptr_y++; + gridptr_z++; + offset_value_ptr += 11; + } + } +} diff --git a/src/layer/x86/gridsample_compute_blob.h b/src/layer/x86/gridsample_compute_blob.h new file mode 100644 index 000000000000..4fb41f2cc247 --- /dev/null +++ b/src/layer/x86/gridsample_compute_blob.h @@ -0,0 +1,145 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "x86_usability.h" + +template +struct grid_sample_unormalize; + +template<> +struct grid_sample_unormalize +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + return _mm256_mul_ps(_mm256_div_ps(_mm256_add_ps(coord, _mm256_set1_ps(1)), _mm256_set1_ps(2)), _mm256_sub_ps(length, _mm256_set1_ps(1))); + } +#endif // __AVX__ + float operator()(int length, float coord) + { + return (coord + 1) / 2.f * (length - 1); + } +}; + +template<> +struct grid_sample_unormalize +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + return _mm256_div_ps(_mm256_comp_fmsub_ps(_mm256_add_ps(coord, _mm256_set1_ps(1)), length, _mm256_set1_ps(1)), _mm256_set1_ps(2)); + } +#endif // __AVX__ + float operator()(int length, float coord) + { + return ((coord + 1) * length - 1) / 2.f; + } +}; + +template +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + return coord; + } +#endif // __AVX__ + float operator()(int length, float coord) + { + return coord; + } +}; + +template +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + const __m256 border_x = _mm256_sub_ps(length, _mm256_set1_ps(1)); + + coord = _mm256_min_ps(border_x, _mm256_max_ps(coord, _mm256_setzero_ps())); + + return coord; + } +#endif // __AVX__ + float operator()(int length, float coord) + { + return std::min(length - 1.0f, std::max(coord, 0.0f)); + } +}; + +template<> +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + const __m256 border_x = _mm256_sub_ps(length, _mm256_set1_ps(1)); + + coord = abs256_ps(coord); + + __m256 reflectx_v = abs256_ps(_mm256_sub_ps(coord, border_x)); + coord = _mm256_sub_ps(border_x, reflectx_v); + + return coord; + } +#endif // __AVX__ + float operator()(int length, float coord) + { + coord = fabs(coord); + coord = (length - 1) - fabs(coord - (length - 1)); + + return std::min(length - 1.0f, std::max(coord, 0.0f)); + } +}; + +template<> +struct compute_coord +{ +#if __AVX__ + __m256 operator()(__m256 length, __m256 coord) + { + const __m256 border_x = _mm256_sub_ps(length, _mm256_set1_ps(1)); + + __m256 v0p5fp8 = _mm256_set1_ps(0.5f); + coord = _mm256_add_ps(coord, v0p5fp8); + + coord = abs256_ps(coord); + + __m256 reflectx_v = abs256_ps(_mm256_sub_ps(coord, length)); + coord = _mm256_sub_ps(length, reflectx_v); + + coord = _mm256_sub_ps(coord, v0p5fp8); + + _mm256_sub_ps(coord, v0p5fp8); + + coord = _mm256_min_ps(border_x, _mm256_max_ps(coord, _mm256_setzero_ps())); + + return coord; + } +#endif // __AVX__ + float operator()(int length, float coord) + { + coord = fabs(coord + 0.5f); + coord = length - fabs(coord - length) - 0.5; + + return std::min(length - 1.0f, std::max(coord, 0.0f)); + } +}; + +#include "gridsample_bilinear_compute_blob.h" +#include "gridsample_bicubic_compute_blob.h" +#include "gridsample_nearest_compute_blob.h" \ No newline at end of file diff --git a/src/layer/x86/gridsample_nearest_apply_interpolation.h b/src/layer/x86/gridsample_nearest_apply_interpolation.h new file mode 100644 index 000000000000..e84cdc7de25e --- /dev/null +++ b/src/layer/x86/gridsample_nearest_apply_interpolation.h @@ -0,0 +1,126 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ +static void gridsample_nearest_apply_interpolation_p16(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m512 _v = offset_ptr[0] >= 0 ? _mm512_loadu_ps(srcptr + offset_ptr[0]) : _mm512_set1_ps(0); + offset_ptr++; + + _mm512_storeu_ps(dstptr, _v); + dstptr += 16; + } + } +} +#endif // __AVX512F__ + +static void gridsample_nearest_apply_interpolation_p8(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m256 _v = offset_ptr[0] >= 0 ? _mm256_loadu_ps(srcptr + offset_ptr[0]) : _mm256_set1_ps(0); + offset_ptr++; + + _mm256_storeu_ps(dstptr, _v); + dstptr += 8; + } + } +} +#endif // __AVX__ +static void gridsample_nearest_apply_interpolation_p4(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int i = 0; i < grid_size; i++) + { + __m128 _v = offset_ptr[0] >= 0 ? _mm_loadu_ps(srcptr + offset_ptr[0]) : _mm_set1_ps(0); + offset_ptr++; + + _mm_storeu_ps(dstptr, _v); + dstptr += 4; + } + } +} + +#endif // __SSE2__ + +static void gridsample_nearest_apply_interpolation_p1(const Mat& src, Mat& dst, const Mat& offset_value, const Option& opt) +{ + const int channels = dst.c; + const int outw = dst.w; + const int outh = dst.h; + const int outd = dst.d; + const int grid_size = outw * outh * outd; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + const float* srcptr = src.channel(q); + float* dstptr = dst.channel(q); + + const int* offset_ptr = offset_value.channel(0); + + for (int x = 0; x < grid_size; x++) + { + *dstptr = offset_ptr[0] >= 0 ? *(srcptr + offset_ptr[0]) : 0; + + offset_ptr++; + dstptr++; + } + } +} diff --git a/src/layer/x86/gridsample_nearest_compute_blob.h b/src/layer/x86/gridsample_nearest_compute_blob.h new file mode 100644 index 000000000000..a7a12066d216 --- /dev/null +++ b/src/layer/x86/gridsample_nearest_compute_blob.h @@ -0,0 +1,315 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +template +void gridsample_2d_nearest_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h; + + float* offset_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 15 < grid_size; x += 16) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + + transpose2x8_ps(gx, gy); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx), _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr += 16; + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x += 2) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w) * src.elempack : -1.0; + + gridptr += 2; + offset_ptr++; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx), _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr_x += 8; + gridptr_y += 8; + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w) * src.elempack : -1.0; + + gridptr_x++; + gridptr_y++; + + offset_ptr++; + } + } +} + +template +void gridsample_3d_nearest_compute_blob(const Mat& src, const Mat& grid, Mat& offset_value, int permute_fusion) +{ + const int grid_size = grid.w * grid.h * grid.d; + + float* offset_ptr = offset_value.channel(0); + + grid_sample_unormalize unormalize; + compute_coord get_coord; + + if (permute_fusion == 0) + { + for (int y = 0; y < grid.c; y++) + { + const float* gridptr = grid.channel(y); + int x = 0; +#if __AVX__ + for (; x + 23 < grid_size; x += 24) + { + __m256 gx = _mm256_loadu_ps(gridptr); + __m256 gy = _mm256_loadu_ps(gridptr + 8); + __m256 gz = _mm256_loadu_ps(gridptr + 16); + + transpose3x8_ps(gx, gy, gz); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + gz = _mm256_floor_ps(_mm256_add_ps(gz, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + v_in_range = _mm256_and_ps(v_in_range, _mm256_and_ps(_mm256_cmp_ps(gz, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), gz, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), gz, + _mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx)), + _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr += 24; + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x += 3) + { + float sample_x = *gridptr; + float sample_y = *(gridptr + 1); + float sample_z = *(gridptr + 2); + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + int z0 = static_cast(floorf(sample_z + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h) & (z0 > -1) & (z0 < src.d)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + gridptr += 3; + offset_ptr++; + } + } + } + else + { + const float* gridptr_x = grid.channel(0); + const float* gridptr_y = grid.channel(1); + const float* gridptr_z = grid.channel(2); + + int x = 0; +#if __AVX__ + for (; x + 7 < grid_size; x += 8) + { + __m256 gx = _mm256_loadu_ps(gridptr_x); + __m256 gy = _mm256_loadu_ps(gridptr_y); + __m256 gz = _mm256_loadu_ps(gridptr_z); + + gx = unormalize(_mm256_set1_ps(src.w), gx); + gx = get_coord(_mm256_set1_ps(src.w), gx); + + gy = unormalize(_mm256_set1_ps(src.h), gy); + gy = get_coord(_mm256_set1_ps(src.h), gy); + + gz = unormalize(_mm256_set1_ps(src.d), gz); + gz = get_coord(_mm256_set1_ps(src.d), gz); + + gx = _mm256_floor_ps(_mm256_add_ps(gx, _mm256_set1_ps(0.5f))); + gy = _mm256_floor_ps(_mm256_add_ps(gy, _mm256_set1_ps(0.5f))); + gz = _mm256_floor_ps(_mm256_add_ps(gz, _mm256_set1_ps(0.5f))); + + __m256 v_in_range = _mm256_and_ps(_mm256_and_ps(_mm256_cmp_ps(gx, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.w), gx, _CMP_GT_OS)), + _mm256_and_ps(_mm256_cmp_ps(gy, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.h), gy, _CMP_GT_OS))); + v_in_range = _mm256_and_ps(v_in_range, _mm256_and_ps(_mm256_cmp_ps(gz, _mm256_set1_ps(-1), _CMP_GT_OS), _mm256_cmp_ps(_mm256_set1_ps(src.d), gz, _CMP_GT_OS))); + + __m256 offset = _mm256_mul_ps(_mm256_comp_fmadd_ps(_mm256_mul_ps(_mm256_set1_ps(src.w), _mm256_set1_ps(src.h)), gz, + _mm256_comp_fmadd_ps(gy, _mm256_set1_ps(src.w), gx)), + _mm256_set1_ps(src.elempack)); + + offset = _mm256_blendv_ps(_mm256_set1_ps(-1.0f), _mm256_castsi256_ps(_mm256_cvtps_epi32(offset)), v_in_range); + + _mm256_storeu_ps(offset_ptr, offset); + + gridptr_x += 8; + gridptr_y += 8; + gridptr_z += 8; + + offset_ptr += 8; + } + +#endif // __AVX__ + + for (; x < grid_size; x++) + { + float sample_x = *gridptr_x; + float sample_y = *gridptr_y; + float sample_z = *gridptr_z; + + sample_x = unormalize(src.w, sample_x); + sample_x = get_coord(src.w, sample_x); + + sample_y = unormalize(src.h, sample_y); + sample_y = get_coord(src.h, sample_y); + + sample_z = unormalize(src.d, sample_z); + sample_z = get_coord(src.d, sample_z); + + int x0 = static_cast(floorf(sample_x + 0.5f)); + int y0 = static_cast(floorf(sample_y + 0.5f)); + int z0 = static_cast(floorf(sample_z + 0.5f)); + + bool in_bound = ((x0 > -1) & (x0 < src.w) & (y0 > -1) & (y0 < src.h) & (z0 > -1) & (z0 < src.d)); + + int* iptr = (int*)offset_ptr; + *iptr = in_bound ? (x0 + y0 * src.w + z0 * src.w * src.h) * src.elempack : -1.0; + + gridptr_x++; + gridptr_y++; + gridptr_z++; + + offset_ptr++; + } + } +} diff --git a/src/layer/x86/gridsample_x86.cpp b/src/layer/x86/gridsample_x86.cpp new file mode 100644 index 000000000000..004bc4d0895b --- /dev/null +++ b/src/layer/x86/gridsample_x86.cpp @@ -0,0 +1,455 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "gridsample_x86.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" +#include "cpu.h" + +namespace ncnn { + +#include "gridsample_compute_blob.h" +#include "gridsample_bilinear_apply_interpolation.h" +#include "gridsample_bicubic_apply_interpolation.h" +#include "gridsample_nearest_apply_interpolation.h" + +GridSample_x86::GridSample_x86() +{ +#if __SSE2__ + support_packing = true; +#endif // __SSE2__ +} + +int GridSample_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& grid = bottom_blobs[1]; + Mat& top_blob = top_blobs[0]; + int elempack = bottom_blob.elempack; + + int channels = bottom_blob.c; + int dims = bottom_blob.dims; + size_t elemsize = bottom_blob.elemsize; + + int outw, outh, outd; + Mat offset_value_blob; + + Mat grid_p1; + if (grid.elempack != 1) + { + convert_packing(grid, grid_p1, 1, opt); + } + else + { + grid_p1 = grid; + } + + if (dims == 3) + { + outw = permute_fusion == 0 ? grid_p1.h : grid_p1.w; + outh = permute_fusion == 0 ? grid_p1.c : grid_p1.h; + + top_blob.create(outw, outh, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (sample_type == GridSample::Interpolation_BILINEAR) + { + offset_value_blob.create(outw, outh, elemsize * 6, 6, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == GridSample::Interpolation_NEAREST) + { + offset_value_blob.create(outw, outh, 1, elemsize, 1, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == GridSample::Interpolation_BICUBIC) + { + offset_value_blob.create(outw, outh, elemsize * 18, 18, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_2d_bicubic_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + } + + if (dims == 4) + { + outw = permute_fusion == 0 ? grid_p1.h : grid_p1.w; + outh = permute_fusion == 0 ? grid_p1.d : grid_p1.h; + outd = permute_fusion == 0 ? grid_p1.c : grid_p1.d; + + top_blob.create(outw, outh, outd, channels, elemsize, elempack, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + if (sample_type == GridSample::Interpolation_BILINEAR) + { + offset_value_blob.create(outw, outh, outd, elemsize * 11, 11, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_bilinear_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == GridSample::Interpolation_NEAREST) + { + offset_value_blob.create(outw, outh, outd, 1, elemsize, 1, opt.workspace_allocator); + if (offset_value_blob.empty()) + return -100; + + if (padding_mode == GridSample::Padding_ZEROS) + { + if (align_corner == 0) + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_BORDER) + { + if (align_corner == 0) + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else if (padding_mode == GridSample::Padding_REFLECTION) + { + if (align_corner == 0) + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + else + { + gridsample_3d_nearest_compute_blob(bottom_blob, grid_p1, offset_value_blob, permute_fusion); + } + } + else + { + NCNN_LOGE("gridsample padding_mode error\n"); + return -100; + } + } + + if (sample_type == 3) + { + NCNN_LOGE("unsupported bicubic when dims == 4"); + return -100; + } + } + +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p16(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } +#endif // __AVX512F__ + if (elempack == 8) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p8(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } + +#endif // __AVX__ + if (elempack == 4) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p4(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } + +#endif // __SSE2__ + + if (elempack == 1) + { + if (dims == 3) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_2d_bilinear_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_BICUBIC) + { + gridsample_2d_bicubic_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + } + else if (dims == 4) + { + if (sample_type == GridSample::Interpolation_BILINEAR) + { + gridsample_3d_bilinear_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + else if (sample_type == GridSample::Interpolation_NEAREST) + { + gridsample_nearest_apply_interpolation_p1(bottom_blob, top_blob, offset_value_blob, opt); + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/x86/gridsample_x86.h b/src/layer/x86/gridsample_x86.h new file mode 100644 index 000000000000..826414eefc9b --- /dev/null +++ b/src/layer/x86/gridsample_x86.h @@ -0,0 +1,32 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_GRIDSAMPLE_X86_H +#define LAYER_GRIDSAMPLE_X86_H + +#include "gridsample.h" + +namespace ncnn { + +class GridSample_x86 : virtual public GridSample +{ +public: + GridSample_x86(); + + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +}; + +} // namespace ncnn + +#endif // LAYER_GRIDSAMPLE_X86_H diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index e75e78c0c255..1571cdf49280 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -273,6 +273,14 @@ static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128 { return _mm_sub_ps(_c, _mm_mul_ps(_a, _b)); } +static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_sub_ps(_mm_mul_ps(_a, _b), _c); +} +static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_sub_ps(_c, _mm_mul_ps(_mm_mul_ps(_a, _b), _mm_set1_ps(-1))); +} #else static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(const __m128& _a, const __m128& _b, const __m128& _c) { @@ -283,6 +291,14 @@ static NCNN_FORCEINLINE __m128 _mm_comp_fnmadd_ps(const __m128& _a, const __m128 // return -a * b + c return _mm_fnmadd_ps(_a, _b, _c); } +static NCNN_FORCEINLINE __m128 _mm_comp_fmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_fmsub_ps(_a, _b, _c); +} +static NCNN_FORCEINLINE __m128 _mm_comp_fnmsub_ps(const __m128& _a, const __m128& _b, const __m128& _c) +{ + return _mm_fnmsub_ps(_a, _b, _c); +} #endif // !__FMA__ #if __AVX__ @@ -295,9 +311,18 @@ static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m { return _mm256_sub_ps(_c, _mm256_mul_ps(_a, _b)); } +static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + return _mm256_sub_ps(_mm256_mul_ps(_a, _b), _c); +} +static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + return _mm256_sub_ps(_c, _mm256_mul_ps(_mm256_mul_ps(_a, _b), _mm256_set1_ps(-1))); +} #else static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) { + // return a * b + c return _mm256_fmadd_ps(_a, _b, _c); } static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m256& _b, const __m256& _c) @@ -305,6 +330,16 @@ static NCNN_FORCEINLINE __m256 _mm256_comp_fnmadd_ps(const __m256& _a, const __m // return -a * b + c return _mm256_fnmadd_ps(_a, _b, _c); } +static NCNN_FORCEINLINE __m256 _mm256_comp_fmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + // return a * b - c + return _mm256_fmsub_ps(_a, _b, _c); +} +static NCNN_FORCEINLINE __m256 _mm256_comp_fnmsub_ps(const __m256& _a, const __m256& _b, const __m256& _c) +{ + // return -(a * b) - c + return _mm256_fnmsub_ps(_a, _b, _c); +} #endif static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(const __m256& a, const __m256& b, float c) @@ -418,6 +453,163 @@ static NCNN_FORCEINLINE void transpose8x2_ps(__m256& _r0, __m256& _r1) _r1 = _mm256_permute2f128_ps(_tmp0, _tmp1, _MM_SHUFFLE(0, 3, 0, 1)); } +static NCNN_FORCEINLINE void transpose2x8_ps(__m256& _r0, __m256& _r1) +{ + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 3, 0, 1)); + + _r0 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + _r1 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); +} + +static NCNN_FORCEINLINE void transpose3x8_ps(__m256& _r0, __m256& _r1, __m256& _r2) +{ + __m256 _tmp0 = _mm256_permute2f128_ps(_r0, _r1, _MM_SHUFFLE(0, 3, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_r0, _r2, _MM_SHUFFLE(0, 2, 0, 1)); + __m256 _tmp2 = _mm256_permute2f128_ps(_r1, _r2, _MM_SHUFFLE(0, 3, 0, 0)); + + __m256 _tmp4 = _mm256_shuffle_ps(_tmp0, _tmp1, _MM_SHUFFLE(1, 0, 2, 1)); + __m256 _tmp5 = _mm256_shuffle_ps(_tmp1, _tmp2, _MM_SHUFFLE(2, 1, 3, 2)); + + _r0 = _mm256_shuffle_ps(_tmp0, _tmp5, _MM_SHUFFLE(2, 0, 3, 0)); + _r1 = _mm256_shuffle_ps(_tmp4, _tmp5, _MM_SHUFFLE(3, 1, 2, 0)); + _r2 = _mm256_shuffle_ps(_tmp4, _tmp2, _MM_SHUFFLE(3, 0, 3, 1)); +} + +static NCNN_FORCEINLINE void transpose8x6_ps(__m256& _r0, __m256& _r1, __m256& _r2, __m256& _r3, __m256& _r4, __m256& _r5) +{ + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + + __m256 _tmp6 = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmp7 = _mm256_shuffle_ps(_tmp4, _tmp0, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmp8 = _mm256_shuffle_ps(_tmp2, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmp9 = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpa = _mm256_shuffle_ps(_tmp5, _tmp1, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpb = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmp8, _tmp9, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpa, _tmpb, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmp6, _tmp7, _MM_SHUFFLE(0, 3, 0, 1)); + _r4 = _mm256_permute2f128_ps(_tmp8, _tmp9, _MM_SHUFFLE(0, 3, 0, 1)); + _r5 = _mm256_permute2f128_ps(_tmpa, _tmpb, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static NCNN_FORCEINLINE void transpose8x11_ps(__m256& _r0, __m256& _r1, __m256& _r2, __m256& _r3, __m256& _r4, __m256& _r5, __m256& _r6, __m256& _r7, __m256& _r8, __m256& _r9, __m256& _ra) +{ + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _r0); + __m256 _tmpb = _mm256_shuffle_ps(_ra, _tmp1, _MM_SHUFFLE(3, 2, 1, 2)); + __m256 _tmpc = _mm256_unpacklo_ps(_r1, _r2); + __m256 _tmpd = _mm256_unpackhi_ps(_r1, _r2); + __m256 _tmpe = _mm256_unpacklo_ps(_r3, _r4); + __m256 _tmpf = _mm256_unpackhi_ps(_r3, _r4); + __m256 _tmpg = _mm256_unpacklo_ps(_r5, _r6); + __m256 _tmph = _mm256_unpackhi_ps(_r5, _r6); + __m256 _tmpi = _mm256_unpacklo_ps(_r7, _r8); + __m256 _tmpj = _mm256_unpackhi_ps(_r7, _r8); + __m256 _tmpk = _mm256_unpacklo_ps(_r9, _ra); + __m256 _tmpl = _mm256_unpackhi_ps(_r9, _ra); + + __m256 _tmpm = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpo = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(3, 0, 1, 0)); + __m256 _tmpp = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpq = _mm256_shuffle_ps(_tmpg, _tmpi, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpr = _mm256_shuffle_ps(_tmpk, _tmp1, _MM_SHUFFLE(1, 0, 3, 2)); + __m256 _tmps = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpt = _mm256_shuffle_ps(_tmp7, _tmp9, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpu = _mm256_shuffle_ps(_tmpb, _tmpd, _MM_SHUFFLE(3, 2, 2, 0)); + __m256 _tmpv = _mm256_shuffle_ps(_tmpf, _tmph, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpw = _mm256_shuffle_ps(_tmpj, _tmpl, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmpw, _tmpm, _MM_SHUFFLE(0, 3, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpn, _tmpo, _MM_SHUFFLE(0, 3, 0, 1)); + _r7 = _mm256_permute2f128_ps(_tmpp, _tmpq, _MM_SHUFFLE(0, 3, 0, 1)); + _r8 = _mm256_permute2f128_ps(_tmpr, _tmps, _MM_SHUFFLE(0, 3, 0, 1)); + _r9 = _mm256_permute2f128_ps(_tmpt, _tmpu, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmpv, _tmpw, _MM_SHUFFLE(0, 3, 0, 1)); +} + +static void transpose8x18_ps(__m256& _r0, __m256& _r1, __m256& _r2, __m256& _r3, __m256& _r4, __m256& _r5, __m256& _r6, __m256& _r7, __m256& _r8, __m256& _r9, __m256& _ra, __m256& _rb, __m256& _rc, __m256& _rd, __m256& _re, __m256& _rf, __m256& _rg, __m256& _rh) +{ + __m256 _tmp0 = _mm256_unpacklo_ps(_r0, _r1); + __m256 _tmp1 = _mm256_unpackhi_ps(_r0, _r1); + __m256 _tmp2 = _mm256_unpacklo_ps(_r2, _r3); + __m256 _tmp3 = _mm256_unpackhi_ps(_r2, _r3); + __m256 _tmp4 = _mm256_unpacklo_ps(_r4, _r5); + __m256 _tmp5 = _mm256_unpackhi_ps(_r4, _r5); + __m256 _tmp6 = _mm256_unpacklo_ps(_r6, _r7); + __m256 _tmp7 = _mm256_unpackhi_ps(_r6, _r7); + __m256 _tmp8 = _mm256_unpacklo_ps(_r8, _r9); + __m256 _tmp9 = _mm256_unpackhi_ps(_r8, _r9); + __m256 _tmpa = _mm256_unpacklo_ps(_ra, _rb); + __m256 _tmpb = _mm256_unpackhi_ps(_ra, _rb); + __m256 _tmpc = _mm256_unpacklo_ps(_rc, _rd); + __m256 _tmpd = _mm256_unpackhi_ps(_rc, _rd); + __m256 _tmpe = _mm256_unpacklo_ps(_re, _rf); + __m256 _tmpf = _mm256_unpackhi_ps(_re, _rf); + __m256 _tmpg = _mm256_unpacklo_ps(_rg, _rh); + __m256 _tmph = _mm256_unpackhi_ps(_rg, _rh); + + __m256 _tmpi = _mm256_shuffle_ps(_tmp0, _tmp2, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpj = _mm256_shuffle_ps(_tmp4, _tmp6, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpk = _mm256_shuffle_ps(_tmp8, _tmpa, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpl = _mm256_shuffle_ps(_tmpc, _tmpe, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpm = _mm256_shuffle_ps(_tmpg, _tmp0, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpn = _mm256_shuffle_ps(_tmp2, _tmp4, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpo = _mm256_shuffle_ps(_tmp6, _tmp8, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpp = _mm256_shuffle_ps(_tmpa, _tmpc, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpq = _mm256_shuffle_ps(_tmpe, _tmpg, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpr = _mm256_shuffle_ps(_tmp1, _tmp3, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmps = _mm256_shuffle_ps(_tmp5, _tmp7, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpt = _mm256_shuffle_ps(_tmp9, _tmpb, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpu = _mm256_shuffle_ps(_tmpd, _tmpf, _MM_SHUFFLE(1, 0, 1, 0)); + __m256 _tmpv = _mm256_shuffle_ps(_tmph, _tmp1, _MM_SHUFFLE(3, 2, 1, 0)); + __m256 _tmpw = _mm256_shuffle_ps(_tmp3, _tmp5, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpx = _mm256_shuffle_ps(_tmp7, _tmp9, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpy = _mm256_shuffle_ps(_tmpb, _tmpd, _MM_SHUFFLE(3, 2, 3, 2)); + __m256 _tmpz = _mm256_shuffle_ps(_tmpf, _tmph, _MM_SHUFFLE(3, 2, 3, 2)); + + _r0 = _mm256_permute2f128_ps(_tmpi, _tmpj, _MM_SHUFFLE(0, 2, 0, 0)); + _r1 = _mm256_permute2f128_ps(_tmpk, _tmpl, _MM_SHUFFLE(0, 2, 0, 0)); + _r2 = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 2, 0, 0)); + _r3 = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 2, 0, 0)); + _r4 = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 2, 0, 0)); + _r5 = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 2, 0, 0)); + _r6 = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 2, 0, 0)); + _r7 = _mm256_permute2f128_ps(_tmpw, _tmpx, _MM_SHUFFLE(0, 2, 0, 0)); + _r8 = _mm256_permute2f128_ps(_tmpy, _tmpz, _MM_SHUFFLE(0, 2, 0, 0)); + _r9 = _mm256_permute2f128_ps(_tmpi, _tmpj, _MM_SHUFFLE(0, 3, 0, 1)); + _ra = _mm256_permute2f128_ps(_tmpk, _tmpl, _MM_SHUFFLE(0, 3, 0, 1)); + _rb = _mm256_permute2f128_ps(_tmpm, _tmpn, _MM_SHUFFLE(0, 3, 0, 1)); + _rc = _mm256_permute2f128_ps(_tmpo, _tmpp, _MM_SHUFFLE(0, 3, 0, 1)); + _rd = _mm256_permute2f128_ps(_tmpq, _tmpr, _MM_SHUFFLE(0, 3, 0, 1)); + _re = _mm256_permute2f128_ps(_tmps, _tmpt, _MM_SHUFFLE(0, 3, 0, 1)); + _rf = _mm256_permute2f128_ps(_tmpu, _tmpv, _MM_SHUFFLE(0, 3, 0, 1)); + _rg = _mm256_permute2f128_ps(_tmpw, _tmpx, _MM_SHUFFLE(0, 3, 0, 1)); + _rh = _mm256_permute2f128_ps(_tmpy, _tmpz, _MM_SHUFFLE(0, 3, 0, 1)); +} + static NCNN_FORCEINLINE __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7) { const __m256 s01 = _mm256_hadd_ps(v0, v1); diff --git a/tests/test_gridsample.cpp b/tests/test_gridsample.cpp index 70c96b304805..0e3841153521 100644 --- a/tests/test_gridsample.cpp +++ b/tests/test_gridsample.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -15,12 +15,13 @@ #include "layer/gridsample.h" #include "testutil.h" -static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample_type, int padding_mode, int align_corner) +static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample_type, int padding_mode, int align_corner, int permute_fusion) { ncnn::ParamDict pd; pd.set(0, sample_type); pd.set(1, padding_mode); pd.set(2, align_corner); + pd.set(3, permute_fusion); std::vector weights(0); @@ -31,9 +32,9 @@ static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample int ret = test_layer("GridSample", pd, weights, as); if (ret != 0) { - fprintf(stderr, "test_gridsample failed a.dims=%d a=(%d %d %d %d) grid.dims=%d grid=(%d %d %d %d) sample_type=%d padding_mode=%d align_corner=%d", + fprintf(stderr, "test_gridsample failed a.dims=%d a=(%d %d %d %d) grid.dims=%d grid=(%d %d %d %d) sample_type=%d padding_mode=%d align_corner=%d permute_fusion=%d", a.dims, a.w, a.h, a.d, a.c, grid.dims, grid.w, grid.h, grid.d, grid.c, - sample_type, padding_mode, align_corner); + sample_type, padding_mode, align_corner, permute_fusion); } return ret; @@ -42,81 +43,141 @@ static int test_gridsample(const ncnn::Mat& a, const ncnn::Mat& grid, int sample static int test_gridsample_0() { return 0 - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 2, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 16, 12), 3, 3, 1); + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 1, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 1, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 2, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 2, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 3, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 1, 3, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 1, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 1, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 2, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 2, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 3, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 2, 3, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 1, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 1, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 2, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 2, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 3, 0, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(2, 11, 13), 3, 3, 1, 0) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 1, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 1, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 2, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 2, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 3, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 1, 3, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 1, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 1, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 2, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 2, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 3, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 2, 3, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 1, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 1, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 2, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 2, 1, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 3, 0, 1) + || test_gridsample(RandomMat(3, 7, 1), RandomMat(11, 13, 2), 3, 3, 1, 1); } static int test_gridsample_1() { return 0 - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 2, 3, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 1, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 1, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 2, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 2, 1) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 3, 0) - || test_gridsample(RandomMat(16, 12, 3), RandomMat(2, 27, 21), 3, 3, 1); + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 1, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 1, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 2, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 2, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 3, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 1, 3, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 1, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 1, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 2, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 2, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 3, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 2, 3, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 1, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 1, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 2, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 2, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 3, 0, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(2, 24, 16), 3, 3, 1, 0) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 1, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 1, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 2, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 2, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 3, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 1, 3, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 1, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 1, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 2, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 2, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 3, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 2, 3, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 1, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 1, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 2, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 2, 1, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 3, 0, 1) + || test_gridsample(RandomMat(8, 12, 16), RandomMat(24, 16, 2), 3, 3, 1, 1); } static int test_gridsample_2() { return 0 - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 27, 21, 10), 2, 3, 1); + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 1, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 1, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 2, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 2, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 3, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 1, 3, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 1, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 1, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 2, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 2, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 3, 0, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(3, 17, 11, 13), 2, 3, 1, 0) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 1, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 1, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 2, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 2, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 3, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 1, 3, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 1, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 1, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 2, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 2, 1, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 3, 0, 1) + || test_gridsample(RandomMat(5, 7, 11, 13), RandomMat(17, 11, 13, 3), 2, 3, 1, 1); } static int test_gridsample_3() { return 0 - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 1, 3, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 1, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 1, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 2, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 2, 1) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 3, 0) - || test_gridsample(RandomMat(16, 12, 10, 5), RandomMat(3, 16, 12, 10), 2, 3, 1); + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 1, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 1, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 2, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 2, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 3, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 1, 3, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 1, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 1, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 2, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 2, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 3, 0, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(3, 11, 12, 16), 2, 3, 1, 0) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 1, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 1, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 2, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 2, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 3, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 1, 3, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 1, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 1, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 2, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 2, 1, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 3, 0, 1) + || test_gridsample(RandomMat(16, 12, 11, 16), RandomMat(11, 12, 16, 3), 2, 3, 1, 1); } int main() diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 3d09ec1859dc..fd5105e612fe 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1734,6 +1734,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 0=%d", sample_type) fprintf_param_value(" 1=%d", padding_mode) fprintf_param_value(" 2=%d", align_corner) + fprintf_param_value(" 3=%d", permute_fusion) } else if (layer->type == "GroupNorm") { diff --git a/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp b/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp index 41dfc65ee39c..b71e64dcbde0 100644 --- a/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp +++ b/tools/pnnx/src/pass_ncnn/F_grid_sample.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -15,7 +15,6 @@ #include "pass_ncnn.h" namespace pnnx { - namespace ncnn { class F_grid_sample : public GraphRewriterPass @@ -61,11 +60,113 @@ pnnx.Output output 1 0 out op->params["1"] = 3; op->params["2"] = captured_params.at("align_corners").b ? 1 : 0; + op->params["3"] = 0; } }; REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_grid_sample, 20) +class F_grid_sample_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_a 0 1 a +pnnx.Input input_b 0 1 b +torch.permute op_0 1 1 b b1 dims=%dims +F.grid_sample op_1 2 1 a b1 out mode=%mode padding_mode=%padding_mode align_corners=%align_corners +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "GridSample"; + } + + const char* name_str() const + { + return "permutegridsample"; + } + + bool match(const std::map& captured_params) const + { + const std::vector& dims = captured_params.at("dims").ai; + + if ((dims == std::vector{1, 2, 0}) || (dims == std::vector{1, 2, 3, 0})) + return true; + if ((dims == std::vector{0, 2, 3, 1}) || (dims == std::vector{0, 2, 3, 4, 1})) + return true; + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& mode = captured_params.at("mode").s; + if (mode == "bilinear") + op->params["0"] = 1; + if (mode == "nearest") + op->params["0"] = 2; + if (mode == "bicubic") + op->params["0"] = 3; + + const std::string& padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "zeros") + op->params["1"] = 1; + if (padding_mode == "border") + op->params["1"] = 2; + if (padding_mode == "reflection") + op->params["1"] = 3; + + op->params["2"] = captured_params.at("align_corners").b ? 1 : 0; + + const int batch_index = op->inputs[1]->params["__batch_index"].i; + + const std::vector& dims = captured_params.at("dims").ai; + + int input_rank = (int)op->inputs[0]->shape.size(); + + if (input_rank == 0) + { + // assume input is fine + input_rank = (int)dims.size(); + } + + if (batch_index >= 0 && batch_index < input_rank) + input_rank -= 1; + + if (input_rank > 4) + { + fprintf(stderr, "permute %d-rank tensor is not supported yet!\n", input_rank); + return; + } + + // drop permute batch index + std::vector new_dims; + for (int i = 0; i < (int)dims.size(); i++) + { + if (dims[i] == batch_index) + continue; + + int new_dim = dims[i] > batch_index ? dims[i] - 1 : dims[i]; + new_dims.push_back(new_dim); + } + + if (input_rank != (int)new_dims.size()) + { + fprintf(stderr, "permute %d-rank tensor with %d-rank dims is not possible\n", input_rank, (int)new_dims.size()); + return; + } + + if ((input_rank == 3 && new_dims == std::vector{1, 2, 0}) || (input_rank == 4 && new_dims == std::vector{1, 2, 3, 0})) + op->params["3"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_grid_sample_1, 19) + } // namespace ncnn } // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/test_F_grid_sample.py b/tools/pnnx/tests/ncnn/test_F_grid_sample.py index c84d38232b1e..95ca812eb512 100644 --- a/tools/pnnx/tests/ncnn/test_F_grid_sample.py +++ b/tools/pnnx/tests/ncnn/test_F_grid_sample.py @@ -22,46 +22,56 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() - def forward(self, x, xg1, xg2, y, yg1, yg2): + def forward(self, x, xg1, xg2, xgp1, xgp2, y, yg1, yg2, ygp1, ygp2): # norm to -1 ~ 1 xg1 = xg1 * 2 - 1 xg2 = xg2 * 2 - 1 yg1 = yg1 * 2 - 1 yg2 = yg2 * 2 - 1 - x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='zeros', align_corners=False) - x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='border', align_corners=False) - x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='reflection', align_corners=False) - x = F.grid_sample(x, xg2, mode='nearest', padding_mode='zeros', align_corners=False) - x = F.grid_sample(x, xg1, mode='nearest', padding_mode='border', align_corners=False) - x = F.grid_sample(x, xg2, mode='nearest', padding_mode='reflection', align_corners=False) - x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='zeros', align_corners=False) - x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='border', align_corners=False) - x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='reflection', align_corners=False) - x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='zeros', align_corners=True) - x = F.grid_sample(x, xg1, mode='bilinear', padding_mode='border', align_corners=True) - x = F.grid_sample(x, xg2, mode='bilinear', padding_mode='reflection', align_corners=True) - x = F.grid_sample(x, xg1, mode='nearest', padding_mode='zeros', align_corners=True) - x = F.grid_sample(x, xg2, mode='nearest', padding_mode='border', align_corners=True) - x = F.grid_sample(x, xg1, mode='nearest', padding_mode='reflection', align_corners=True) - x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='zeros', align_corners=True) - x = F.grid_sample(x, xg1, mode='bicubic', padding_mode='border', align_corners=True) - x = F.grid_sample(x, xg2, mode='bicubic', padding_mode='reflection', align_corners=True) - - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=False) - y = F.grid_sample(y, yg2, mode='bilinear', padding_mode='border', align_corners=False) - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='reflection', align_corners=False) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='zeros', align_corners=False) - y = F.grid_sample(y, yg1, mode='nearest', padding_mode='border', align_corners=False) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='reflection', align_corners=False) - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=True) - y = F.grid_sample(y, yg2, mode='bilinear', padding_mode='border', align_corners=True) - y = F.grid_sample(y, yg1, mode='bilinear', padding_mode='reflection', align_corners=True) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='zeros', align_corners=True) - y = F.grid_sample(y, yg1, mode='nearest', padding_mode='border', align_corners=True) - y = F.grid_sample(y, yg2, mode='nearest', padding_mode='reflection', align_corners=True) - - return x, y + x0 = F.grid_sample(x, xg1, mode='bilinear', padding_mode='zeros', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='bilinear', padding_mode='border', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='bilinear', padding_mode='reflection', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='nearest', padding_mode='zeros', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='nearest', padding_mode='border', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='nearest', padding_mode='reflection', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='bicubic', padding_mode='zeros', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='bicubic', padding_mode='border', align_corners=False) + x0 = F.grid_sample(x0, xg1, mode='bicubic', padding_mode='reflection', align_corners=False) + x0 = F.grid_sample(x0, xg2, mode='bilinear', padding_mode='zeros', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='bilinear', padding_mode='border', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='bilinear', padding_mode='reflection', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='nearest', padding_mode='zeros', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='nearest', padding_mode='border', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='nearest', padding_mode='reflection', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='bicubic', padding_mode='zeros', align_corners=True) + x0 = F.grid_sample(x0, xg1, mode='bicubic', padding_mode='border', align_corners=True) + x0 = F.grid_sample(x0, xg2, mode='bicubic', padding_mode='reflection', align_corners=True) + + y0 = F.grid_sample(y, yg1, mode='bilinear', padding_mode='zeros', align_corners=False) + y0 = F.grid_sample(y0, yg2, mode='bilinear', padding_mode='border', align_corners=False) + y0 = F.grid_sample(y0, yg1, mode='bilinear', padding_mode='reflection', align_corners=False) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='zeros', align_corners=False) + y0 = F.grid_sample(y0, yg1, mode='nearest', padding_mode='border', align_corners=False) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='reflection', align_corners=False) + y0 = F.grid_sample(y0, yg1, mode='bilinear', padding_mode='zeros', align_corners=True) + y0 = F.grid_sample(y0, yg2, mode='bilinear', padding_mode='border', align_corners=True) + y0 = F.grid_sample(y0, yg1, mode='bilinear', padding_mode='reflection', align_corners=True) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='zeros', align_corners=True) + y0 = F.grid_sample(y0, yg1, mode='nearest', padding_mode='border', align_corners=True) + y0 = F.grid_sample(y0, yg2, mode='nearest', padding_mode='reflection', align_corners=True) + + xgp1 = xgp1.permute(0, 2, 3, 1) + xgp2 = xgp2.permute(0, 2, 3, 1) + ygp1 = ygp1.permute(0, 2, 3, 4, 1) + ygp2 = ygp2.permute(0, 2, 3, 4, 1) + + x1 = F.grid_sample(x, xgp1, mode='bilinear', padding_mode='zeros', align_corners=False) + x1 = F.grid_sample(x1, xgp2, mode='bilinear', padding_mode='border', align_corners=False) + + y1 = F.grid_sample(y, ygp1, mode='bilinear', padding_mode='zeros', align_corners=False) + y1 = F.grid_sample(y1, ygp2, mode='bilinear', padding_mode='border', align_corners=False) + return x0, y0, x1, y1 def test(): net = Model() @@ -71,25 +81,29 @@ def test(): x = torch.rand(1, 3, 12, 16) xg1 = torch.rand(1, 21, 27, 2) xg2 = torch.rand(1, 12, 16, 2) + xgp1 = torch.rand(1, 2, 21, 27) + xgp2 = torch.rand(1, 2, 12, 16) y = torch.rand(1, 5, 10, 12, 16) yg1 = torch.rand(1, 10, 21, 27, 3) yg2 = torch.rand(1, 10, 12, 16, 3) + ygp1 = torch.rand(1, 3, 10, 21, 27) + ygp2 = torch.rand(1, 3, 10, 12, 16) - a0, a1 = net(x, xg1, xg2, y, yg1, yg2) + a0, a1, a2, a3 = net(x, xg1, xg2, xgp1, xgp2, y, yg1, yg2, ygp1, ygp2) # export torchscript - mod = torch.jit.trace(net, (x, xg1, xg2, y, yg1, yg2)) + mod = torch.jit.trace(net, (x, xg1, xg2, xgp1, xgp2, y, yg1, yg2, ygp1, ygp2)) mod.save("test_F_grid_sample.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_grid_sample.pt inputshape=[1,3,12,16],[1,21,27,2],[1,12,16,2],[1,5,10,12,16],[1,10,21,27,3],[1,10,12,16,3]") + os.system("../../src/pnnx test_F_grid_sample.pt inputshape=[1,3,12,16],[1,21,27,2],[1,12,16,2],[1,2,21,27],[1,2,12,16],[1,5,10,12,16],[1,10,21,27,3],[1,10,12,16,3],[1,3,10,21,27],[1,3,10,12,16]") # ncnn inference import test_F_grid_sample_ncnn - b0, b1 = test_F_grid_sample_ncnn.test_inference() + b0, b1, b2, b3 = test_F_grid_sample_ncnn.test_inference() - return torch.allclose(a0, b0, 1e-4, 1e-4) and torch.allclose(a1, b1, 1e-4, 1e-4) + return torch.allclose(a0, b0, 1e-6, 1e-6) and torch.allclose(a1, b1, 1e-6, 1e-6) and torch.allclose(a2, b2, 1e-6, 1e-6) and torch.allclose(a3, b3, 1e-6, 1e-6) if __name__ == "__main__": if test():