From daecbe6ea9f3a19ab9eabfbf49f9944fb40ad74c Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 19 Sep 2024 19:02:03 +0800 Subject: [PATCH 01/55] staging --- src/layer/arm/gemm_arm.cpp | 1062 ++ src/layer/arm/gemm_arm.h | 4 + src/layer/arm/gemm_arm_asimddp.cpp | 114 + src/layer/arm/gemm_arm_i8mm.cpp | 94 + src/layer/arm/gemm_int8.h | 16045 +++++++++++++++++++++++++++ src/layer/arm/gemm_int8_bf16s.h | 9903 +++++++++++++++++ src/layer/gemm.cpp | 451 +- src/layer/gemm.h | 12 + tests/test_gemm_3.cpp | 316 + 9 files changed, 27943 insertions(+), 58 deletions(-) create mode 100644 src/layer/arm/gemm_arm_asimddp.cpp create mode 100644 src/layer/arm/gemm_arm_i8mm.cpp create mode 100644 src/layer/arm/gemm_int8.h create mode 100644 src/layer/arm/gemm_int8_bf16s.h create mode 100644 tests/test_gemm_3.cpp diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index e798680e2afa..c1e0c3e0d697 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -16,6 +16,7 @@ #if __ARM_NEON #include +#include "neon_mathfun.h" #endif // __ARM_NEON #include "arm_usability.h" @@ -29,6 +30,13 @@ namespace ncnn { #include "gemm_bf16s.h" #endif +#if NCNN_INT8 +#include "gemm_int8.h" +#if NCNN_BF16 +#include "gemm_int8_bf16s.h" +#endif +#endif + Gemm_arm::Gemm_arm() { #if __ARM_NEON @@ -50,6 +58,8 @@ void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) const int elempack = A.elempack; const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + NCNN_LOGE("pack_A_tile %d %d %d %d %d %d", i, max_ii, k, max_kk, elempack, A_hstep); + float* pp = AT; int ii = 0; @@ -2461,6 +2471,79 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons int kk = 0; for (; kk < max_kk; kk += 1) { +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%0, #128] \n" + "ld1 {v2.4s}, [%0], #16 \n" + "prfm pldl1keep, [%1, #256] \n" + "ld1 {v0.4s, v1.4s}, [%1], #32 \n" + "fmla %2.4s, v2.4s, v0.s[0] \n" + "fmla %3.4s, v2.4s, v0.s[1] \n" + "fmla %4.4s, v2.4s, v0.s[2] \n" + "fmla %5.4s, v2.4s, v0.s[3] \n" + "fmla %6.4s, v2.4s, v1.s[0] \n" + "fmla %7.4s, v2.4s, v1.s[1] \n" + "fmla %8.4s, v2.4s, v1.s[2] \n" + "fmla %9.4s, v2.4s, v1.s[3] \n" + : "=r"(pA), + "=r"(pB), + "=w"(_sum0), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6), + "=w"(_sum7) + : "0"(pA), + "1"(pB), + "2"(_sum0), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "6"(_sum4), + "7"(_sum5), + "8"(_sum6), + "9"(_sum7) + : "memory", "v0", "v1", "v2", "v3"); +#else + asm volatile( + "pld [%0, #128] \n" + "vld1.f32 {d4-d5}, [%0]! \n" + "pld [%1, #256] \n" + "vld1.f32 {d0-d3}, [%1]! \n" + "vmla.f32 %q2, q2, d0[0] \n" + "vmla.f32 %q3, q2, d0[1] \n" + "vmla.f32 %q4, q2, d1[0] \n" + "vmla.f32 %q5, q2, d1[1] \n" + "vmla.f32 %q6, q2, d2[0] \n" + "vmla.f32 %q7, q2, d2[1] \n" + "vmla.f32 %q8, q2, d3[0] \n" + "vmla.f32 %q9, q2, d3[1] \n" + : "=r"(pA), + "=r"(pB), + "=w"(_sum0), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6), + "=w"(_sum7) + : "0"(pA), + "1"(pB), + "2"(_sum0), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3), + "6"(_sum4), + "7"(_sum5), + "8"(_sum6), + "9"(_sum7) + : "memory", "q0", "q1", "q2"); +#endif +#else // NCNN_GNU_INLINE_ASM float32x4_t _pA = vld1q_f32(pA); float32x4_t _pB0 = vld1q_f32(pB); float32x4_t _pB1 = vld1q_f32(pB + 4); @@ -2487,6 +2570,7 @@ static void gemm_transB_packed_tile(const Mat& AT_tile, const Mat& BT_tile, cons pA += 4; pB += 8; +#endif // NCNN_GNU_INLINE_ASM } if (k_end) @@ -4164,6 +4248,17 @@ static int gemm_AT_BT_arm(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_b int Gemm_arm::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + // support_packing = false; + support_fp16_storage = false; + // support_bf16_storage = false; + return create_pipeline_int8(opt); + // return 0; + } +#endif + #if NCNN_ARM82 if (cpu_support_arm_asimdhp() && opt.use_fp16_storage) { @@ -4311,6 +4406,14 @@ int Gemm_arm::create_pipeline(const Option& opt) int Gemm_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + // return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& bottom_blob = constantA ? AT_data : bottom_blobs[0]; int elembits = bottom_blob.elembits(); @@ -5199,4 +5302,963 @@ int Gemm_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + int M; + int N; + if (constantA && constantB) + { + M = constantM; + N = constantN; + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + M = constantM; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = constantN; + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + } + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = CT_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB) + { + C = bottom_blobs.size() == 1 ? bottom_blobs[0] : Mat(); + } + else if (constantA) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else if (constantB) + { + C = bottom_blobs.size() == 2 ? bottom_blobs[1] : Mat(); + } + else + { + C = bottom_blobs.size() == 3 ? bottom_blobs[2] : Mat(); + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w * C.elempack == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w * C.elempack == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h * C.elempack == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h * C.elempack == 1) + { + // 1xN + broadcast_type_C = 4; + } + + // pre-multiply C with beta + if (beta != 1.f) + { + Mat CT_data; + CT_data.create_like(C, opt.workspace_allocator); + + const int size = C.total() * C.elempack; + for (int i = 0; i < size; i++) + { + CT_data[i] = C[i] * beta; + } + + C = CT_data; + } + } + } + + int out_elempack = 1; +#if __ARM_NEON + if (opt.use_packing_layout) + { + int outh = output_transpose ? N : M; + out_elempack = outh % 4 == 0 ? 4 : 1; + } +#endif // __ARM_NEON + if (output_elempack) + out_elempack = output_elempack; + size_t out_elemsize = 4u * out_elempack; + + if (opt.use_bf16_storage) + { + out_elemsize = 2u * out_elempack; + } + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(M, N / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + else + top_blob.create(N, M / out_elempack, out_elemsize, out_elempack, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + int _nT = nT ? nT : opt.num_threads; + if (nT != 0 && opt.num_threads != nT) + { + // force num_threads the same as in create_pipeline + // so we could use pre-packed A/B from the same tile config + NCNN_LOGE("opt.num_threads %d changed, gemm will use load-time value %d", opt.num_threads, nT); + } + + int ret = 0; + if (constantA && constantB) + { + ret = gemm_AT_BT_arm_int8(AT_data, A_data_int8_scales, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantM, constantN, constantK, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantA) + { + const Mat& B = bottom_blobs[0]; + ret = gemm_AT_arm_int8(AT_data, A_data_int8_scales, B, C, top_blob, broadcast_type_C, constantM, constantK, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else if (constantB) + { + const Mat& A = bottom_blobs[0]; + ret = gemm_BT_arm_int8(A, BT_data, B_data_int8_scale, C, top_blob, broadcast_type_C, constantN, constantK, transA, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + else + { + const Mat& A = bottom_blobs[0]; + const Mat& B = bottom_blobs[1]; + ret = gemm_arm_int8(A, B, C, top_blob, broadcast_type_C, transA, transB, output_transpose, constant_TILE_M, constant_TILE_N, constant_TILE_K, _nT, opt); + } + if (ret != 0) + return ret; + + // multiply top_blob with alpha + if (alpha != 1.f) + { + const int size = top_blob.total() * out_elempack; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < size; i++) + { + top_blob[i] *= alpha; + } + } + + return 0; +} +#endif + } // namespace ncnn diff --git a/src/layer/arm/gemm_arm.h b/src/layer/arm/gemm_arm.h index 0c1eab108baf..f72c4d5fa053 100644 --- a/src/layer/arm/gemm_arm.h +++ b/src/layer/arm/gemm_arm.h @@ -41,6 +41,10 @@ class Gemm_arm : public Gemm int create_pipeline_bf16s(const Option& opt); int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif public: int nT; diff --git a/src/layer/arm/gemm_arm_asimddp.cpp b/src/layer/arm/gemm_arm_asimddp.cpp new file mode 100644 index 000000000000..9d62a50ed8ec --- /dev/null +++ b/src/layer/arm/gemm_arm_asimddp.cpp @@ -0,0 +1,114 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +#if NCNN_BF16 +#include "gemm_int8_bf16s.h" +#endif + +void pack_A_tile_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ + unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); +} + +void transpose_unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ + transpose_unpack_output_tile_int32_to_fp32(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); +} + +void gemm_transB_packed_tile_int8_asimddp(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +#if NCNN_BF16 +void pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_bf16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_bf16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ + unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); +} + +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ + transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/arm/gemm_arm_i8mm.cpp b/src/layer/arm/gemm_arm_i8mm.cpp new file mode 100644 index 000000000000..69c1941f4524 --- /dev/null +++ b/src/layer/arm/gemm_arm_i8mm.cpp @@ -0,0 +1,94 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "gemm_int8.h" + +#if NCNN_BF16 +#include "gemm_int8_bf16s.h" +#endif + +void pack_A_tile_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void transpose_pack_A_tile_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ + transpose_pack_A_tile_int8(A, AT, i, max_ii, k, max_kk); +} + +void pack_B_tile_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void transpose_pack_B_tile_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ + transpose_pack_B_tile_int8(B, BT, j, max_jj, k, max_kk); +} + +void pack_A_tile_fp32_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_fp32_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_fp32_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_fp32_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_fp32_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_fp32_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void gemm_transB_packed_tile_int8_i8mm(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ + gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); +} + +#if NCNN_BF16 +void pack_A_tile_bf16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + pack_A_tile_bf16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void transpose_pack_A_tile_bf16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ + transpose_pack_A_tile_bf16_to_int8(A, AT, i, max_ii, k, max_kk, scales); +} + +void pack_B_tile_bf16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} + +void transpose_pack_B_tile_bf16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ + transpose_pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); +} +#endif // NCNN_BF16 + +} // namespace ncnn diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h new file mode 100644 index 000000000000..9ccd880ea62c --- /dev/null +++ b/src/layer/arm/gemm_int8.h @@ -0,0 +1,16045 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void gemm_transB_packed_tile_int8_i8mm(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void transpose_pack_A_tile_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk); +void pack_B_tile_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void transpose_pack_B_tile_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk); +void pack_A_tile_fp32_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_fp32_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_fp32_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_fp32_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); +void transpose_unpack_output_tile_int32_to_fp32_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); +void gemm_transB_packed_tile_int8_asimddp(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk); +#endif + +static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_A_tile_int8_i8mm(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_A_tile_int8_asimddp(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("pack_A_tile_int8"); + // assert A.elempack == 1 + // assert A.dims == 2 + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* p0 = A.row(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + const signed char* p2 = A.row(i + ii + 2) + k; + const signed char* p3 = A.row(i + ii + 3) + k; + const signed char* p4 = A.row(i + ii + 4) + k; + const signed char* p5 = A.row(i + ii + 5) + k; + const signed char* p6 = A.row(i + ii + 6) + k; + const signed char* p7 = A.row(i + ii + 7) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); + int8x16_t _p4 = vld1q_s8(p4); + int8x16_t _p5 = vld1q_s8(p5); + int8x16_t _p6 = vld1q_s8(p6); + int8x16_t _p7 = vld1q_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(vget_low_s8(_p0), vget_low_s8(_p1)); + int8x16_t _r1 = vcombine_s8(vget_low_s8(_p2), vget_low_s8(_p3)); + int8x16_t _r2 = vcombine_s8(vget_low_s8(_p4), vget_low_s8(_p5)); + int8x16_t _r3 = vcombine_s8(vget_low_s8(_p6), vget_low_s8(_p7)); + int8x16_t _r4 = vcombine_s8(vget_high_s8(_p0), vget_high_s8(_p1)); + int8x16_t _r5 = vcombine_s8(vget_high_s8(_p2), vget_high_s8(_p3)); + int8x16_t _r6 = vcombine_s8(vget_high_s8(_p4), vget_high_s8(_p5)); + int8x16_t _r7 = vcombine_s8(vget_high_s8(_p6), vget_high_s8(_p7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _p01 = vzipq_s32(vreinterpretq_s32_s8(_p0), vreinterpretq_s32_s8(_p1)); + int32x4x2_t _p23 = vzipq_s32(vreinterpretq_s32_s8(_p2), vreinterpretq_s32_s8(_p3)); + int32x4x2_t _p45 = vzipq_s32(vreinterpretq_s32_s8(_p4), vreinterpretq_s32_s8(_p5)); + int32x4x2_t _p67 = vzipq_s32(vreinterpretq_s32_s8(_p6), vreinterpretq_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[0]), vget_low_s32(_p23.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[0]), vget_low_s32(_p67.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[0]), vget_high_s32(_p23.val[0]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[0]), vget_high_s32(_p67.val[0]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[1]), vget_low_s32(_p23.val[1]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[1]), vget_low_s32(_p67.val[1]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[1]), vget_high_s32(_p23.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[1]), vget_high_s32(_p67.val[1]))); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _p01 = vzipq_s16(vreinterpretq_s16_s8(_p0), vreinterpretq_s16_s8(_p1)); + int16x8x2_t _p23 = vzipq_s16(vreinterpretq_s16_s8(_p2), vreinterpretq_s16_s8(_p3)); + int16x8x2_t _p45 = vzipq_s16(vreinterpretq_s16_s8(_p4), vreinterpretq_s16_s8(_p5)); + int16x8x2_t _p67 = vzipq_s16(vreinterpretq_s16_s8(_p6), vreinterpretq_s16_s8(_p7)); + int32x4x2_t _t0 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[0]), vreinterpretq_s32_s16(_p23.val[0])); + int32x4x2_t _t1 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[1]), vreinterpretq_s32_s16(_p23.val[1])); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[0]), vreinterpretq_s32_s16(_p67.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[1]), vreinterpretq_s32_s16(_p67.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t2.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t2.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t2.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t2.val[1]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + vst1q_s8(pp + 64, _r4); + vst1q_s8(pp + 80, _r5); + vst1q_s8(pp + 96, _r6); + vst1q_s8(pp + 112, _r7); + pp += 128; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + p4 += 16; + p5 += 16; + p6 += 16; + p7 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); + int8x8_t _p4 = vld1_s8(p4); + int8x8_t _p5 = vld1_s8(p5); + int8x8_t _p6 = vld1_s8(p6); + int8x8_t _p7 = vld1_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(_p0, _p1); + int8x16_t _r1 = vcombine_s8(_p2, _p3); + int8x16_t _r2 = vcombine_s8(_p4, _p5); + int8x16_t _r3 = vcombine_s8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _p01 = vzip_s32(vreinterpret_s32_s8(_p0), vreinterpret_s32_s8(_p1)); + int32x2x2_t _p23 = vzip_s32(vreinterpret_s32_s8(_p2), vreinterpret_s32_s8(_p3)); + int32x2x2_t _p45 = vzip_s32(vreinterpret_s32_s8(_p4), vreinterpret_s32_s8(_p5)); + int32x2x2_t _p67 = vzip_s32(vreinterpret_s32_s8(_p6), vreinterpret_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[0], _p23.val[0])); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[0], _p67.val[0])); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[1], _p23.val[1])); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[1], _p67.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8_t _p04 = vreinterpretq_s16_s8(vcombine_s8(_p0, _p4)); + int16x8_t _p15 = vreinterpretq_s16_s8(vcombine_s8(_p1, _p5)); + int16x8_t _p26 = vreinterpretq_s16_s8(vcombine_s8(_p2, _p6)); + int16x8_t _p37 = vreinterpretq_s16_s8(vcombine_s8(_p3, _p7)); + int16x8x2_t _t0 = vzipq_s16(_p04, _p15); + int16x8x2_t _t1 = vzipq_s16(_p26, _p37); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[0]), vreinterpretq_s32_s16(_t1.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[1]), vreinterpretq_s32_s16(_t1.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + pp += 64; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + p4 += 8; + p5 += 8; + p6 += 8; + p7 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; + pp[16] = p4[0]; + pp[17] = p4[1]; + pp[18] = p4[2]; + pp[19] = p4[3]; + pp[20] = p5[0]; + pp[21] = p5[1]; + pp[22] = p5[2]; + pp[23] = p5[3]; + pp[24] = p6[0]; + pp[25] = p6[1]; + pp[26] = p6[2]; + pp[27] = p6[3]; + pp[28] = p7[0]; + pp[29] = p7[1]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp[16] = p0[2]; + pp[17] = p0[3]; + pp[18] = p1[2]; + pp[19] = p1[3]; + pp[20] = p2[2]; + pp[21] = p2[3]; + pp[22] = p3[2]; + pp[23] = p3[3]; + pp[24] = p4[2]; + pp[25] = p4[3]; + pp[26] = p5[2]; + pp[27] = p5[3]; + pp[28] = p6[2]; + pp[29] = p6[3]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp += 16; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + p4 += 2; + p5 += 2; + p6 += 2; + p7 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* p0 = A.row(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + const signed char* p2 = A.row(i + ii + 2) + k; + const signed char* p3 = A.row(i + ii + 3) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x4_t _r0123; + _r0123.val[0] = vreinterpretq_s64_s8(_p0); + _r0123.val[1] = vreinterpretq_s64_s8(_p1); + _r0123.val[2] = vreinterpretq_s64_s8(_p2); + _r0123.val[3] = vreinterpretq_s64_s8(_p3); + vst4q_s64((int64_t*)pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x4_t _r0123; + _r0123.val[0] = vreinterpretq_s32_s8(_p0); + _r0123.val[1] = vreinterpretq_s32_s8(_p1); + _r0123.val[2] = vreinterpretq_s32_s8(_p2); + _r0123.val[3] = vreinterpretq_s32_s8(_p3); + vst4q_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x4_t _r0123; + _r0123.val[0] = vreinterpretq_s16_s8(_p0); + _r0123.val[1] = vreinterpretq_s16_s8(_p1); + _r0123.val[2] = vreinterpretq_s16_s8(_p2); + _r0123.val[3] = vreinterpretq_s16_s8(_p3); + vst4q_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 64; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); + vst1q_s8(pp + 16, vcombine_s8(_p2, _p3)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x4_t _r0123; + _r0123.val[0] = vreinterpret_s32_s8(_p0); + _r0123.val[1] = vreinterpret_s32_s8(_p1); + _r0123.val[2] = vreinterpret_s32_s8(_p2); + _r0123.val[3] = vreinterpret_s32_s8(_p3); + vst4_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x4_t _r0123; + _r0123.val[0] = vreinterpret_s16_s8(_p0); + _r0123.val[1] = vreinterpret_s16_s8(_p1); + _r0123.val[2] = vreinterpret_s16_s8(_p2); + _r0123.val[3] = vreinterpret_s16_s8(_p3); + vst4_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p0[2]; + pp[9] = p0[3]; + pp[10] = p1[2]; + pp[11] = p1[3]; + pp[12] = p2[2]; + pp[13] = p2[3]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* p0 = A.row(i + ii) + k; + const signed char* p1 = A.row(i + ii + 1) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x2_t _r01; + _r01.val[0] = vreinterpretq_s64_s8(_p0); + _r01.val[1] = vreinterpretq_s64_s8(_p1); + vst2q_s64((int64_t*)pp, _r01); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _r01; + _r01.val[0] = vreinterpretq_s32_s8(_p0); + _r01.val[1] = vreinterpretq_s32_s8(_p1); + vst2q_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _r01; + _r01.val[0] = vreinterpretq_s16_s8(_p0); + _r01.val[1] = vreinterpretq_s16_s8(_p1); + vst2q_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 16; + p1 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _r01; + _r01.val[0] = vreinterpret_s32_s8(_p0); + _r01.val[1] = vreinterpret_s32_s8(_p1); + vst2_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x2_t _r01; + _r01.val[0] = vreinterpret_s16_s8(_p0); + _r01.val[1] = vreinterpret_s16_s8(_p1); + vst2_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 8; + p1 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p0[2]; + pp[5] = p0[3]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 8; + p0 += 4; + p1 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + for (; ii < max_ii; ii += 1) + { + const signed char* p0 = A.row(i + ii) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + vst1q_s8(pp, vld1q_s8(p0)); + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } +} + +static void transpose_pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_A_tile_int8_i8mm(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_A_tile_int8_asimddp(A, AT, i, max_ii, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_A_tile_int8"); + // assert A.elempack == 1 + // assert A.dims == 2 + + const int A_hstep = A.w; + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _r0 = vld1_s8(p0); + int8x8_t _r1 = vld1_s8(p0 + A_hstep); + int8x8_t _r2 = vld1_s8(p0 + A_hstep * 2); + int8x8_t _r3 = vld1_s8(p0 + A_hstep * 3); + int8x8_t _r4 = vld1_s8(p0 + A_hstep * 4); + int8x8_t _r5 = vld1_s8(p0 + A_hstep * 5); + int8x8_t _r6 = vld1_s8(p0 + A_hstep * 6); + int8x8_t _r7 = vld1_s8(p0 + A_hstep * 7); + // transpose8x8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x8x4_t _r0123; + _r0123.val[0] = _r04.val[0]; + _r0123.val[1] = _r15.val[0]; + _r0123.val[2] = _r26.val[0]; + _r0123.val[3] = _r37.val[0]; + int8x8x4_t _r4567; + _r4567.val[0] = _r04.val[1]; + _r4567.val[1] = _r15.val[1]; + _r4567.val[2] = _r26.val[1]; + _r4567.val[3] = _r37.val[1]; + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); + pp += 64; + p0 += A_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8x4_t _r0123; + _r0123.val[0] = vld1_s8(p0); + _r0123.val[1] = vld1_s8(p0 + A_hstep); + _r0123.val[2] = vld1_s8(p0 + A_hstep * 2); + _r0123.val[3] = vld1_s8(p0 + A_hstep * 3); + vst4_s8(pp, _r0123); + pp += 32; + p0 += A_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + int8x8x2_t _r01; + _r01.val[0] = vld1_s8(p0); + _r01.val[1] = vld1_s8(p0 + A_hstep); + vst2_s8(pp, _r01); + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += A_hstep; + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[A_hstep * 4]; + pp[5] = p0[A_hstep * 5]; + pp[6] = p0[A_hstep * 6]; + pp[7] = p0[A_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[A_hstep + 1]; + pp[10] = p0[A_hstep * 2 + 1]; + pp[11] = p0[A_hstep * 3 + 1]; + pp[12] = p0[A_hstep * 4 + 1]; + pp[13] = p0[A_hstep * 5 + 1]; + pp[14] = p0[A_hstep * 6 + 1]; + pp[15] = p0[A_hstep * 7 + 1]; + pp[16] = p0[2]; + pp[17] = p0[A_hstep + 2]; + pp[18] = p0[A_hstep * 2 + 2]; + pp[19] = p0[A_hstep * 3 + 2]; + pp[20] = p0[A_hstep * 4 + 2]; + pp[21] = p0[A_hstep * 5 + 2]; + pp[22] = p0[A_hstep * 6 + 2]; + pp[23] = p0[A_hstep * 7 + 2]; + pp[24] = p0[3]; + pp[25] = p0[A_hstep + 3]; + pp[26] = p0[A_hstep * 2 + 3]; + pp[27] = p0[A_hstep * 3 + 3]; + pp[28] = p0[A_hstep * 4 + 3]; + pp[29] = p0[A_hstep * 5 + 3]; + pp[30] = p0[A_hstep * 6 + 3]; + pp[31] = p0[A_hstep * 7 + 3]; + pp += 32; + p0 += A_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp[8] = p0[2]; + pp[9] = p0[A_hstep + 2]; + pp[10] = p0[A_hstep * 2 + 2]; + pp[11] = p0[A_hstep * 3 + 2]; + pp[12] = p0[3]; + pp[13] = p0[A_hstep + 3]; + pp[14] = p0[A_hstep * 2 + 3]; + pp[15] = p0[A_hstep * 3 + 3]; + pp += 16; + p0 += A_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[A_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[A_hstep + 3]; + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += A_hstep; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[A_hstep * 4]; + pp[5] = p0[A_hstep * 5]; + pp[6] = p0[A_hstep * 6]; + pp[7] = p0[A_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[A_hstep + 1]; + pp[10] = p0[A_hstep * 2 + 1]; + pp[11] = p0[A_hstep * 3 + 1]; + pp[12] = p0[A_hstep * 4 + 1]; + pp[13] = p0[A_hstep * 5 + 1]; + pp[14] = p0[A_hstep * 6 + 1]; + pp[15] = p0[A_hstep * 7 + 1]; + pp += 16; + p0 += A_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[A_hstep * 2]; + pp[3] = p0[A_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[A_hstep + 1]; + pp[6] = p0[A_hstep * 2 + 1]; + pp[7] = p0[A_hstep * 3 + 1]; + pp += 8; + p0 += A_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[A_hstep]; + pp[2] = p0[1]; + pp[3] = p0[A_hstep + 1]; + pp += 4; + p0 += A_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += A_hstep; + } + } + for (; ii < max_ii; ii += 1) + { + const signed char* p0 = A.row(k) + (i + ii); + + int kk = 0; + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = p0[0]; + // pp[1] = p0[A_hstep]; + // pp += 2; + // p0 += A_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += A_hstep; + } + } +} + +static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_B_tile_int8_i8mm(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_B_tile_int8_asimddp(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("pack_B_tile_int8"); + // assert B.elempack == 1 + // assert B.dims == 2 + + signed char* pp = BT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + const signed char* p2 = B.row(j + jj + 2) + k; + const signed char* p3 = B.row(j + jj + 3) + k; + const signed char* p4 = B.row(j + jj + 4) + k; + const signed char* p5 = B.row(j + jj + 5) + k; + const signed char* p6 = B.row(j + jj + 6) + k; + const signed char* p7 = B.row(j + jj + 7) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); + int8x16_t _p4 = vld1q_s8(p4); + int8x16_t _p5 = vld1q_s8(p5); + int8x16_t _p6 = vld1q_s8(p6); + int8x16_t _p7 = vld1q_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(vget_low_s8(_p0), vget_low_s8(_p1)); + int8x16_t _r1 = vcombine_s8(vget_low_s8(_p2), vget_low_s8(_p3)); + int8x16_t _r2 = vcombine_s8(vget_low_s8(_p4), vget_low_s8(_p5)); + int8x16_t _r3 = vcombine_s8(vget_low_s8(_p6), vget_low_s8(_p7)); + int8x16_t _r4 = vcombine_s8(vget_high_s8(_p0), vget_high_s8(_p1)); + int8x16_t _r5 = vcombine_s8(vget_high_s8(_p2), vget_high_s8(_p3)); + int8x16_t _r6 = vcombine_s8(vget_high_s8(_p4), vget_high_s8(_p5)); + int8x16_t _r7 = vcombine_s8(vget_high_s8(_p6), vget_high_s8(_p7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _p01 = vzipq_s32(vreinterpretq_s32_s8(_p0), vreinterpretq_s32_s8(_p1)); + int32x4x2_t _p23 = vzipq_s32(vreinterpretq_s32_s8(_p2), vreinterpretq_s32_s8(_p3)); + int32x4x2_t _p45 = vzipq_s32(vreinterpretq_s32_s8(_p4), vreinterpretq_s32_s8(_p5)); + int32x4x2_t _p67 = vzipq_s32(vreinterpretq_s32_s8(_p6), vreinterpretq_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[0]), vget_low_s32(_p23.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[0]), vget_low_s32(_p67.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[0]), vget_high_s32(_p23.val[0]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[0]), vget_high_s32(_p67.val[0]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p01.val[1]), vget_low_s32(_p23.val[1]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_p45.val[1]), vget_low_s32(_p67.val[1]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p01.val[1]), vget_high_s32(_p23.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_p45.val[1]), vget_high_s32(_p67.val[1]))); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _p01 = vzipq_s16(vreinterpretq_s16_s8(_p0), vreinterpretq_s16_s8(_p1)); + int16x8x2_t _p23 = vzipq_s16(vreinterpretq_s16_s8(_p2), vreinterpretq_s16_s8(_p3)); + int16x8x2_t _p45 = vzipq_s16(vreinterpretq_s16_s8(_p4), vreinterpretq_s16_s8(_p5)); + int16x8x2_t _p67 = vzipq_s16(vreinterpretq_s16_s8(_p6), vreinterpretq_s16_s8(_p7)); + int32x4x2_t _t0 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[0]), vreinterpretq_s32_s16(_p23.val[0])); + int32x4x2_t _t1 = vzipq_s32(vreinterpretq_s32_s16(_p01.val[1]), vreinterpretq_s32_s16(_p23.val[1])); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[0]), vreinterpretq_s32_s16(_p67.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_p45.val[1]), vreinterpretq_s32_s16(_p67.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t2.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t2.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t2.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t2.val[1]))); + int8x16_t _r4 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r5 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r6 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r7 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + vst1q_s8(pp + 64, _r4); + vst1q_s8(pp + 80, _r5); + vst1q_s8(pp + 96, _r6); + vst1q_s8(pp + 112, _r7); + pp += 128; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + p4 += 16; + p5 += 16; + p6 += 16; + p7 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); + int8x8_t _p4 = vld1_s8(p4); + int8x8_t _p5 = vld1_s8(p5); + int8x8_t _p6 = vld1_s8(p6); + int8x8_t _p7 = vld1_s8(p7); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _r0 = vcombine_s8(_p0, _p1); + int8x16_t _r1 = vcombine_s8(_p2, _p3); + int8x16_t _r2 = vcombine_s8(_p4, _p5); + int8x16_t _r3 = vcombine_s8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _p01 = vzip_s32(vreinterpret_s32_s8(_p0), vreinterpret_s32_s8(_p1)); + int32x2x2_t _p23 = vzip_s32(vreinterpret_s32_s8(_p2), vreinterpret_s32_s8(_p3)); + int32x2x2_t _p45 = vzip_s32(vreinterpret_s32_s8(_p4), vreinterpret_s32_s8(_p5)); + int32x2x2_t _p67 = vzip_s32(vreinterpret_s32_s8(_p6), vreinterpret_s32_s8(_p7)); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[0], _p23.val[0])); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[0], _p67.val[0])); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[1], _p23.val[1])); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[1], _p67.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8_t _p04 = vreinterpretq_s16_s8(vcombine_s8(_p0, _p4)); + int16x8_t _p15 = vreinterpretq_s16_s8(vcombine_s8(_p1, _p5)); + int16x8_t _p26 = vreinterpretq_s16_s8(vcombine_s8(_p2, _p6)); + int16x8_t _p37 = vreinterpretq_s16_s8(vcombine_s8(_p3, _p7)); + int16x8x2_t _t0 = vzipq_s16(_p04, _p15); + int16x8x2_t _t1 = vzipq_s16(_p26, _p37); + int32x4x2_t _t2 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[0]), vreinterpretq_s32_s16(_t1.val[0])); + int32x4x2_t _t3 = vzipq_s32(vreinterpretq_s32_s16(_t0.val[1]), vreinterpretq_s32_s16(_t1.val[1])); + int8x16_t _r0 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0]))); + int8x16_t _r1 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0]))); + int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1]))); + int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1]))); +#endif // __ARM_FEATURE_DOTPROD + vst1q_s8(pp, _r0); + vst1q_s8(pp + 16, _r1); + vst1q_s8(pp + 32, _r2); + vst1q_s8(pp + 48, _r3); + pp += 64; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + p4 += 8; + p5 += 8; + p6 += 8; + p7 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; + pp[16] = p4[0]; + pp[17] = p4[1]; + pp[18] = p4[2]; + pp[19] = p4[3]; + pp[20] = p5[0]; + pp[21] = p5[1]; + pp[22] = p5[2]; + pp[23] = p5[3]; + pp[24] = p6[0]; + pp[25] = p6[1]; + pp[26] = p6[2]; + pp[27] = p6[3]; + pp[28] = p7[0]; + pp[29] = p7[1]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp[16] = p0[2]; + pp[17] = p0[3]; + pp[18] = p1[2]; + pp[19] = p1[3]; + pp[20] = p2[2]; + pp[21] = p2[3]; + pp[22] = p3[2]; + pp[23] = p3[3]; + pp[24] = p4[2]; + pp[25] = p4[3]; + pp[26] = p5[2]; + pp[27] = p5[3]; + pp[28] = p6[2]; + pp[29] = p6[3]; + pp[30] = p7[2]; + pp[31] = p7[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + p4 += 4; + p5 += 4; + p6 += 4; + p7 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p4[0]; + pp[9] = p4[1]; + pp[10] = p5[0]; + pp[11] = p5[1]; + pp[12] = p6[0]; + pp[13] = p6[1]; + pp[14] = p7[0]; + pp[15] = p7[1]; + pp += 16; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + p4 += 2; + p5 += 2; + p6 += 2; + p7 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp[4] = p4[0]; + pp[5] = p5[0]; + pp[6] = p6[0]; + pp[7] = p7[0]; + pp += 8; + p0++; + p1++; + p2++; + p3++; + p4++; + p5++; + p6++; + p7++; + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + const signed char* p2 = B.row(j + jj + 2) + k; + const signed char* p3 = B.row(j + jj + 3) + k; + + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); + int8x16_t _p2 = vld1q_s8(p2); + int8x16_t _p3 = vld1q_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x4_t _r0123; + _r0123.val[0] = vreinterpretq_s64_s8(_p0); + _r0123.val[1] = vreinterpretq_s64_s8(_p1); + _r0123.val[2] = vreinterpretq_s64_s8(_p2); + _r0123.val[3] = vreinterpretq_s64_s8(_p3); + vst4q_s64((int64_t*)pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x4_t _r0123; + _r0123.val[0] = vreinterpretq_s32_s8(_p0); + _r0123.val[1] = vreinterpretq_s32_s8(_p1); + _r0123.val[2] = vreinterpretq_s32_s8(_p2); + _r0123.val[3] = vreinterpretq_s32_s8(_p3); + vst4q_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x4_t _r0123; + _r0123.val[0] = vreinterpretq_s16_s8(_p0); + _r0123.val[1] = vreinterpretq_s16_s8(_p1); + _r0123.val[2] = vreinterpretq_s16_s8(_p2); + _r0123.val[3] = vreinterpretq_s16_s8(_p3); + vst4q_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 64; + p0 += 16; + p1 += 16; + p2 += 16; + p3 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); + int8x8_t _p2 = vld1_s8(p2); + int8x8_t _p3 = vld1_s8(p3); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); + vst1q_s8(pp + 16, vcombine_s8(_p2, _p3)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x4_t _r0123; + _r0123.val[0] = vreinterpret_s32_s8(_p0); + _r0123.val[1] = vreinterpret_s32_s8(_p1); + _r0123.val[2] = vreinterpret_s32_s8(_p2); + _r0123.val[3] = vreinterpret_s32_s8(_p3); + vst4_s32((int*)pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x4_t _r0123; + _r0123.val[0] = vreinterpret_s16_s8(_p0); + _r0123.val[1] = vreinterpret_s16_s8(_p1); + _r0123.val[2] = vreinterpret_s16_s8(_p2); + _r0123.val[3] = vreinterpret_s16_s8(_p3); + vst4_s16((short*)pp, _r0123); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 8; + p1 += 8; + p2 += 8; + p3 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; + pp[8] = p2[0]; + pp[9] = p2[1]; + pp[10] = p2[2]; + pp[11] = p2[3]; + pp[12] = p3[0]; + pp[13] = p3[1]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp[8] = p0[2]; + pp[9] = p0[3]; + pp[10] = p1[2]; + pp[11] = p1[3]; + pp[12] = p2[2]; + pp[13] = p2[3]; + pp[14] = p3[2]; + pp[15] = p3[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 4; + p1 += 4; + p2 += 4; + p3 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p2[0]; + pp[5] = p2[1]; + pp[6] = p3[0]; + pp[7] = p3[1]; + pp += 8; + p0 += 2; + p1 += 2; + p2 += 2; + p3 += 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp[2] = p2[0]; + pp[3] = p3[0]; + pp += 4; + p0++; + p1++; + p2++; + p3++; + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* p0 = B.row(j + jj) + k; + const signed char* p1 = B.row(j + jj + 1) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _p0 = vld1q_s8(p0); + int8x16_t _p1 = vld1q_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int64x2x2_t _r01; + _r01.val[0] = vreinterpretq_s64_s8(_p0); + _r01.val[1] = vreinterpretq_s64_s8(_p1); + vst2q_s64((int64_t*)pp, _r01); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _r01; + _r01.val[0] = vreinterpretq_s32_s8(_p0); + _r01.val[1] = vreinterpretq_s32_s8(_p1); + vst2q_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x8x2_t _r01; + _r01.val[0] = vreinterpretq_s16_s8(_p0); + _r01.val[1] = vreinterpretq_s16_s8(_p1); + vst2q_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 32; + p0 += 16; + p1 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _p0 = vld1_s8(p0); + int8x8_t _p1 = vld1_s8(p1); +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + vst1q_s8(pp, vcombine_s8(_p0, _p1)); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _r01; + _r01.val[0] = vreinterpret_s32_s8(_p0); + _r01.val[1] = vreinterpret_s32_s8(_p1); + vst2_s32((int*)pp, _r01); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4x2_t _r01; + _r01.val[0] = vreinterpret_s16_s8(_p0); + _r01.val[1] = vreinterpret_s16_s8(_p1); + vst2_s16((short*)pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + pp += 16; + p0 += 8; + p1 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp[4] = p1[0]; + pp[5] = p1[1]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#else // __ARM_FEATURE_DOTPROD + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp[4] = p0[2]; + pp[5] = p0[3]; + pp[6] = p1[2]; + pp[7] = p1[3]; +#endif // __ARM_FEATURE_DOTPROD + pp += 8; + p0 += 4; + p1 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p1[0]; + pp[3] = p1[1]; + pp += 4; + p0 += 2; + p1 += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p1[0]; + pp += 2; + p0++; + p1++; + } + } + for (; jj < max_jj; jj += 1) + { + const signed char* p0 = B.row(j + jj) + k; + + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + vst1q_s8(pp, vld1q_s8(p0)); + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0++; + } + } +} + +static void transpose_pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_B_tile_int8_i8mm(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_B_tile_int8_asimddp(B, BT, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("transpose_pack_B_tile_int8"); + // assert B.elempack == 1 + // assert B.dims == 2 + + const int B_hstep = B.w; + + signed char* pp = BT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _r0 = vld1_s8(p0); + int8x8_t _r1 = vld1_s8(p0 + B_hstep); + int8x8_t _r2 = vld1_s8(p0 + B_hstep * 2); + int8x8_t _r3 = vld1_s8(p0 + B_hstep * 3); + int8x8_t _r4 = vld1_s8(p0 + B_hstep * 4); + int8x8_t _r5 = vld1_s8(p0 + B_hstep * 5); + int8x8_t _r6 = vld1_s8(p0 + B_hstep * 6); + int8x8_t _r7 = vld1_s8(p0 + B_hstep * 7); + // transpose8x8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x8x4_t _r0123; + _r0123.val[0] = _r04.val[0]; + _r0123.val[1] = _r15.val[0]; + _r0123.val[2] = _r26.val[0]; + _r0123.val[3] = _r37.val[0]; + int8x8x4_t _r4567; + _r4567.val[0] = _r04.val[1]; + _r4567.val[1] = _r15.val[1]; + _r4567.val[2] = _r26.val[1]; + _r4567.val[3] = _r37.val[1]; + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); + pp += 64; + p0 += B_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8x4_t _r0123; + _r0123.val[0] = vld1_s8(p0); + _r0123.val[1] = vld1_s8(p0 + B_hstep); + _r0123.val[2] = vld1_s8(p0 + B_hstep * 2); + _r0123.val[3] = vld1_s8(p0 + B_hstep * 3); + vst4_s8(pp, _r0123); + pp += 32; + p0 += B_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + int8x8x2_t _r01; + _r01.val[0] = vld1_s8(p0); + _r01.val[1] = vld1_s8(p0 + B_hstep); + vst2_s8(pp, _r01); + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + vst1_s8(pp, vld1_s8(p0)); + pp += 8; + p0 += B_hstep; + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[B_hstep * 4]; + pp[5] = p0[B_hstep * 5]; + pp[6] = p0[B_hstep * 6]; + pp[7] = p0[B_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[B_hstep + 1]; + pp[10] = p0[B_hstep * 2 + 1]; + pp[11] = p0[B_hstep * 3 + 1]; + pp[12] = p0[B_hstep * 4 + 1]; + pp[13] = p0[B_hstep * 5 + 1]; + pp[14] = p0[B_hstep * 6 + 1]; + pp[15] = p0[B_hstep * 7 + 1]; + pp[16] = p0[2]; + pp[17] = p0[B_hstep + 2]; + pp[18] = p0[B_hstep * 2 + 2]; + pp[19] = p0[B_hstep * 3 + 2]; + pp[20] = p0[B_hstep * 4 + 2]; + pp[21] = p0[B_hstep * 5 + 2]; + pp[22] = p0[B_hstep * 6 + 2]; + pp[23] = p0[B_hstep * 7 + 2]; + pp[24] = p0[3]; + pp[25] = p0[B_hstep + 3]; + pp[26] = p0[B_hstep * 2 + 3]; + pp[27] = p0[B_hstep * 3 + 3]; + pp[28] = p0[B_hstep * 4 + 3]; + pp[29] = p0[B_hstep * 5 + 3]; + pp[30] = p0[B_hstep * 6 + 3]; + pp[31] = p0[B_hstep * 7 + 3]; + pp += 32; + p0 += B_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[B_hstep + 1]; + pp[6] = p0[B_hstep * 2 + 1]; + pp[7] = p0[B_hstep * 3 + 1]; + pp[8] = p0[2]; + pp[9] = p0[B_hstep + 2]; + pp[10] = p0[B_hstep * 2 + 2]; + pp[11] = p0[B_hstep * 3 + 2]; + pp[12] = p0[3]; + pp[13] = p0[B_hstep + 3]; + pp[14] = p0[B_hstep * 2 + 3]; + pp[15] = p0[B_hstep * 3 + 3]; + pp += 16; + p0 += B_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp[4] = p0[2]; + pp[5] = p0[B_hstep + 2]; + pp[6] = p0[3]; + pp[7] = p0[B_hstep + 3]; + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp[2] = p0[2]; + pp[3] = p0[3]; + pp += 4; + p0 += B_hstep; + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[B_hstep * 4]; + pp[5] = p0[B_hstep * 5]; + pp[6] = p0[B_hstep * 6]; + pp[7] = p0[B_hstep * 7]; + pp[8] = p0[1]; + pp[9] = p0[B_hstep + 1]; + pp[10] = p0[B_hstep * 2 + 1]; + pp[11] = p0[B_hstep * 3 + 1]; + pp[12] = p0[B_hstep * 4 + 1]; + pp[13] = p0[B_hstep * 5 + 1]; + pp[14] = p0[B_hstep * 6 + 1]; + pp[15] = p0[B_hstep * 7 + 1]; + pp += 16; + p0 += B_hstep * 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[B_hstep * 2]; + pp[3] = p0[B_hstep * 3]; + pp[4] = p0[1]; + pp[5] = p0[B_hstep + 1]; + pp[6] = p0[B_hstep * 2 + 1]; + pp[7] = p0[B_hstep * 3 + 1]; + pp += 8; + p0 += B_hstep * 4; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = p0[0]; + pp[1] = p0[B_hstep]; + pp[2] = p0[1]; + pp[3] = p0[B_hstep + 1]; + pp += 4; + p0 += B_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp[1] = p0[1]; + pp += 2; + p0 += B_hstep; + } + } + for (; jj < max_jj; jj += 1) + { + const signed char* p0 = B.row(k) + (j + jj); + + int kk = 0; + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = p0[0]; + // pp[1] = p0[B_hstep]; + // pp += 2; + // p0 += B_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = p0[0]; + pp += 1; + p0 += B_hstep; + } + } +} + +static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.w; + + // NCNN_LOGE("compute_A_tile_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif + + for (int ii = 0; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); + + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + +#endif + ps += 4; + pods += 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + for (int ii = 0; ii < max_ii; ii++) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep; + + float absmax = 0.f; + int kk = 0; +#if __ARM_NEON + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (; kk + 15 < K; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 7 < K; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif // __ARM_NEON + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabs(p0[0])); + p0++; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_A_tile_fp32_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_A_tile_fp32_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + float32x4x4_t _r = vld4q_f32(p0 + A_hstep * 4); + float32x4x4_t _s = vld4q_f32(p0 + A_hstep * 4 + 16); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(_q.val[0], _scale0, 0); + float32x4_t _p5 = vmulq_laneq_f32(_q.val[1], _scale0, 1); + float32x4_t _p6 = vmulq_laneq_f32(_q.val[2], _scale0, 2); + float32x4_t _p7 = vmulq_laneq_f32(_q.val[3], _scale0, 3); + float32x4_t _p8 = vmulq_laneq_f32(_r.val[0], _scale1, 0); + float32x4_t _p9 = vmulq_laneq_f32(_r.val[1], _scale1, 1); + float32x4_t _pa = vmulq_laneq_f32(_r.val[2], _scale1, 2); + float32x4_t _pb = vmulq_laneq_f32(_r.val[3], _scale1, 3); + float32x4_t _pc = vmulq_laneq_f32(_s.val[0], _scale1, 0); + float32x4_t _pd = vmulq_laneq_f32(_s.val[1], _scale1, 1); + float32x4_t _pe = vmulq_laneq_f32(_s.val[2], _scale1, 2); + float32x4_t _pf = vmulq_laneq_f32(_s.val[3], _scale1, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 4 + 28); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale0); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale0); + _p8 = vmulq_f32(_p8, _scale1); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale1); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale1); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale1); + _pf = vmulq_f32(_pf, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(_q.val[0], _scale1, 0); + float32x4_t _p5 = vmulq_laneq_f32(_q.val[1], _scale1, 1); + float32x4_t _p6 = vmulq_laneq_f32(_q.val[2], _scale1, 2); + float32x4_t _p7 = vmulq_laneq_f32(_q.val[3], _scale1, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 4 + 12); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale1); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale1); + _p7 = vmulq_f32(_p7, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p0n = vld1q_f32(p0 + 4); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p1n = vld1q_f32(p0 + A_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p0n = vmulq_f32(_p0n, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p1n = vmulq_f32(_p1n, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p0n, _p1n); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 7 + 4); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 0); + _p2 = vmulq_laneq_f32(_p2, _scale0, 1); + _p3 = vmulq_laneq_f32(_p3, _scale0, 1); + _p4 = vmulq_laneq_f32(_p4, _scale0, 2); + _p5 = vmulq_laneq_f32(_p5, _scale0, 2); + _p6 = vmulq_laneq_f32(_p6, _scale0, 3); + _p7 = vmulq_laneq_f32(_p7, _scale0, 3); + _p8 = vmulq_laneq_f32(_p8, _scale1, 0); + _p9 = vmulq_laneq_f32(_p9, _scale1, 0); + _pa = vmulq_laneq_f32(_pa, _scale1, 1); + _pb = vmulq_laneq_f32(_pb, _scale1, 1); + _pc = vmulq_laneq_f32(_pc, _scale1, 2); + _pd = vmulq_laneq_f32(_pd, _scale1, 2); + _pe = vmulq_laneq_f32(_pe, _scale1, 3); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 7); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + A_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + A_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + A_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + A_hstep * 7); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + float32x4x2_t _scale01 = vzipq_f32(_scale0, _scale0); + float32x4x2_t _scale23 = vzipq_f32(_scale1, _scale1); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + _p45 = vmulq_f32(_p45, _scale23.val[0]); + _p67 = vmulq_f32(_p67, _scale23.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[A_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[A_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[A_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[A_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale, 3); + float32x4_t _p4 = vmulq_laneq_f32(_q.val[0], _scale, 0); + float32x4_t _p5 = vmulq_laneq_f32(_q.val[1], _scale, 1); + float32x4_t _p6 = vmulq_laneq_f32(_q.val[2], _scale, 2); + float32x4_t _p7 = vmulq_laneq_f32(_q.val[3], _scale, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + + float32x4_t _p0 = vmulq_laneq_f32(_p.val[0], _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(_p.val[1], _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(_p.val[2], _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(_p.val[3], _scale, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 0); + _p2 = vmulq_laneq_f32(_p2, _scale, 1); + _p3 = vmulq_laneq_f32(_p3, _scale, 1); + _p4 = vmulq_laneq_f32(_p4, _scale, 2); + _p5 = vmulq_laneq_f32(_p5, _scale, 2); + _p6 = vmulq_laneq_f32(_p6, _scale, 3); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + float32x4x2_t _scale01 = vzipq_f32(_scale, _scale); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale1); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale0); + pp[2] = float2int8(p0[A_hstep] * scale1); + pp[3] = float2int8(p0[A_hstep + 1] * scale1); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(p0[0] * scale0); + // pp[1] = float2int8(p0[1] * scale0); + // pp[2] = float2int8(p0[A_hstep] * scale1); + // pp[3] = float2int8(p0[A_hstep + 1] * scale1); + // pp += 4; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep] * scale1); + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; + + const float scale = scales[ii]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(p0[0] * scale); + // pp[1] = float2int8(p0[1] * scale); + // pp += 2; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.dims == 3 ? A.c : A.h; + + // NCNN_LOGE("transpose_compute_A_tile_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + +#if __ARM_NEON +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif +#endif + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { + int ii = 0; + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (int kk = 0; kk < K; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); + float32x2_t _aa2 = vmax_f32(vget_low_f32(_absmax2), vget_high_f32(_absmax2)); + float32x2_t _aa3 = vmax_f32(vget_low_f32(_absmax3), vget_high_f32(_absmax3)); + float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); + float32x2_t _aa23 = vpmax_f32(_aa2, _aa3); + float32x4_t _absmax = vcombine_f32(_aa01, _aa23); + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax); + float32x4_t _out_descale = vdivq_f32(_absmax, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } + for (; ii < max_ii; ii++) + { + const float* p0 = (const float*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 8); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 12); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep * 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int ii = 0; +#if __ARM_NEON + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + (i + ii); + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 2; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = vld1q_f32(p0); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + (i + ii); + + float32x2_t _absmax0 = vdup_n_f32(0.f); + float32x2_t _absmax1 = vdup_n_f32(0.f); + float32x2_t _absmax2 = vdup_n_f32(0.f); + float32x2_t _absmax3 = vdup_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + _absmax0 = vmax_f32(_absmax0, vabs_f32(_p0)); + _absmax1 = vmax_f32(_absmax1, vabs_f32(_p1)); + _absmax2 = vmax_f32(_absmax2, vabs_f32(_p2)); + _absmax3 = vmax_f32(_absmax3, vabs_f32(_p3)); + p0 += A_hstep * 4; + } + _absmax0 = vmax_f32(_absmax0, _absmax2); + _absmax1 = vmax_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + _absmax0 = vmax_f32(_absmax0, vabs_f32(_p0)); + _absmax1 = vmax_f32(_absmax1, vabs_f32(_p1)); + p0 += A_hstep * 2; + } + _absmax0 = vmax_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x2_t _p = vld1_f32(p0); + _absmax0 = vmax_f32(_absmax0, vabs_f32(_p)); + p0 += A_hstep; + } + +#if __aarch64__ + float32x2_t _scale = vdiv_f32(vget_low_f32(_v127), _absmax0); + float32x2_t _out_descale = vdiv_f32(_absmax0, vget_low_f32(_v127_B_scale)); + + vst1_f32(ps, _scale); + vst1_f32(pods, _out_descale); +#else + float tmp[2]; + vst1_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + + // float32x2_t _recp_absmax = vrecpe_f32(_absmax0); + // _recp_absmax = vmul_f32(vrecps_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmul_f32(vrecps_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmul_f32(vrecps_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x2_t _scale = vmul_f32(vget_low_f32(_v127), _recp_absmax); + // float32x2_t _out_descale = vmul_f32(_absmax0, vget_low_f32(_recp_v127_B_scale)); +#endif + + ps += 2; + pods += 2; + } +#endif // __ARM_NEON + for (; ii < max_ii; ii++) + { + const float* p0 = (const float*)A + (i + ii); + + float absmax = 0.f; + for (int kk = 0; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabs(p0[0])); + p0 += A_hstep; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_A_tile_fp32_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_A_tile_fp32_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("transpose_pack_A_tile_fp32_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 4 + 28); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + _p8 = vmulq_laneq_f32(_p8, _scale0, 0); + _p9 = vmulq_laneq_f32(_p9, _scale0, 1); + _pa = vmulq_laneq_f32(_pa, _scale0, 2); + _pb = vmulq_laneq_f32(_pb, _scale0, 3); + _pc = vmulq_laneq_f32(_pc, _scale1, 0); + _pd = vmulq_laneq_f32(_pd, _scale1, 1); + _pe = vmulq_laneq_f32(_pe, _scale1, 2); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + A_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + A_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + A_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + A_hstep * 7 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + _p8 = vmulq_f32(_p8, _scale0); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale0); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale0); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale0); + _pf = vmulq_f32(_pf, _scale1); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 4 + 12); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + _p4 = vmulq_laneq_f32(_p4, _scale, 0); + _p5 = vmulq_laneq_f32(_p5, _scale, 1); + _p6 = vmulq_laneq_f32(_p6, _scale, 2); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + A_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 7); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii); + + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vzipq_f32(_scale0, _scale1).val[0]; + for (; kk + 7 < max_kk; kk += 8) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + A_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + A_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + A_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + A_hstep * 7); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p46 = vcombine_f32(_p4, _p6); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + float32x4_t _p57 = vcombine_f32(_p5, _p7); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + A_hstep); + float32x2_t _p2 = vld1_f32(p0 + A_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + A_hstep * 3); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[A_hstep + 0] * scale0); + pp[2] = float2int8(p0[1] * scale1); + pp[3] = float2int8(p0[A_hstep + 1] * scale1); + pp += 4; + p0 += A_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); + pp[1] = float2int8(p0[1] * scale1); + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; + + const float scale = scales[ii]; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 8); + float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + A_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + float32x4_t _p2 = float32x4_t(); + float32x4_t _p3 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[A_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[A_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[A_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[A_hstep * 7], _p1, 3); + _p2 = vsetq_lane_f32(p0[A_hstep * 8], _p2, 0); + _p2 = vsetq_lane_f32(p0[A_hstep * 9], _p2, 1); + _p2 = vsetq_lane_f32(p0[A_hstep * 10], _p2, 2); + _p2 = vsetq_lane_f32(p0[A_hstep * 11], _p2, 3); + _p3 = vsetq_lane_f32(p0[A_hstep * 12], _p3, 0); + _p3 = vsetq_lane_f32(p0[A_hstep * 13], _p3, 1); + _p3 = vsetq_lane_f32(p0[A_hstep * 14], _p3, 2); + _p3 = vsetq_lane_f32(p0[A_hstep * 15], _p3, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[A_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[A_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[A_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[A_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[A_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[A_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[A_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(p0[0] * scale); + // pp[1] = float2int8(p0[A_hstep] * scale); + // pp += 2; + // p0 += A_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_B_tile_fp32_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_B_tile_fp32_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("pack_B_tile_fp32_to_int8 %d %d %d", max_jj, max_kk, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + float32x4x4_t _r = vld4q_f32(p0 + B_hstep * 4); + float32x4x4_t _s = vld4q_f32(p0 + B_hstep * 4 + 16); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + float32x4_t _p4 = vmulq_f32(_q.val[0], _scale); + float32x4_t _p5 = vmulq_f32(_q.val[1], _scale); + float32x4_t _p6 = vmulq_f32(_q.val[2], _scale); + float32x4_t _p7 = vmulq_f32(_q.val[3], _scale); + float32x4_t _p8 = vmulq_f32(_r.val[0], _scale); + float32x4_t _p9 = vmulq_f32(_r.val[1], _scale); + float32x4_t _pa = vmulq_f32(_r.val[2], _scale); + float32x4_t _pb = vmulq_f32(_r.val[3], _scale); + float32x4_t _pc = vmulq_f32(_s.val[0], _scale); + float32x4_t _pd = vmulq_f32(_s.val[1], _scale); + float32x4_t _pe = vmulq_f32(_s.val[2], _scale); + float32x4_t _pf = vmulq_f32(_s.val[3], _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 4 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + float32x4_t _p4 = vmulq_f32(_q.val[0], _scale); + float32x4_t _p5 = vmulq_f32(_q.val[1], _scale); + float32x4_t _p6 = vmulq_f32(_q.val[2], _scale); + float32x4_t _p7 = vmulq_f32(_q.val[3], _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 4 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 7 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 7); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + B_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + B_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + B_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + B_hstep * 7); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[B_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[B_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[B_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[B_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0++; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + float32x4x4_t _q = vld4q_f32(p0 + 16); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + float32x4_t _p4 = vmulq_f32(_q.val[0], _scale); + float32x4_t _p5 = vmulq_f32(_q.val[1], _scale); + float32x4_t _p6 = vmulq_f32(_q.val[2], _scale); + float32x4_t _p7 = vmulq_f32(_q.val[3], _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + float32x4x4_t _p = vld4q_f32(p0); + + float32x4_t _p0 = vmulq_f32(_p.val[0], _scale); + float32x4_t _p1 = vmulq_f32(_p.val[1], _scale); + float32x4_t _p2 = vmulq_f32(_p.val[2], _scale); + float32x4_t _p3 = vmulq_f32(_p.val[3], _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[B_hstep] * scale); + pp[3] = float2int8(p0[B_hstep + 1] * scale); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(p0[0] * scale); + // pp[1] = float2int8(p0[1] * scale); + // pp[2] = float2int8(p0[B_hstep] * scale); + // pp[3] = float2int8(p0[B_hstep + 1] * scale); + // pp += 4; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[B_hstep] * scale); + pp += 2; + p0++; + } + } + } + for (; jj < max_jj; jj += 1) + { + const float* p0 = (const float*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(p0[0] * scale); + // pp[1] = float2int8(p0[1] * scale); + // pp += 2; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_B_tile_fp32_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_B_tile_fp32_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("transpose_pack_B_tile_fp32_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 4 + 12); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 4 + 16); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 4 + 20); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 4 + 24); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 4 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + 16); + float32x4_t _p5 = vld1q_f32(p0 + 20); + float32x4_t _p6 = vld1q_f32(p0 + 24); + float32x4_t _p7 = vld1q_f32(p0 + 28); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + float32x4_t _p8 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p9 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _pa = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _pb = vld1q_f32(p0 + B_hstep * 5 + 4); + float32x4_t _pc = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _pd = vld1q_f32(p0 + B_hstep * 6 + 4); + float32x4_t _pe = vld1q_f32(p0 + B_hstep * 7); + float32x4_t _pf = vld1q_f32(p0 + B_hstep * 7 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 2 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 3 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += B_hstep; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 4 + 4); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 4 + 8); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 4 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + 8); + float32x4_t _p3 = vld1q_f32(p0 + 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + float32x4_t _p4 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p5 = vld1q_f32(p0 + B_hstep * 5); + float32x4_t _p6 = vld1q_f32(p0 + B_hstep * 6); + float32x4_t _p7 = vld1q_f32(p0 + B_hstep * 7); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 2); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 4 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + float32x2_t _p4 = vld1_f32(p0 + B_hstep * 4); + float32x2_t _p5 = vld1_f32(p0 + B_hstep * 5); + float32x2_t _p6 = vld1_f32(p0 + B_hstep * 6); + float32x2_t _p7 = vld1_f32(p0 + B_hstep * 7); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + float32x4_t _p45 = vcombine_f32(_p4, _p5); + float32x4_t _p67 = vcombine_f32(_p6, _p7); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p46 = vcombine_f32(_p4, _p6); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + float32x4_t _p57 = vcombine_f32(_p5, _p7); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x2_t _p0 = vld1_f32(p0); + float32x2_t _p1 = vld1_f32(p0 + B_hstep); + float32x2_t _p2 = vld1_f32(p0 + B_hstep * 2); + float32x2_t _p3 = vld1_f32(p0 + B_hstep * 3); + +#if __ARM_FEATURE_DOTPROD + float32x4_t _p01 = vcombine_f32(_p0, _p1); + float32x4_t _p23 = vcombine_f32(_p2, _p3); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _p02 = vcombine_f32(_p0, _p2); + float32x4_t _p13 = vcombine_f32(_p1, _p3); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[B_hstep + 0] * scale); + pp[2] = float2int8(p0[1] * scale); + pp[3] = float2int8(p0[B_hstep + 1] * scale); + pp += 4; + p0 += B_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { + const float* p0 = (const float*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep * 4); + float32x4_t _p2 = vld1q_f32(p0 + B_hstep * 8); + float32x4_t _p3 = vld1q_f32(p0 + B_hstep * 12); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = vld1q_f32(p0); + float32x4_t _p1 = vld1q_f32(p0 + B_hstep * 4); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(p0[0] * scale); + pp[1] = float2int8(p0[1] * scale); + pp[2] = float2int8(p0[2] * scale); + pp[3] = float2int8(p0[3] * scale); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + float32x4_t _p2 = float32x4_t(); + float32x4_t _p3 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[B_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[B_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[B_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[B_hstep * 7], _p1, 3); + _p2 = vsetq_lane_f32(p0[B_hstep * 8], _p2, 0); + _p2 = vsetq_lane_f32(p0[B_hstep * 9], _p2, 1); + _p2 = vsetq_lane_f32(p0[B_hstep * 10], _p2, 2); + _p2 = vsetq_lane_f32(p0[B_hstep * 11], _p2, 3); + _p3 = vsetq_lane_f32(p0[B_hstep * 12], _p3, 0); + _p3 = vsetq_lane_f32(p0[B_hstep * 13], _p3, 1); + _p3 = vsetq_lane_f32(p0[B_hstep * 14], _p3, 2); + _p3 = vsetq_lane_f32(p0[B_hstep * 15], _p3, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = float32x4_t(); + float32x4_t _p1 = float32x4_t(); + _p0 = vsetq_lane_f32(p0[0], _p0, 0); + _p0 = vsetq_lane_f32(p0[B_hstep], _p0, 1); + _p0 = vsetq_lane_f32(p0[B_hstep * 2], _p0, 2); + _p0 = vsetq_lane_f32(p0[B_hstep * 3], _p0, 3); + _p1 = vsetq_lane_f32(p0[B_hstep * 4], _p1, 0); + _p1 = vsetq_lane_f32(p0[B_hstep * 5], _p1, 1); + _p1 = vsetq_lane_f32(p0[B_hstep * 6], _p1, 2); + _p1 = vsetq_lane_f32(p0[B_hstep * 7], _p1, 3); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(p0[0] * scale); + // pp[1] = float2int8(p0[B_hstep] * scale); + // pp += 2; + // p0 += B_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale); + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + unpack_output_tile_int32_to_fp32_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const float* pC = C; + + // NCNN_LOGE("unpack_output_tile_int32_to_fp32 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0]); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _ca = vld1q_f32(pC + c_hstep * 5); + float32x4_t _cb = vld1q_f32(pC + c_hstep * 5 + 4); + float32x4_t _cc = vld1q_f32(pC + c_hstep * 6); + float32x4_t _cd = vld1q_f32(pC + c_hstep * 6 + 4); + float32x4_t _ce = vld1q_f32(pC + c_hstep * 7); + float32x4_t _cf = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + transpose8x4_ps(_c8, _c9, _ca, _cb, _cc, _cd, _ce, _cf); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 4 * 2); + float32x4_t _c3 = vld1q_f32(pC + 4 * 3); + float32x4_t _c4 = vld1q_f32(pC + 4 * 4); + float32x4_t _c5 = vld1q_f32(pC + 4 * 5); + float32x4_t _c6 = vld1q_f32(pC + 4 * 6); + float32x4_t _c7 = vld1q_f32(pC + 4 * 7); + float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _ca = vld1q_f32(pC + c_hstep * 4 + 4 * 2); + float32x4_t _cb = vld1q_f32(pC + c_hstep * 4 + 4 * 3); + float32x4_t _cc = vld1q_f32(pC + c_hstep * 4 + 4 * 4); + float32x4_t _cd = vld1q_f32(pC + c_hstep * 4 + 4 * 5); + float32x4_t _ce = vld1q_f32(pC + c_hstep * 4 + 4 * 6); + float32x4_t _cf = vld1q_f32(pC + c_hstep * 4 + 4 * 7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + float32x4_t _c4 = vdupq_n_f32(pC[4]); + float32x4_t _c5 = vdupq_n_f32(pC[5]); + float32x4_t _c6 = vdupq_n_f32(pC[6]); + float32x4_t _c7 = vdupq_n_f32(pC[7]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + 16, _f4); + vst1q_f32(p0 + 20, _f5); + vst1q_f32(p0 + 24, _f6); + vst1q_f32(p0 + 28, _f7); + vst1q_f32(p0 + out_hstep * 4, _f8); + vst1q_f32(p0 + out_hstep * 4 + 4, _f9); + vst1q_f32(p0 + out_hstep * 4 + 8, _fa); + vst1q_f32(p0 + out_hstep * 4 + 12, _fb); + vst1q_f32(p0 + out_hstep * 4 + 16, _fc); + vst1q_f32(p0 + out_hstep * 4 + 20, _fd); + vst1q_f32(p0 + out_hstep * 4 + 24, _fe); + vst1q_f32(p0 + out_hstep * 4 + 28, _ff); + + pp += 64; + p0 += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + transpose4x4_ps(_c4, _c5, _c6, _c7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 4; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 4 + 8); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 4 + 12); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 4 + 4, _f5); + vst1q_f32(p0 + out_hstep * 4 + 8, _f6); + vst1q_f32(p0 + out_hstep * 4 + 12, _f7); + + pp += 32; + p0 += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4_t _c45 = vcombine_f32(_cc4, _cc5); + float32x4_t _c67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc0 = vuzpq_f32(_c01, _c23); + float32x4x2_t _ccc1 = vuzpq_f32(_c45, _c67); + _f0 = vaddq_f32(_f0, _ccc0.val[0]); + _f1 = vaddq_f32(_f1, _ccc0.val[1]); + _f2 = vaddq_f32(_f2, _ccc1.val[0]); + _f3 = vaddq_f32(_f3, _ccc1.val[1]); + pC += 2; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _c1 = vdupq_n_f32(pC[1]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep * 4, _f2); + vst1q_f32(p0 + out_hstep * 4 + 4, _f3); + + pp += 16; + p0 += 8; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + + pp += 8; + p0 += 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); + int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); + int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); + _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); + float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); + float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); + float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); + float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); + float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); + float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); + float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); + float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); + float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); + float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); + float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); + float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + _f8 = vaddq_f32(_f8, _cc4); + _f9 = vaddq_f32(_f9, _cc4); + _fa = vaddq_f32(_fa, _cc5); + _fb = vaddq_f32(_fb, _cc5); + _fc = vaddq_f32(_fc, _cc6); + _fd = vaddq_f32(_fd, _cc6); + _fe = vaddq_f32(_fe, _cc7); + _ff = vaddq_f32(_ff, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _ca = vld1q_f32(pC + c_hstep * 5); + float32x4_t _cb = vld1q_f32(pC + c_hstep * 5 + 4); + float32x4_t _cc = vld1q_f32(pC + c_hstep * 6); + float32x4_t _cd = vld1q_f32(pC + c_hstep * 6 + 4); + float32x4_t _ce = vld1q_f32(pC + c_hstep * 7); + float32x4_t _cf = vld1q_f32(pC + c_hstep * 7 + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + float32x4x4_t _cc0 = vld4q_f32(pC); + float32x4x4_t _cc1 = vld4q_f32(pC + 16); + float32x4x4_t _cc2 = vld4q_f32(pC + c_hstep * 4); + float32x4x4_t _cc3 = vld4q_f32(pC + c_hstep * 4 + 16); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc1.val[0]); + _f2 = vaddq_f32(_f2, _cc0.val[1]); + _f3 = vaddq_f32(_f3, _cc1.val[1]); + _f4 = vaddq_f32(_f4, _cc0.val[2]); + _f5 = vaddq_f32(_f5, _cc1.val[2]); + _f6 = vaddq_f32(_f6, _cc0.val[3]); + _f7 = vaddq_f32(_f7, _cc1.val[3]); + _f8 = vaddq_f32(_f8, _cc2.val[0]); + _f9 = vaddq_f32(_f9, _cc3.val[0]); + _fa = vaddq_f32(_fa, _cc2.val[1]); + _fb = vaddq_f32(_fb, _cc3.val[1]); + _fc = vaddq_f32(_fc, _cc2.val[2]); + _fd = vaddq_f32(_fd, _cc3.val[2]); + _fe = vaddq_f32(_fe, _cc2.val[3]); + _ff = vaddq_f32(_ff, _cc3.val[3]); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep, _f2); + vst1q_f32(p0 + out_hstep + 4, _f3); + vst1q_f32(p0 + out_hstep * 2, _f4); + vst1q_f32(p0 + out_hstep * 2 + 4, _f5); + vst1q_f32(p0 + out_hstep * 3, _f6); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + vst1q_f32(p0 + out_hstep * 4, _f8); + vst1q_f32(p0 + out_hstep * 4 + 4, _f9); + vst1q_f32(p0 + out_hstep * 5, _fa); + vst1q_f32(p0 + out_hstep * 5 + 4, _fb); + vst1q_f32(p0 + out_hstep * 6, _fc); + vst1q_f32(p0 + out_hstep * 6 + 4, _fd); + vst1q_f32(p0 + out_hstep * 7, _fe); + vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + + pp += 64; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); + float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); + float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); + float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); + float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); + float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); + float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); + float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); + float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + _f4 = vaddq_f32(_f4, _cc4); + _f5 = vaddq_f32(_f5, _cc5); + _f6 = vaddq_f32(_f6, _cc6); + _f7 = vaddq_f32(_f7, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 4; + } + if (c_elempack == 4) + { + float32x4x4_t _cc0 = vld4q_f32(pC); + float32x4x4_t _cc1 = vld4q_f32(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc0.val[1]); + _f2 = vaddq_f32(_f2, _cc0.val[2]); + _f3 = vaddq_f32(_f3, _cc0.val[3]); + _f4 = vaddq_f32(_f4, _cc1.val[0]); + _f5 = vaddq_f32(_f5, _cc1.val[1]); + _f6 = vaddq_f32(_f6, _cc1.val[2]); + _f7 = vaddq_f32(_f7, _cc1.val[3]); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 7, _f7); + + pp += 32; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // e0 e1 f0 f1 + // g0 g1 h0 h1 + { + int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _sum13 = vzipq_s32(_sum2, _sum3); + _sum0 = _sum02.val[0]; + _sum1 = _sum02.val[1]; + _sum2 = _sum13.val[0]; + _sum3 = _sum13.val[1]; + } +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // e0 e1 f0 f1 + // g0 g1 h0 h1 + { + int32x4x2_t _t0 = vuzpq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vuzpq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_t0.val[0], _t1.val[0]); + int32x4x2_t _t3 = vzipq_s32(_t1.val[1], _t0.val[1]); + _sum0 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4x2_t _descale01 = vzipq_f32(_descale0, _descale0); + float32x4x2_t _descale23 = vzipq_f32(_descale1, _descale1); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale23.val[0]); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale23.val[1]); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); + float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc0.val[1]); + _f2 = vaddq_f32(_f2, _cc1.val[0]); + _f3 = vaddq_f32(_f3, _cc1.val[1]); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + _f0 = vaddq_f32(_f0, vcombine_f32(_cc0, _cc1)); + _f1 = vaddq_f32(_f1, vcombine_f32(_cc2, _cc3)); + _f2 = vaddq_f32(_f2, vcombine_f32(_cc4, _cc5)); + _f3 = vaddq_f32(_f3, vcombine_f32(_cc6, _cc7)); + pC += 2; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4x2_t _c01 = vzipq_f32(_c0, _c1); + float32x4x2_t _c23 = vzipq_f32(_c2, _c3); + _f0 = vaddq_f32(_f0, _c01.val[0]); + _f1 = vaddq_f32(_f1, _c01.val[1]); + _f2 = vaddq_f32(_f2, _c23.val[0]); + _f3 = vaddq_f32(_f3, _c23.val[1]); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + float32x2_t _cc0 = vld1_f32(pC); + _c0 = vcombine_f32(_cc0, _cc0); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + pC += 2; + } + } + + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f2)); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f2)); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f3)); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f3)); + + pp += 16; + p0 += 2; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + + pp += 8; + p0++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0]); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + 16); + float32x4_t _c5 = vld1q_f32(pC + 20); + float32x4_t _c6 = vld1q_f32(pC + 24); + float32x4_t _c7 = vld1q_f32(pC + 28); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + float32x4_t _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + float32x4_t _c4 = vdupq_n_f32(pC[4]); + float32x4_t _c5 = vdupq_n_f32(pC[5]); + float32x4_t _c6 = vdupq_n_f32(pC[6]); + float32x4_t _c7 = vdupq_n_f32(pC[7]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + 16, _f4); + vst1q_f32(p0 + 20, _f5); + vst1q_f32(p0 + 24, _f6); + vst1q_f32(p0 + 28, _f7); + + pp += 32; + p0 += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + c_hstep * 1); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + float32x4_t _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + + pp += 16; + p0 += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_c01, _c23); + _f0 = vaddq_f32(_f0, _cc.val[0]); + _f1 = vaddq_f32(_f1, _cc.val[1]); + pC += 2; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + float32x4_t _c1 = vdupq_n_f32(pC[1]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + + pp += 8; + p0 += 8; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + vst1q_f32(p0, _f0); + + pp += 4; + p0 += 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + float32x4x4_t _cc0 = vld4q_f32(pC); + float32x4x4_t _cc1 = vld4q_f32(pC + 16); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc1.val[0]); + _f2 = vaddq_f32(_f2, _cc0.val[1]); + _f3 = vaddq_f32(_f3, _cc1.val[1]); + _f4 = vaddq_f32(_f4, _cc0.val[2]); + _f5 = vaddq_f32(_f5, _cc1.val[2]); + _f6 = vaddq_f32(_f6, _cc0.val[3]); + _f7 = vaddq_f32(_f7, _cc1.val[3]); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep, _f2); + vst1q_f32(p0 + out_hstep + 4, _f3); + vst1q_f32(p0 + out_hstep * 2, _f4); + vst1q_f32(p0 + out_hstep * 2 + 4, _f5); + vst1q_f32(p0 + out_hstep * 3, _f6); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + + pp += 32; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + _sum1 = vextq_s32(_sum1, _sum1, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + c_hstep * 1); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + if (c_elempack == 4) + { + float32x4x4_t _c = vld4q_f32(pC); + _f0 = vaddq_f32(_f0, _c.val[0]); + _f1 = vaddq_f32(_f1, _c.val[1]); + _f2 = vaddq_f32(_f2, _c.val[2]); + _f3 = vaddq_f32(_f3, _c.val[3]); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + + pp += 16; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + { + int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + _sum0 = _sum01.val[0]; + _sum1 = _sum01.val[1]; + } +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + { + int32x4_t _t0 = vuzpq_s32(_sum0, _sum1).val[0]; + int32x4_t _t1 = vuzpq_s32(_sum1, _sum0).val[1]; + int32x4x2_t _t3 = vuzpq_s32(_t0, _t1); + _sum0 = _t3.val[0]; + _sum1 = _t3.val[1]; + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4x2_t _descale01 = vzipq_f32(_descale, _descale); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc0.val[1]); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + _f0 = vaddq_f32(_f0, vcombine_f32(_cc0, _cc1)); + _f1 = vaddq_f32(_f1, vcombine_f32(_cc2, _cc3)); + pC += 2; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4x2_t _c01 = vzipq_f32(_c0, _c1); + _f0 = vaddq_f32(_f0, _c01.val[0]); + _f1 = vaddq_f32(_f1, _c01.val[1]); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + float32x2_t _cc0 = vld1_f32(pC); + _c0 = vcombine_f32(_cc0, _cc0); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 2; + } + } + + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); + + pp += 8; + p0 += 2; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0]; + c1 = pC[1]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + + // if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep, _f2); + vst1q_f32(p0 + out_hstep + 4, _f3); + + pp += 16; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // TODO neon optimize + float f00 = pp[0] * descale0; + float f01 = pp[1] * descale0; + float f10 = pp[2] * descale1; + float f11 = pp[3] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f00 += c0; + f01 += c0; + f10 += c0; + f11 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f00 += c0; + f01 += c0; + f10 += c1; + f11 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f00 += pC[0]; + f01 += pC[1]; + f10 += pC[c_hstep]; + f11 += pC[c_hstep + 1]; + pC += 2; + } + if (broadcast_type_C == 4) + { + f00 += pC[0]; + f01 += pC[1]; + f10 += pC[0]; + f11 += pC[1]; + pC += 2; + } + } + + p0[0] = f00; + p0[1] = f01; + p0[out_hstep] = f10; + p0[out_hstep + 1] = f11; + + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0]; + f1 += pC[c_hstep]; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0]; + f1 += pC[0]; + pC += 1; + } + } + + p0[0] = f0; + p0[out_hstep] = f1; + + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + // out_elempack == 1 + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + + // if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _f0 = vadd_f32(_f0, vld1_f32(pC)); + pC += 2; + } + } + + vst1_f32(p0, _f0); + + pp += 2; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + f0 += pC[0]; + pC += 1; + } + } + + p0[0] = f0; + + pp += 1; + p0++; + } + } + } +} + +static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_unpack_output_tile_int32_to_fp32_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const float* pC = C; + + // NCNN_LOGE("transpose_unpack_output_tile_int32_to_fp32 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0]); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); + int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); + int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); + _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); + } +#else // __ARM_FEATURE_DOTPROD + + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); + float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); + float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); + float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); + float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); + float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); + float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); + float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); + float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + _f8 = vaddq_f32(_f8, _cc4); + _f9 = vaddq_f32(_f9, _cc4); + _fa = vaddq_f32(_fa, _cc5); + _fb = vaddq_f32(_fb, _cc5); + _fc = vaddq_f32(_fc, _cc6); + _fd = vaddq_f32(_fd, _cc6); + _fe = vaddq_f32(_fe, _cc7); + _ff = vaddq_f32(_ff, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _ca = vld1q_f32(pC + c_hstep * 5); + float32x4_t _cb = vld1q_f32(pC + c_hstep * 5 + 4); + float32x4_t _cc = vld1q_f32(pC + c_hstep * 6); + float32x4_t _cd = vld1q_f32(pC + c_hstep * 6 + 4); + float32x4_t _ce = vld1q_f32(pC + c_hstep * 7); + float32x4_t _cf = vld1q_f32(pC + c_hstep * 7 + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + float32x4x4_t _cc0 = vld4q_f32(pC); + float32x4x4_t _cc1 = vld4q_f32(pC + 16); + float32x4x4_t _cc2 = vld4q_f32(pC + c_hstep * 4); + float32x4x4_t _cc3 = vld4q_f32(pC + c_hstep * 4 + 16); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc1.val[0]); + _f2 = vaddq_f32(_f2, _cc0.val[1]); + _f3 = vaddq_f32(_f3, _cc1.val[1]); + _f4 = vaddq_f32(_f4, _cc0.val[2]); + _f5 = vaddq_f32(_f5, _cc1.val[2]); + _f6 = vaddq_f32(_f6, _cc0.val[3]); + _f7 = vaddq_f32(_f7, _cc1.val[3]); + _f8 = vaddq_f32(_f8, _cc2.val[0]); + _f9 = vaddq_f32(_f9, _cc3.val[0]); + _fa = vaddq_f32(_fa, _cc2.val[1]); + _fb = vaddq_f32(_fb, _cc3.val[1]); + _fc = vaddq_f32(_fc, _cc2.val[2]); + _fd = vaddq_f32(_fd, _cc3.val[2]); + _fe = vaddq_f32(_fe, _cc2.val[3]); + _ff = vaddq_f32(_ff, _cc3.val[3]); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + 8, _f4); + vst1q_f32(p0 + 12, _f6); + vst1q_f32(p0 + 16, _f8); + vst1q_f32(p0 + 20, _fa); + vst1q_f32(p0 + 24, _fc); + vst1q_f32(p0 + 28, _fe); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 4 + 4, _f3); + vst1q_f32(p0 + out_hstep * 4 + 8, _f5); + vst1q_f32(p0 + out_hstep * 4 + 12, _f7); + vst1q_f32(p0 + out_hstep * 4 + 16, _f9); + vst1q_f32(p0 + out_hstep * 4 + 20, _fb); + vst1q_f32(p0 + out_hstep * 4 + 24, _fd); + vst1q_f32(p0 + out_hstep * 4 + 28, _ff); + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else // __ARM_FEATURE_DOTPROD + + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); + float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); + float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); + float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); + float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + + + + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + _f4 = vaddq_f32(_f4, _cc4); + _f5 = vaddq_f32(_f5, _cc5); + _f6 = vaddq_f32(_f6, _cc6); + _f7 = vaddq_f32(_f7, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 4; + } + if (c_elempack == 4) + { + float32x4x4_t _cc0 = vld4q_f32(pC); + float32x4x4_t _cc1 = vld4q_f32(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc0.val[1]); + _f2 = vaddq_f32(_f2, _cc0.val[2]); + _f3 = vaddq_f32(_f3, _cc0.val[3]); + _f4 = vaddq_f32(_f4, _cc1.val[0]); + _f5 = vaddq_f32(_f5, _cc1.val[1]); + _f6 = vaddq_f32(_f6, _cc1.val[2]); + _f7 = vaddq_f32(_f7, _cc1.val[3]); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + 16, _f4); + vst1q_f32(p0 + 20, _f5); + vst1q_f32(p0 + 24, _f6); + vst1q_f32(p0 + 28, _f7); + pp += 32; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#endif // __ARM_FEATURE_DOTPROD + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _ca = vld1q_f32(pC + c_hstep * 5); + float32x4_t _cb = vld1q_f32(pC + c_hstep * 5 + 4); + float32x4_t _cc = vld1q_f32(pC + c_hstep * 6); + float32x4_t _cd = vld1q_f32(pC + c_hstep * 6 + 4); + float32x4_t _ce = vld1q_f32(pC + c_hstep * 7); + float32x4_t _cf = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _ca, _cb, _cc, _cd, _ce, _cf); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c2); + _f2 = vaddq_f32(_f2, _c4); + _f3 = vaddq_f32(_f3, _c6); + _f4 = vaddq_f32(_f4, _c8); + _f5 = vaddq_f32(_f5, _ca); + _f6 = vaddq_f32(_f6, _cc); + _f7 = vaddq_f32(_f7, _ce); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c3); + _fa = vaddq_f32(_fa, _c5); + _fb = vaddq_f32(_fb, _c7); + _fc = vaddq_f32(_fc, _c9); + _fd = vaddq_f32(_fd, _cb); + _fe = vaddq_f32(_fe, _cd); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + 16); + float32x4_t _c5 = vld1q_f32(pC + 20); + float32x4_t _c6 = vld1q_f32(pC + 24); + float32x4_t _c7 = vld1q_f32(pC + 28); + float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _ca = vld1q_f32(pC + c_hstep * 4 + 8); + float32x4_t _cb = vld1q_f32(pC + c_hstep * 4 + 12); + float32x4_t _cc = vld1q_f32(pC + c_hstep * 4 + 16); + float32x4_t _cd = vld1q_f32(pC + c_hstep * 4 + 20); + float32x4_t _ce = vld1q_f32(pC + c_hstep * 4 + 24); + float32x4_t _cf = vld1q_f32(pC + c_hstep * 4 + 28); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + float32x4_t _c4 = vdupq_n_f32(pC[4]); + float32x4_t _c5 = vdupq_n_f32(pC[5]); + float32x4_t _c6 = vdupq_n_f32(pC[6]); + float32x4_t _c7 = vdupq_n_f32(pC[7]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f8); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f9); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _fa); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _fb); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 4 + 4, _fc); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 5 + 4, _fd); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 6 + 4, _fe); + vst1q_f32(p0 + out_hstep * 7, _f7); + vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); + transpose4x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c2); + _f2 = vaddq_f32(_f2, _c4); + _f3 = vaddq_f32(_f3, _c6); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c3); + _f6 = vaddq_f32(_f6, _c5); + _f7 = vaddq_f32(_f7, _c7); + pC += 4; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 4 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 4 + 8); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 4 + 12); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + + pp += 32; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep * 1); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4_t _cc45 = vcombine_f32(_cc4, _cc5); + float32x4_t _cc67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc0 = vuzpq_f32(_cc01, _cc23); + float32x4x2_t _ccc1 = vuzpq_f32(_cc45, _cc67); + _f0 = vaddq_f32(_f0, _ccc0.val[0]); + _f1 = vaddq_f32(_f1, _ccc0.val[1]); + _f2 = vaddq_f32(_f2, _ccc1.val[0]); + _f3 = vaddq_f32(_f3, _ccc1.val[1]); + pC += 2; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _c1 = vdupq_n_f32(pC[1]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f3); + + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + pp += 8; + p0 += out_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0]); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + } + if (broadcast_type_C == 3) + { + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else // __ARM_FEATURE_DOTPROD + + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + float32x4x4_t _cc0 = vld4q_f32(pC); + float32x4x4_t _cc1 = vld4q_f32(pC + 16); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc1.val[0]); + _f2 = vaddq_f32(_f2, _cc0.val[1]); + _f3 = vaddq_f32(_f3, _cc1.val[1]); + _f4 = vaddq_f32(_f4, _cc0.val[2]); + _f5 = vaddq_f32(_f5, _cc1.val[2]); + _f6 = vaddq_f32(_f6, _cc0.val[3]); + _f7 = vaddq_f32(_f7, _cc1.val[3]); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + 8, _f4); + vst1q_f32(p0 + 12, _f6); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 4 + 4, _f3); + vst1q_f32(p0 + out_hstep * 4 + 8, _f5); + vst1q_f32(p0 + out_hstep * 4 + 12, _f7); + + pp += 32; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + _sum1 = vextq_s32(_sum1, _sum1, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + if (c_elempack == 4) + { + float32x4x4_t _c = vld4q_f32(pC); + _f0 = vaddq_f32(_f0, _c.val[0]); + _f1 = vaddq_f32(_f1, _c.val[1]); + _f2 = vaddq_f32(_f2, _c.val[2]); + _f3 = vaddq_f32(_f3, _c.val[3]); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + + pp += 16; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + 16); + float32x4_t _c5 = vld1q_f32(pC + 20); + float32x4_t _c6 = vld1q_f32(pC + 24); + float32x4_t _c7 = vld1q_f32(pC + 28); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + float32x4_t _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + float32x4_t _c4 = vdupq_n_f32(pC[4]); + float32x4_t _c5 = vdupq_n_f32(pC[5]); + float32x4_t _c6 = vdupq_n_f32(pC[6]); + float32x4_t _c7 = vdupq_n_f32(pC[7]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 7, _f7); + + pp += 32; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + float32x4_t _c1 = vdupq_n_f32(pC[1]); + float32x4_t _c2 = vdupq_n_f32(pC[2]); + float32x4_t _c3 = vdupq_n_f32(pC[3]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_cc01, _cc23); + _f0 = vaddq_f32(_f0, _cc.val[0]); + _f1 = vaddq_f32(_f1, _cc.val[1]); + pC += 2; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + float32x4_t _c1 = vdupq_n_f32(pC[1]); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0]); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + vst1q_f32(p0, _f0); + pp += 4; + p0 += out_hstep; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale01 = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0]; + c1 = pC[1]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + +#if __ARM_NEON + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 4 + 4, _f3); + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + + pp += 8; + p0 += out_hstep * 4; + } + } +#endif // __ARM_NEON + if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _sum13 = vzipq_s32(_sum1, _sum3); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum02.val[0]), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum02.val[1]), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum13.val[0]), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum13.val[1]), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + _f1 = vaddq_f32(_f1, _cc); + _f2 = vaddq_f32(_f2, _cc); + _f3 = vaddq_f32(_f3, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4x2_t _c02 = vzipq_f32(_c0, _c2); + float32x4x2_t _c13 = vzipq_f32(_c1, _c3); + _f0 = vaddq_f32(_f0, _c02.val[0]); + _f1 = vaddq_f32(_f1, _c02.val[1]); + _f2 = vaddq_f32(_f2, _c13.val[0]); + _f3 = vaddq_f32(_f3, _c13.val[1]); + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); + float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc0.val[1]); + _f2 = vaddq_f32(_f2, _cc1.val[0]); + _f3 = vaddq_f32(_f3, _cc1.val[1]); + pC += 8; + } + } + + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f2)); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f2)); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f3)); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f3)); + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum01.val[0]), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum01.val[1]), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + _f1 = vaddq_f32(_f1, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4x2_t _c01 = vzipq_f32(_c0, _c1); + _f0 = vaddq_f32(_f0, _c01.val[0]); + _f1 = vaddq_f32(_f1, _c01.val[1]); + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + float32x4x2_t _cc = vzipq_f32(_c0, _c0); + _f0 = vaddq_f32(_f0, _cc.val[0]); + _f1 = vaddq_f32(_f1, _cc.val[1]); + pC += 4; + } + } + + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); + + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // TODO neon optimize + // a0 a1 b0 b1 + + float f00 = pp[0] * descale0; + float f01 = pp[2] * descale1; + float f10 = pp[1] * descale0; + float f11 = pp[3] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f00 += c0; + f01 += c0; + f10 += c0; + f11 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f00 += c0; + f01 += c1; + f10 += c0; + f11 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f00 += pC[0]; + f01 += pC[c_hstep]; + f10 += pC[1]; + f11 += pC[c_hstep + 1]; + pC += 2; + } + if (broadcast_type_C == 4) + { + f00 += pC[0]; + f01 += pC[0]; + f10 += pC[1]; + f11 += pC[1]; + pC += 2; + } + } + + p0[0] = f00; + p0[1] = f01; + p0[out_hstep] = f10; + p0[out_hstep + 1] = f11; + + pp += 4; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0]; + f1 += pC[c_hstep]; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0]; + f1 += pC[0]; + pC += 1; + } + } + + p0[0] = f0; + p0[1] = f1; + + pp += 2; + p0 += out_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0]; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + +#if __ARM_NEON + if (out_elempack == 4) + { + int jj = 0; + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 8, _f2); + vst1q_f32(p0 + out_hstep * 12, _f3); + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + + vst1q_f32(p0, _f0); + pp += 4; + p0 += out_hstep * 4; + } + } +#endif // __ARM_NEON + if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0); + p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1); + p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2); + p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3); + p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0); + p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1); + p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); + p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); + + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _f0 = vadd_f32(_f0, vld1_f32(pC)); + pC += 2; + } + } + + p0[0] = vget_lane_f32(_f0, 0); + p0[out_hstep] = vget_lane_f32(_f0, 1); + + pp += 2; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += pC[0]; + pC += 1; + } + } + + p0[0] = f0; + + pp += 1; + p0 += out_hstep; + } + } + } +} + +static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int i, int max_ii, int j, int max_jj, int k, int max_kk) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + gemm_transB_packed_tile_int8_i8mm(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + gemm_transB_packed_tile_int8_asimddp(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk); + return; + } +#endif + + // NCNN_LOGE("gemm_transB_packed_tile_int8 %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + + const signed char* pAT = AT_tile; + const signed char* pBT = BT_tile; + + int* outptr = topT_tile; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const signed char* pB = pBT; + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( +#if !__ARM_FEATURE_MATMUL_INT8 + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + + "1: \n" +#endif // !__ARM_FEATURE_MATMUL_INT8 + +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v0.16b, v0.16b, v0.16b \n" + "eor v1.16b, v1.16b, v1.16b \n" + "eor v2.16b, v2.16b, v2.16b \n" + "eor v3.16b, v3.16b, v3.16b \n" + "eor v4.16b, v4.16b, v4.16b \n" + "eor v5.16b, v5.16b, v5.16b \n" + "eor v6.16b, v6.16b, v6.16b \n" + "eor v7.16b, v7.16b, v7.16b \n" + "eor v8.16b, v8.16b, v8.16b \n" + "eor v9.16b, v9.16b, v9.16b \n" + "eor v10.16b, v10.16b, v10.16b \n" + "eor v11.16b, v11.16b, v11.16b \n" + "eor v12.16b, v12.16b, v12.16b \n" + "eor v13.16b, v13.16b, v13.16b \n" + "eor v14.16b, v14.16b, v14.16b \n" + "eor v15.16b, v15.16b, v15.16b \n" + + "2: \n" + "ld1 {v16.16b, v17.16b, v18.16b, v19.16b}, [%1], #64 \n" + "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%2], #64 \n" + "smmla v0.4s, v16.16b, v20.16b \n" + "smmla v1.4s, v17.16b, v20.16b \n" + "smmla v2.4s, v16.16b, v21.16b \n" + "smmla v3.4s, v17.16b, v21.16b \n" + "smmla v4.4s, v18.16b, v20.16b \n" + "smmla v5.4s, v19.16b, v20.16b \n" + "smmla v6.4s, v18.16b, v21.16b \n" + "smmla v7.4s, v19.16b, v21.16b \n" + "subs w4, w4, #1 \n" + "smmla v8.4s, v16.16b, v22.16b \n" + "smmla v9.4s, v17.16b, v22.16b \n" + "smmla v10.4s, v16.16b, v23.16b \n" + "smmla v11.4s, v17.16b, v23.16b \n" + "smmla v12.4s, v18.16b, v22.16b \n" + "smmla v13.4s, v19.16b, v22.16b \n" + "smmla v14.4s, v18.16b, v23.16b \n" + "smmla v15.4s, v19.16b, v23.16b \n" + "bne 2b \n" + + "uzp1 v16.4s, v0.4s, v1.4s \n" + "uzp2 v17.4s, v0.4s, v1.4s \n" + "uzp1 v18.4s, v2.4s, v3.4s \n" + "uzp2 v19.4s, v2.4s, v3.4s \n" + "uzp1 v20.4s, v4.4s, v5.4s \n" + "uzp2 v21.4s, v4.4s, v5.4s \n" + "uzp1 v22.4s, v6.4s, v7.4s \n" + "uzp2 v23.4s, v6.4s, v7.4s \n" + "uzp1 v24.4s, v8.4s, v9.4s \n" + "uzp2 v25.4s, v8.4s, v9.4s \n" + "uzp1 v26.4s, v10.4s, v11.4s \n" + "uzp2 v27.4s, v10.4s, v11.4s \n" + "uzp1 v28.4s, v12.4s, v13.4s \n" + "uzp2 v29.4s, v12.4s, v13.4s \n" + "uzp1 v30.4s, v14.4s, v15.4s \n" + "uzp2 v31.4s, v14.4s, v15.4s \n" + + "cmp %w7, #0 \n" + "beq 1f \n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0], #64 \n" + "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0], #64 \n" + "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" + "add v20.4s, v20.4s, v4.4s \n" + "add v21.4s, v21.4s, v5.4s \n" + "add v22.4s, v22.4s, v6.4s \n" + "add v23.4s, v23.4s, v7.4s \n" + "add v24.4s, v24.4s, v8.4s \n" + "add v25.4s, v25.4s, v9.4s \n" + "add v26.4s, v26.4s, v10.4s \n" + "add v27.4s, v27.4s, v11.4s \n" + "add v28.4s, v28.4s, v12.4s \n" + "add v29.4s, v29.4s, v13.4s \n" + "add v30.4s, v30.4s, v14.4s \n" + "add v31.4s, v31.4s, v15.4s \n" + "b 1f \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "2: \n" + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%1], #64 \n" + "ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [%2], #64 \n" + "sdot v16.4s, v0.16b, v4.4b[0] \n" + "sdot v17.4s, v0.16b, v4.4b[1] \n" + "sdot v18.4s, v0.16b, v4.4b[2] \n" + "sdot v19.4s, v0.16b, v4.4b[3] \n" + "sdot v20.4s, v1.16b, v4.4b[0] \n" + "sdot v21.4s, v1.16b, v4.4b[1] \n" + "sdot v22.4s, v1.16b, v4.4b[2] \n" + "sdot v23.4s, v1.16b, v4.4b[3] \n" + "sdot v24.4s, v0.16b, v5.4b[0] \n" + "sdot v25.4s, v0.16b, v5.4b[1] \n" + "sdot v26.4s, v0.16b, v5.4b[2] \n" + "sdot v27.4s, v0.16b, v5.4b[3] \n" + "sdot v28.4s, v1.16b, v5.4b[0] \n" + "sdot v29.4s, v1.16b, v5.4b[1] \n" + "sdot v30.4s, v1.16b, v5.4b[2] \n" + "sdot v31.4s, v1.16b, v5.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v2.16b, v6.4b[0] \n" + "sdot v17.4s, v2.16b, v6.4b[1] \n" + "sdot v18.4s, v2.16b, v6.4b[2] \n" + "sdot v19.4s, v2.16b, v6.4b[3] \n" + "sdot v20.4s, v3.16b, v6.4b[0] \n" + "sdot v21.4s, v3.16b, v6.4b[1] \n" + "sdot v22.4s, v3.16b, v6.4b[2] \n" + "sdot v23.4s, v3.16b, v6.4b[3] \n" + "sdot v24.4s, v2.16b, v7.4b[0] \n" + "sdot v25.4s, v2.16b, v7.4b[1] \n" + "sdot v26.4s, v2.16b, v7.4b[2] \n" + "sdot v27.4s, v2.16b, v7.4b[3] \n" + "sdot v28.4s, v3.16b, v7.4b[0] \n" + "sdot v29.4s, v3.16b, v7.4b[1] \n" + "sdot v30.4s, v3.16b, v7.4b[2] \n" + "sdot v31.4s, v3.16b, v7.4b[3] \n" + "bne 2b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" +#if __ARM_FEATURE_MATMUL_INT8 + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0] \n" + "sub %0, %0, #192 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" + "1: \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v2.16b, v3.16b}, [%2], #32 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v1.16b, v2.4b[0] \n" + "sdot v21.4s, v1.16b, v2.4b[1] \n" + "sdot v22.4s, v1.16b, v2.4b[2] \n" + "sdot v23.4s, v1.16b, v2.4b[3] \n" + "sdot v24.4s, v0.16b, v3.4b[0] \n" + "sdot v25.4s, v0.16b, v3.4b[1] \n" + "sdot v26.4s, v0.16b, v3.4b[2] \n" + "sdot v27.4s, v0.16b, v3.4b[3] \n" + "sdot v28.4s, v1.16b, v3.4b[0] \n" + "sdot v29.4s, v1.16b, v3.4b[1] \n" + "sdot v30.4s, v1.16b, v3.4b[2] \n" + "sdot v31.4s, v1.16b, v3.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull2 v9.8h, v0.16b, v4.16b \n" + "rev64 v2.4s, v0.4s \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull2 v11.8h, v2.16b, v4.16b \n" + "rev64 v6.8h, v4.8h \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull2 v13.8h, v0.16b, v6.16b \n" + "rev64 v3.4s, v1.4s \n" + "smull v14.8h, v2.8b, v6.8b \n" + "smull2 v15.8h, v2.16b, v6.16b \n" + "rev64 v7.8h, v5.8h \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v5.16b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v11.8h, v3.16b, v5.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v7.16b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v15.8h, v3.16b, v7.16b \n" + "ext v0.16b, v0.16b, v0.16b, #8 \n" + "ext v2.16b, v2.16b, v2.16b, #8 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v20.4s, v10.8h \n" + "sadalp v21.4s, v11.8h \n" + "ext v1.16b, v1.16b, v1.16b, #8 \n" + "ext v3.16b, v3.16b, v3.16b, #8 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull2 v9.8h, v0.16b, v4.16b \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull2 v11.8h, v2.16b, v4.16b \n" + "sadalp v24.4s, v12.8h \n" + "sadalp v25.4s, v13.8h \n" + "sadalp v28.4s, v14.8h \n" + "sadalp v29.4s, v15.8h \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull2 v13.8h, v0.16b, v6.16b \n" + "smull v14.8h, v2.8b, v6.8b \n" + "smull2 v15.8h, v2.16b, v6.16b \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v5.16b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v11.8h, v3.16b, v5.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v7.16b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v15.8h, v3.16b, v7.16b \n" + "subs w4, w4, #1 \n" + "sadalp v18.4s, v8.8h \n" + "sadalp v19.4s, v9.8h \n" + "sadalp v22.4s, v10.8h \n" + "sadalp v23.4s, v11.8h \n" + "sadalp v26.4s, v12.8h \n" + "sadalp v27.4s, v13.8h \n" + "sadalp v30.4s, v14.8h \n" + "sadalp v31.4s, v15.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v1.16b}, [%2], #16 \n" + "dup v4.8h, v1.h[0] \n" + "dup v5.8h, v1.h[1] \n" + "dup v6.8h, v1.h[2] \n" + "dup v7.8h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "smull2 v12.8h, v0.16b, v4.16b \n" + "smull2 v13.8h, v0.16b, v5.16b \n" + "smull2 v14.8h, v0.16b, v6.16b \n" + "smull2 v15.8h, v0.16b, v7.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" + "dup v4.8h, v1.h[4] \n" + "dup v5.8h, v1.h[5] \n" + "dup v6.8h, v1.h[6] \n" + "dup v7.8h, v1.h[7] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "smull2 v12.8h, v0.16b, v4.16b \n" + "smull2 v13.8h, v0.16b, v5.16b \n" + "smull2 v14.8h, v0.16b, v6.16b \n" + "smull2 v15.8h, v0.16b, v7.16b \n" + "sadalp v24.4s, v8.8h \n" + "sadalp v25.4s, v9.8h \n" + "sadalp v26.4s, v10.8h \n" + "sadalp v27.4s, v11.8h \n" + "sadalp v28.4s, v12.8h \n" + "sadalp v29.4s, v13.8h \n" + "sadalp v30.4s, v14.8h \n" + "sadalp v31.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "rev64 v1.4s, v0.4s \n" + "rev64 v3.8h, v2.8h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v20.4s, v10.8h \n" + "sadalp v21.4s, v11.8h \n" + "sadalp v24.4s, v12.8h \n" + "sadalp v25.4s, v13.8h \n" + "sadalp v28.4s, v14.8h \n" + "sadalp v29.4s, v15.8h \n" + "ext v0.16b, v0.16b, v0.16b, #8 \n" + "ext v1.16b, v1.16b, v1.16b, #8 \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v18.4s, v8.8h \n" + "sadalp v19.4s, v9.8h \n" + "sadalp v22.4s, v10.8h \n" + "sadalp v23.4s, v11.8h \n" + "sadalp v26.4s, v12.8h \n" + "sadalp v27.4s, v13.8h \n" + "sadalp v30.4s, v14.8h \n" + "sadalp v31.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.8b}, [%2], #8 \n" + "dup v8.8b, v1.b[0] \n" + "dup v9.8b, v1.b[1] \n" + "dup v10.8b, v1.b[2] \n" + "dup v11.8b, v1.b[3] \n" + "dup v12.8b, v1.b[4] \n" + "dup v13.8b, v1.b[5] \n" + "dup v14.8b, v1.b[6] \n" + "dup v15.8b, v1.b[7] \n" + "smull v8.8h, v0.8b, v8.8b \n" + "smull v9.8h, v0.8b, v9.8b \n" + "smull v10.8h, v0.8b, v10.8b \n" + "smull v11.8h, v0.8b, v11.8b \n" + "smull v12.8h, v0.8b, v12.8b \n" + "smull v13.8h, v0.8b, v13.8b \n" + "smull v14.8h, v0.8b, v14.8b \n" + "smull v15.8h, v0.8b, v15.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw v17.4s, v17.4s, v9.4h \n" + "saddw v18.4s, v18.4s, v10.4h \n" + "saddw v19.4s, v19.4s, v11.4h \n" + "saddw2 v20.4s, v20.4s, v8.8h \n" + "saddw2 v21.4s, v21.4s, v9.8h \n" + "saddw2 v22.4s, v22.4s, v10.8h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" + "saddw v24.4s, v24.4s, v12.4h \n" + "saddw v25.4s, v25.4s, v13.4h \n" + "saddw v26.4s, v26.4s, v14.4h \n" + "saddw v27.4s, v27.4s, v15.4h \n" + "saddw2 v28.4s, v28.4s, v12.8h \n" + "saddw2 v29.4s, v29.4s, v13.8h \n" + "saddw2 v30.4s, v30.4s, v14.8h \n" + "saddw2 v31.4s, v31.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v4.8b}, [%2], #8 \n" + "ext v1.8b, v0.8b, v0.8b, #4 \n" + "rev32 v2.4h, v0.4h \n" + "rev64 v3.4h, v0.4h \n" + "rev32 v5.8b, v4.8b \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v1.8b, v4.8b \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull v11.8h, v3.8b, v4.8b \n" + "smull v12.8h, v0.8b, v5.8b \n" + "smull v13.8h, v1.8b, v5.8b \n" + "smull v14.8h, v2.8b, v5.8b \n" + "smull v15.8h, v3.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw2 v21.4s, v21.4s, v10.8h \n" + "saddw v22.4s, v22.4s, v11.4h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" + "saddw v24.4s, v24.4s, v12.4h \n" + "saddw2 v25.4s, v25.4s, v12.8h \n" + "saddw v26.4s, v26.4s, v13.4h \n" + "saddw2 v27.4s, v27.4s, v13.8h \n" + "saddw v28.4s, v28.4s, v14.4h \n" + "saddw2 v29.4s, v29.4s, v14.8h \n" + "saddw v30.4s, v30.4s, v15.4h \n" + "saddw2 v31.4s, v31.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + int32x4_t _sum8; + int32x4_t _sum9; + int32x4_t _suma; + int32x4_t _sumb; + int32x4_t _sumc; + int32x4_t _sumd; + int32x4_t _sume; + int32x4_t _sumf; + +#if __ARM_FEATURE_MATMUL_INT8 + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + } +#else // __ARM_FEATURE_MATMUL_INT8 + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + _sum8 = vdupq_n_s32(0); + _sum9 = vdupq_n_s32(0); + _suma = vdupq_n_s32(0); + _sumb = vdupq_n_s32(0); + _sumc = vdupq_n_s32(0); + _sumd = vdupq_n_s32(0); + _sume = vdupq_n_s32(0); + _sumf = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + } +#endif // __ARM_FEATURE_MATMUL_INT8 + + int kk = 0; +#if __ARM_FEATURE_MATMUL_INT8 + { + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + _sum0 = vmmlaq_s32(_sum0, _pA0, _pB0); + _sum1 = vmmlaq_s32(_sum1, _pA1, _pB0); + _sum2 = vmmlaq_s32(_sum2, _pA0, _pB1); + _sum3 = vmmlaq_s32(_sum3, _pA1, _pB1); + _sum4 = vmmlaq_s32(_sum4, _pA2, _pB0); + _sum5 = vmmlaq_s32(_sum5, _pA3, _pB0); + _sum6 = vmmlaq_s32(_sum6, _pA2, _pB1); + _sum7 = vmmlaq_s32(_sum7, _pA3, _pB1); + _sum8 = vmmlaq_s32(_sum8, _pA0, _pB2); + _sum9 = vmmlaq_s32(_sum9, _pA1, _pB2); + _suma = vmmlaq_s32(_suma, _pA0, _pB3); + _sumb = vmmlaq_s32(_sumb, _pA1, _pB3); + _sumc = vmmlaq_s32(_sumc, _pA2, _pB2); + _sumd = vmmlaq_s32(_sumd, _pA3, _pB2); + _sume = vmmlaq_s32(_sume, _pA2, _pB3); + _sumf = vmmlaq_s32(_sumf, _pA3, _pB3); + + pA += 64; + pB += 64; + } + + int32x4x2_t _ss0 = vuzpq_s32(_sum0, _sum1); + int32x4x2_t _ss1 = vuzpq_s32(_sum2, _sum3); + int32x4x2_t _ss2 = vuzpq_s32(_sum4, _sum5); + int32x4x2_t _ss3 = vuzpq_s32(_sum6, _sum7); + int32x4x2_t _ss4 = vuzpq_s32(_sum8, _sum9); + int32x4x2_t _ss5 = vuzpq_s32(_suma, _sumb); + int32x4x2_t _ss6 = vuzpq_s32(_sumc, _sumd); + int32x4x2_t _ss7 = vuzpq_s32(_sume, _sumf); + + if (k == 0) + { + _sum0 = _ss0.val[0]; + _sum1 = _ss0.val[1]; + _sum2 = _ss1.val[0]; + _sum3 = _ss1.val[1]; + _sum4 = _ss2.val[0]; + _sum5 = _ss2.val[1]; + _sum6 = _ss3.val[0]; + _sum7 = _ss3.val[1]; + _sum8 = _ss4.val[0]; + _sum9 = _ss4.val[1]; + _suma = _ss5.val[0]; + _sumb = _ss5.val[1]; + _sumc = _ss6.val[0]; + _sumd = _ss6.val[1]; + _sume = _ss7.val[0]; + _sumf = _ss7.val[1]; + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + _sum8 = vld1q_s32(outptr + 32); + _sum9 = vld1q_s32(outptr + 36); + _suma = vld1q_s32(outptr + 40); + _sumb = vld1q_s32(outptr + 44); + _sumc = vld1q_s32(outptr + 48); + _sumd = vld1q_s32(outptr + 52); + _sume = vld1q_s32(outptr + 56); + _sumf = vld1q_s32(outptr + 60); + + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + _sum4 = vaddq_s32(_sum4, _ss2.val[0]); + _sum5 = vaddq_s32(_sum5, _ss2.val[1]); + _sum6 = vaddq_s32(_sum6, _ss3.val[0]); + _sum7 = vaddq_s32(_sum7, _ss3.val[1]); + _sum8 = vaddq_s32(_sum8, _ss4.val[0]); + _sum9 = vaddq_s32(_sum9, _ss4.val[1]); + _suma = vaddq_s32(_suma, _ss5.val[0]); + _sumb = vaddq_s32(_sumb, _ss5.val[1]); + _sumc = vaddq_s32(_sumc, _ss6.val[0]); + _sumd = vaddq_s32(_sumd, _ss6.val[1]); + _sume = vaddq_s32(_sume, _ss7.val[0]); + _sumf = vaddq_s32(_sumf, _ss7.val[1]); + } + } +#elif __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 2222 3333 4444 5555 6666 7777 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB0, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB0, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB0, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB0, 3); + _sum8 = vdotq_laneq_s32(_sum8, _pA0, _pB1, 0); + _sum9 = vdotq_laneq_s32(_sum9, _pA0, _pB1, 1); + _suma = vdotq_laneq_s32(_suma, _pA0, _pB1, 2); + _sumb = vdotq_laneq_s32(_sumb, _pA0, _pB1, 3); + _sumc = vdotq_laneq_s32(_sumc, _pA1, _pB1, 0); + _sumd = vdotq_laneq_s32(_sumd, _pA1, _pB1, 1); + _sume = vdotq_laneq_s32(_sume, _pA1, _pB1, 2); + _sumf = vdotq_laneq_s32(_sumf, _pA1, _pB1, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA2, _pB2, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA2, _pB2, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA2, _pB2, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA2, _pB2, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA3, _pB2, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA3, _pB2, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA3, _pB2, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA3, _pB2, 3); + _sum8 = vdotq_laneq_s32(_sum8, _pA2, _pB3, 0); + _sum9 = vdotq_laneq_s32(_sum9, _pA2, _pB3, 1); + _suma = vdotq_laneq_s32(_suma, _pA2, _pB3, 2); + _sumb = vdotq_laneq_s32(_sumb, _pA2, _pB3, 3); + _sumc = vdotq_laneq_s32(_sumc, _pA3, _pB3, 0); + _sumd = vdotq_laneq_s32(_sumd, _pA3, _pB3, 1); + _sume = vdotq_laneq_s32(_sume, _pA3, _pB3, 2); + _sumf = vdotq_laneq_s32(_sumf, _pA3, _pB3, 3); + + pA += 64; + pB += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 2222 3333 4444 5555 6666 7777 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB0, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB0, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB0, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB0, 3); + _sum8 = vdotq_laneq_s32(_sum8, _pA0, _pB1, 0); + _sum9 = vdotq_laneq_s32(_sum9, _pA0, _pB1, 1); + _suma = vdotq_laneq_s32(_suma, _pA0, _pB1, 2); + _sumb = vdotq_laneq_s32(_sumb, _pA0, _pB1, 3); + _sumc = vdotq_laneq_s32(_sumc, _pA1, _pB1, 0); + _sumd = vdotq_laneq_s32(_sumd, _pA1, _pB1, 1); + _sume = vdotq_laneq_s32(_sume, _pA1, _pB1, 2); + _sumf = vdotq_laneq_s32(_sumf, _pA1, _pB1, 3); + +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB2 = vld1q_s8(pB + 16); + + // aabbccdd eeffgghh + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + + // aabbccdd eeffgghh + // ccddaabb gghheeff + + int8x16_t _pA3 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA2))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB3 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s7 = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s8 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _s9 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sa = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _sb = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sc = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sd = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB1)); + int16x8_t _se = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sf = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB1)); + + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), vget_low_s8(_pB2)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), vget_high_s8(_pB2)); + _s2 = vmlal_s8(_s2, vget_high_s8(_pA2), vget_low_s8(_pB2)); + _s3 = vmlal_s8(_s3, vget_low_s8(_pA2), vget_high_s8(_pB2)); + _s4 = vmlal_s8(_s4, vget_low_s8(_pA3), vget_low_s8(_pB2)); + _s5 = vmlal_s8(_s5, vget_high_s8(_pA3), vget_high_s8(_pB2)); + _s6 = vmlal_s8(_s6, vget_high_s8(_pA3), vget_low_s8(_pB2)); + _s7 = vmlal_s8(_s7, vget_low_s8(_pA3), vget_high_s8(_pB2)); + _s8 = vmlal_s8(_s8, vget_low_s8(_pA2), vget_low_s8(_pB3)); + _s9 = vmlal_s8(_s9, vget_high_s8(_pA2), vget_high_s8(_pB3)); + _sa = vmlal_s8(_sa, vget_high_s8(_pA2), vget_low_s8(_pB3)); + _sb = vmlal_s8(_sb, vget_low_s8(_pA2), vget_high_s8(_pB3)); + _sc = vmlal_s8(_sc, vget_low_s8(_pA3), vget_low_s8(_pB3)); + _sd = vmlal_s8(_sd, vget_high_s8(_pA3), vget_high_s8(_pB3)); + _se = vmlal_s8(_se, vget_high_s8(_pA3), vget_low_s8(_pB3)); + _sf = vmlal_s8(_sf, vget_low_s8(_pA3), vget_high_s8(_pB3)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + _sum8 = vpadalq_s16(_sum8, _s8); + _sum9 = vpadalq_s16(_sum9, _s9); + _suma = vpadalq_s16(_suma, _sa); + _sumb = vpadalq_s16(_sumb, _sb); + _sumc = vpadalq_s16(_sumc, _sc); + _sumd = vpadalq_s16(_sumd, _sd); + _sume = vpadalq_s16(_sume, _se); + _sumf = vpadalq_s16(_sumf, _sf); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + // aabbccdd eeffgghh + + // 00112233 44556677 + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 0))); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 1))); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 2))); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 3))); + int16x8_t _s4 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 0))); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 1))); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 2))); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB)), 3))); + int16x8_t _s8 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 0))); + int16x8_t _s9 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 1))); + int16x8_t _sa = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 2))); + int16x8_t _sb = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 3))); + int16x8_t _sc = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 0))); + int16x8_t _sd = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 1))); + int16x8_t _se = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 2))); + int16x8_t _sf = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB)), 3))); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + _sum8 = vpadalq_s16(_sum8, _s8); + _sum9 = vpadalq_s16(_sum9, _s9); + _suma = vpadalq_s16(_suma, _sa); + _sumb = vpadalq_s16(_sumb, _sb); + _sumc = vpadalq_s16(_sumc, _sc); + _sumd = vpadalq_s16(_sumd, _sd); + _sume = vpadalq_s16(_sume, _se); + _sumf = vpadalq_s16(_sumf, _sf); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + + // aabbccdd eeffgghh + + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + + // 00112233 44556677 + + // 33221100 77665544 + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB0)); + int16x8_t _s7 = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB0)); + int16x8_t _s8 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _s9 = vmull_s8(vget_high_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sa = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB1)); + int16x8_t _sb = vmull_s8(vget_low_s8(_pA0), vget_high_s8(_pB1)); + int16x8_t _sc = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sd = vmull_s8(vget_high_s8(_pA1), vget_high_s8(_pB1)); + int16x8_t _se = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB1)); + int16x8_t _sf = vmull_s8(vget_low_s8(_pA1), vget_high_s8(_pB1)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + _sum8 = vpadalq_s16(_sum8, _s8); + _sum9 = vpadalq_s16(_sum9, _s9); + _suma = vpadalq_s16(_suma, _sa); + _sumb = vpadalq_s16(_sumb, _sb); + _sumc = vpadalq_s16(_sumc, _sc); + _sumd = vpadalq_s16(_sumd, _sd); + _sume = vpadalq_s16(_sume, _se); + _sumf = vpadalq_s16(_sumf, _sf); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + // int8x8_t _pB0 = vld1_s8(pB); + + // abcd efgh + // 0123 4567 + + int16x8_t _s01 = vmull_s8(_pA, vdup_n_s8(pB[0])); + int16x8_t _s23 = vmull_s8(_pA, vdup_n_s8(pB[1])); + int16x8_t _s45 = vmull_s8(_pA, vdup_n_s8(pB[2])); + int16x8_t _s67 = vmull_s8(_pA, vdup_n_s8(pB[3])); + int16x8_t _s89 = vmull_s8(_pA, vdup_n_s8(pB[4])); + int16x8_t _sab = vmull_s8(_pA, vdup_n_s8(pB[5])); + int16x8_t _scd = vmull_s8(_pA, vdup_n_s8(pB[6])); + int16x8_t _sef = vmull_s8(_pA, vdup_n_s8(pB[7])); + + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s23)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s45)); + _sum3 = vaddw_s16(_sum3, vget_low_s16(_s67)); + _sum4 = vaddw_s16(_sum4, vget_high_s16(_s01)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s23)); + _sum6 = vaddw_s16(_sum6, vget_high_s16(_s45)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); + _sum8 = vaddw_s16(_sum8, vget_low_s16(_s89)); + _sum9 = vaddw_s16(_sum9, vget_low_s16(_sab)); + _suma = vaddw_s16(_suma, vget_low_s16(_scd)); + _sumb = vaddw_s16(_sumb, vget_low_s16(_sef)); + _sumc = vaddw_s16(_sumc, vget_high_s16(_s89)); + _sumd = vaddw_s16(_sumd, vget_high_s16(_sab)); + _sume = vaddw_s16(_sume, vget_high_s16(_scd)); + _sumf = vaddw_s16(_sumf, vget_high_s16(_sef)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vld1_s8(pB); + + // abcd efgh + // efgh abcd + // cdab ghef + // ghef cdab + + // 0123 4567 + // 3210 7654 + + // abcdefgh -> ghefcdab -> cdabghef + + int8x8_t _pA1 = vext_s8(_pA0, _pA0, 4); + int8x8_t _pA2 = vreinterpret_s8_s16(vrev32_s16(vreinterpret_s16_s8(_pA0))); + int8x8_t _pA3 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pA0))); + + // 01234567 -> 32107654 + + int8x8_t _pB1 = vrev32_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA0, _pB0); + int16x8_t _s23 = vmull_s8(_pA1, _pB0); + int16x8_t _s45 = vmull_s8(_pA2, _pB0); + int16x8_t _s67 = vmull_s8(_pA3, _pB0); + int16x8_t _s89 = vmull_s8(_pA0, _pB1); + int16x8_t _sab = vmull_s8(_pA1, _pB1); + int16x8_t _scd = vmull_s8(_pA2, _pB1); + int16x8_t _sef = vmull_s8(_pA3, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s45)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s45)); + _sum6 = vaddw_s16(_sum6, vget_low_s16(_s67)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); + _sum8 = vaddw_s16(_sum8, vget_low_s16(_s89)); + _sum9 = vaddw_s16(_sum9, vget_high_s16(_s89)); + _suma = vaddw_s16(_suma, vget_low_s16(_sab)); + _sumb = vaddw_s16(_sumb, vget_high_s16(_sab)); + _sumc = vaddw_s16(_sumc, vget_low_s16(_scd)); + _sumd = vaddw_s16(_sumd, vget_high_s16(_scd)); + _sume = vaddw_s16(_sume, vget_low_s16(_sef)); + _sumf = vaddw_s16(_sumf, vget_high_s16(_sef)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + vst1q_s32(outptr + 32, _sum8); + vst1q_s32(outptr + 36, _sum9); + vst1q_s32(outptr + 40, _suma); + vst1q_s32(outptr + 44, _sumb); + vst1q_s32(outptr + 48, _sumc); + vst1q_s32(outptr + 52, _sumd); + vst1q_s32(outptr + 56, _sume); + vst1q_s32(outptr + 60, _sumf); + + outptr += 64; +#endif // NCNN_GNU_INLINE_ASM + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0] \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + + "1: \n" +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "2: \n" + "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%1], #64 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "smmla v24.4s, v0.16b, v4.16b \n" + "smmla v25.4s, v1.16b, v4.16b \n" + "smmla v26.4s, v0.16b, v5.16b \n" + "smmla v27.4s, v1.16b, v5.16b \n" + "subs w4, w4, #1 \n" + "smmla v28.4s, v2.16b, v4.16b \n" + "smmla v29.4s, v3.16b, v4.16b \n" + "smmla v30.4s, v2.16b, v5.16b \n" + "smmla v31.4s, v3.16b, v5.16b \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "sdot v16.4s, v0.16b, v4.4b[0] \n" + "sdot v17.4s, v0.16b, v4.4b[1] \n" + "sdot v18.4s, v0.16b, v4.4b[2] \n" + "sdot v19.4s, v0.16b, v4.4b[3] \n" + "sdot v20.4s, v1.16b, v4.4b[0] \n" + "sdot v21.4s, v1.16b, v4.4b[1] \n" + "sdot v22.4s, v1.16b, v4.4b[2] \n" + "sdot v23.4s, v1.16b, v4.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v2.16b, v5.4b[0] \n" + "sdot v17.4s, v2.16b, v5.4b[1] \n" + "sdot v18.4s, v2.16b, v5.4b[2] \n" + "sdot v19.4s, v2.16b, v5.4b[3] \n" + "sdot v20.4s, v3.16b, v5.4b[0] \n" + "sdot v21.4s, v3.16b, v5.4b[1] \n" + "sdot v22.4s, v3.16b, v5.4b[2] \n" + "sdot v23.4s, v3.16b, v5.4b[3] \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + "bne 2b \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "uzp1 v0.4s, v24.4s, v25.4s \n" + "uzp2 v1.4s, v24.4s, v25.4s \n" + "uzp1 v2.4s, v26.4s, v27.4s \n" + "uzp2 v3.4s, v26.4s, v27.4s \n" + "uzp1 v4.4s, v28.4s, v29.4s \n" + "uzp2 v5.4s, v28.4s, v29.4s \n" + "uzp1 v6.4s, v30.4s, v31.4s \n" + "uzp2 v7.4s, v30.4s, v31.4s \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" + "add v20.4s, v20.4s, v4.4s \n" + "add v21.4s, v21.4s, v5.4s \n" + "add v22.4s, v22.4s, v6.4s \n" + "add v23.4s, v23.4s, v7.4s \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v1.16b, v2.4b[0] \n" + "sdot v21.4s, v1.16b, v2.4b[1] \n" + "sdot v22.4s, v1.16b, v2.4b[2] \n" + "sdot v23.4s, v1.16b, v2.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v4.16b}, [%2], #16 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "rev64 v2.4s, v0.4s \n" + "smull v10.8h, v2.8b, v4.8b \n" + "ext v5.16b, v4.16b, v4.16b, #8 \n" + "smull2 v9.8h, v0.16b, v5.16b \n" + "rev64 v6.8h, v4.8h \n" + "smull2 v11.8h, v2.16b, v5.16b \n" + "ext v7.16b, v6.16b, v6.16b, #8 \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull v14.8h, v2.8b, v6.8b \n" + "rev64 v3.4s, v1.4s \n" + "smull2 v13.8h, v0.16b, v7.16b \n" + "smull2 v15.8h, v2.16b, v7.16b \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v4.16b \n" + "smlal2 v11.8h, v3.16b, v4.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v6.16b \n" + "smlal2 v15.8h, v3.16b, v6.16b \n" + "subs w4, w4, #1 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v23.4s, v15.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v1.8b}, [%2], #8 \n" + "dup v4.8h, v1.h[0] \n" + "dup v5.8h, v1.h[1] \n" + "dup v6.8h, v1.h[2] \n" + "dup v7.8h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "smull2 v12.8h, v0.16b, v4.16b \n" + "smull2 v13.8h, v0.16b, v5.16b \n" + "smull2 v14.8h, v0.16b, v6.16b \n" + "smull2 v15.8h, v0.16b, v7.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.16b}, [%1], #16 \n" + "ld1r {v2.2d}, [%2] \n" + "add %2, %2, #8 \n" + "rev64 v1.4s, v0.4s \n" + "rev64 v3.8h, v2.8h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.8b}, [%2] \n" + "add %2, %2, #4 \n" + "dup v8.8b, v1.b[0] \n" + "dup v9.8b, v1.b[1] \n" + "dup v10.8b, v1.b[2] \n" + "dup v11.8b, v1.b[3] \n" + "smull v8.8h, v0.8b, v8.8b \n" + "smull v9.8h, v0.8b, v9.8b \n" + "smull v10.8h, v0.8b, v10.8b \n" + "smull v11.8h, v0.8b, v11.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw v17.4s, v17.4s, v9.4h \n" + "saddw v18.4s, v18.4s, v10.4h \n" + "saddw v19.4s, v19.4s, v11.4h \n" + "saddw2 v20.4s, v20.4s, v8.8h \n" + "saddw2 v21.4s, v21.4s, v9.8h \n" + "saddw2 v22.4s, v22.4s, v10.8h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1r {v4.2s}, [%2] \n" + "add %2, %2, #4 \n" + "rev32 v1.4h, v0.4h \n" + "rev64 v5.8b, v4.8b \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v1.8b, v4.8b \n" + "smull v10.8h, v0.8b, v5.8b \n" + "smull v11.8h, v1.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw2 v21.4s, v21.4s, v10.8h \n" + "saddw v22.4s, v22.4s, v11.4h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0!, {d16-d23} \n" + "vldm %0, {d24-d31} \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + "veor q12, q12 \n" + "veor q13, q13 \n" + "veor q14, q14 \n" + "veor q15, q15 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + ".align 4 \n" + "2: \n" + "pld [%1, #256] \n" + "vld1.s8 {d0-d3}, [%1 :64]! \n" + "pld [%2, #128] \n" + "vld1.s8 {d4-d5}, [%2]! \n" + "vmull.s8 q4, d0, d4 \n" + "vrev64.32 q3, q0 \n" + "vmull.s8 q5, d1, d4 \n" + "vmull.s8 q6, d6, d4 \n" + "vmull.s8 q7, d7, d4 \n" + "vrev64.32 q0, q1 \n" + "vmlal.s8 q4, d2, d5 \n" + "vmlal.s8 q5, d3, d5 \n" + "vmlal.s8 q6, d0, d5 \n" + "vmlal.s8 q7, d1, d5 \n" + "vrev64.16 q2, q2 \n" + "vpadal.s16 q8, q4 \n" + "vrev64.32 q1, q3 \n" + "vpadal.s16 q9, q5 \n" + "vmull.s8 q4, d6, d4 \n" + "vpadal.s16 q10, q6 \n" + "vmull.s8 q5, d7, d4 \n" + "vpadal.s16 q11, q7 \n" + "vmull.s8 q6, d2, d4 \n" + "vmull.s8 q7, d3, d4 \n" + "vrev64.32 q3, q0 \n" + "vmlal.s8 q4, d0, d5 \n" + "vmlal.s8 q5, d1, d5 \n" + "vmlal.s8 q6, d6, d5 \n" + "vmlal.s8 q7, d7, d5 \n" + "subs r4, r4, #1 \n" + "vpadal.s16 q14, q4 \n" + "vpadal.s16 q15, q5 \n" + "vpadal.s16 q12, q6 \n" + "vpadal.s16 q13, q7 \n" + "bne 2b \n" + + "3: \n" + "and r4, %6, #2 \n" // r4 = remain = max_kk & 2 + "cmp r4, #0 \n" + "beq 4f \n" + + // kk += 2 part + "vld1.s8 {d0-d1}, [%1 :64]! \n" + "vld1.s8 {d4}, [%2]! \n" + "vrev64.32 q1, q0 \n" + "vrev64.16 d5, d4 \n" + "vmull.s8 q4, d0, d4 \n" + "vmull.s8 q5, d1, d4 \n" + "vmull.s8 q6, d2, d4 \n" + "vmull.s8 q7, d3, d4 \n" + "vpadal.s16 q8, q4 \n" + "vpadal.s16 q9, q5 \n" + "vpadal.s16 q10, q6 \n" + "vpadal.s16 q11, q7 \n" + "vmull.s8 q4, d0, d5 \n" + "vmull.s8 q5, d1, d5 \n" + "vmull.s8 q6, d2, d5 \n" + "vmull.s8 q7, d3, d5 \n" + "vpadal.s16 q12, q4 \n" + "vpadal.s16 q13, q5 \n" + "vpadal.s16 q14, q6 \n" + "vpadal.s16 q15, q7 \n" + + "4: \n" + "and r4, %6, #1 \n" // r4 = remain = max_kk & 1 + "cmp r4, #0 \n" + "beq 5f \n" + + // kk += 1 part + "vld1.s8 {d0}, [%1 :64]! \n" + "vld1.s32 {d2[]}, [%2]! \n" + "vrev64.16 d1, d0 \n" + "vrev64.8 d3, d2 \n" + "vext.s8 d1, d1, #4 \n" + "vmull.s8 q4, d0, d2 \n" + "vmull.s8 q5, d1, d2 \n" + "vmull.s8 q6, d0, d3 \n" + "vmull.s8 q7, d1, d3 \n" + "vaddw.s16 q8, d8 \n" + "vaddw.s16 q9, d9 \n" + "vaddw.s16 q10, d10 \n" + "vaddw.s16 q11, d11 \n" + "vaddw.s16 q12, d12 \n" + "vaddw.s16 q13, d13 \n" + "vaddw.s16 q14, d14 \n" + "vaddw.s16 q15, d15 \n" + + "5: \n" + "vstm %0!, {d16-d23} \n" + "vstm %0!, {d24-d31} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _s0 = vdupq_n_s32(0); + int32x4_t _s1 = vdupq_n_s32(0); + int32x4_t _s2 = vdupq_n_s32(0); + int32x4_t _s3 = vdupq_n_s32(0); + int32x4_t _s4 = vdupq_n_s32(0); + int32x4_t _s5 = vdupq_n_s32(0); + int32x4_t _s6 = vdupq_n_s32(0); + int32x4_t _s7 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb ..... hhhhhhhh + // 00000000 11111111 22222222 33333333 + + _s0 = vmmlaq_s32(_s0, _pA0, _pB0); + _s1 = vmmlaq_s32(_s1, _pA1, _pB0); + _s2 = vmmlaq_s32(_s2, _pA0, _pB1); + _s3 = vmmlaq_s32(_s3, _pA1, _pB1); + _s4 = vmmlaq_s32(_s4, _pA2, _pB0); + _s5 = vmmlaq_s32(_s5, _pA3, _pB0); + _s6 = vmmlaq_s32(_s6, _pA2, _pB1); + _s7 = vmmlaq_s32(_s7, _pA3, _pB1); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB0, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB0, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB0, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB0, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA2, _pB1, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA2, _pB1, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA2, _pB1, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA2, _pB1, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA3, _pB1, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA3, _pB1, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA3, _pB1, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA3, _pB1, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 64; + pB += 32; + } +#if __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _ss0 = vuzpq_s32(_s0, _s1); + int32x4x2_t _ss1 = vuzpq_s32(_s2, _s3); + int32x4x2_t _ss2 = vuzpq_s32(_s4, _s5); + int32x4x2_t _ss3 = vuzpq_s32(_s6, _s7); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + _sum4 = vaddq_s32(_sum4, _ss2.val[0]); + _sum5 = vaddq_s32(_sum5, _ss2.val[1]); + _sum6 = vaddq_s32(_sum6, _ss3.val[0]); + _sum7 = vaddq_s32(_sum7, _ss3.val[1]); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB = vld1q_s8(pB); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 2222 3333 + + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB, 3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x16_t _pB02 = vld1q_s8(pB); + + // aabbccdd eeffgghh + + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + int8x16_t _pA3 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA2))); + + // 00112233 44556677 + + // 33221100 77665544 + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB02)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB02)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB13)); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB13)); + int16x8_t _s6 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB13)); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA1), vget_low_s8(_pB13)); + + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), vget_high_s8(_pB02)); + _s2 = vmlal_s8(_s2, vget_low_s8(_pA3), vget_high_s8(_pB02)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA3), vget_high_s8(_pB02)); + _s4 = vmlal_s8(_s4, vget_low_s8(_pA2), vget_high_s8(_pB13)); + _s5 = vmlal_s8(_s5, vget_high_s8(_pA2), vget_high_s8(_pB13)); + _s6 = vmlal_s8(_s6, vget_low_s8(_pA3), vget_high_s8(_pB13)); + _s7 = vmlal_s8(_s7, vget_high_s8(_pA3), vget_high_s8(_pB13)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + // aabbccdd eeffgghh + + // 00112233 + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2))); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3))); + int16x8_t _s4 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + int16x8_t _s6 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2))); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA), vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3))); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x8_t _pB0 = vld1_s8(pB); + + // aabbccdd eeffgghh + + // ccddaabb gghheeff + + int8x16_t _pA1 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA0))); + + // 00112233 + + // 33221100 + + int8x8_t _pB1 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), _pB0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), _pB0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA1), _pB0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA1), _pB0); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA0), _pB1); + int16x8_t _s5 = vmull_s8(vget_high_s8(_pA0), _pB1); + int16x8_t _s6 = vmull_s8(vget_low_s8(_pA1), _pB1); + int16x8_t _s7 = vmull_s8(vget_high_s8(_pA1), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + // int8x8_t _pB0 = vreinterpret_s32_s8(vld1_dup_s32(pB)); + + // abcdefgh + + // 0123 + + int16x8_t _s01 = vmull_s8(_pA0, vdup_n_s8(pB[0])); + int16x8_t _s23 = vmull_s8(_pA0, vdup_n_s8(pB[1])); + int16x8_t _s45 = vmull_s8(_pA0, vdup_n_s8(pB[2])); + int16x8_t _s67 = vmull_s8(_pA0, vdup_n_s8(pB[3])); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s23)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s45)); + _sum3 = vaddw_s16(_sum3, vget_low_s16(_s67)); + _sum4 = vaddw_s16(_sum4, vget_high_s16(_s01)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s23)); + _sum6 = vaddw_s16(_sum6, vget_high_s16(_s45)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + // int8x8_t _pB0 = vld1_s8(pB); + // _pB0 = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pB0), vreinterpret_s32_s8(_pB0)).val[0]); + + // abcdefgh -> cdabghef + int8x8_t _pA1 = vreinterpret_s8_s16(vrev32_s16(vreinterpret_s16_s8(_pA0))); + + // 01230123 -> 32103210 + int8x8_t _pB1 = vrev64_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA0, _pB0); + int16x8_t _s23 = vmull_s8(_pA1, _pB0); + int16x8_t _s45 = vmull_s8(_pA0, _pB1); + int16x8_t _s67 = vmull_s8(_pA1, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s45)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s45)); + _sum6 = vaddw_s16(_sum6, vget_low_s16(_s67)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + + outptr += 32; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _s0 = vdupq_n_s32(0); + int32x4_t _s1 = vdupq_n_s32(0); + int32x4_t _s2 = vdupq_n_s32(0); + int32x4_t _s3 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb ..... hhhhhhhh + // 00000000 11111111 + + _s0 = vmmlaq_s32(_s0, _pA0, _pB); + _s1 = vmmlaq_s32(_s1, _pA1, _pB); + _s2 = vmmlaq_s32(_s2, _pA2, _pB); + _s3 = vmmlaq_s32(_s3, _pA3, _pB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA1, _pB, 0); + _sum3 = vdotq_laneq_s32(_sum3, _pA1, _pB, 1); + + _sum0 = vdotq_laneq_s32(_sum0, _pA2, _pB, 2); + _sum1 = vdotq_laneq_s32(_sum1, _pA2, _pB, 3); + _sum2 = vdotq_laneq_s32(_sum2, _pA3, _pB, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA3, _pB, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 64; + pB += 16; + } +#if __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _ss0 = vuzpq_s32(_s0, _s1); + int32x4x2_t _ss1 = vuzpq_s32(_s2, _s3); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x8_t _pB = vld1_s8(pB); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 1111 + + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA0, _pB, 1); + _sum2 = vdotq_lane_s32(_sum2, _pA1, _pB, 0); + _sum3 = vdotq_lane_s32(_sum3, _pA1, _pB, 1); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x8_t _pB = vld1_s8(pB); + + // aabbccdd eeffgghh aabbccdd eeffgghh + + // 00112233 -> 00110011 22332233 + + // 11001100 33223322 + + int32x2x2_t _pBB = vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)); + int8x16_t _pB02 = vreinterpretq_s8_s32(vcombine_s32(_pBB.val[0], _pBB.val[1])); + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB02)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB13)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA0), vget_low_s8(_pB13)); + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), vget_high_s8(_pB02)); + _s2 = vmlal_s8(_s2, vget_low_s8(_pA2), vget_high_s8(_pB13)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA2), vget_high_s8(_pB13)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 8; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int16x4_t _pB = vreinterpret_s16_s32(vld1_dup_s32((const int*)pB)); + + int16x4x2_t _pB01 = vuzp_s16(_pB, _pB); + int8x8_t _pB0 = vreinterpret_s8_s16(_pB01.val[0]); + int8x8_t _pB1 = vreinterpret_s8_s16(_pB01.val[1]); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB0); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), _pB1); + int16x8_t _s2 = vmull_s8(vget_high_s8(_pA), _pB0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // aabbccdd eeffgghh + + // 00110011 + // 11001100 + + int8x8_t _pB1 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA), _pB0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA), _pB1); + int16x8_t _s3 = vmull_s8(vget_high_s8(_pA), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + int8x8x2_t _pB01 = vuzp_s8(_pB, _pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB01.val[0]); + int16x8_t _s1 = vmull_s8(_pA, _pB01.val[1]); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s1)); + _sum2 = vaddw_s16(_sum2, vget_high_s16(_s0)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s1)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + // abcdefgh + + // 01010101 + // 10101010 + int8x8_t _pB1 = vext_s8(_pB0, _pB0, 1); + + int16x8_t _s0 = vmull_s8(_pA, _pB0); + int16x8_t _s1 = vmull_s8(_pA, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s1)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s1)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _s0 = vdupq_n_s32(0); + int32x4_t _s1 = vdupq_n_s32(0); + int32x4_t _s2 = vdupq_n_s32(0); + int32x4_t _s3 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pA2 = vld1q_s8(pA + 32); + int8x16_t _pA3 = vld1q_s8(pA + 48); + + int8x8_t _pB = vld1_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb ..... hhhhhhhh + // 00000000 + int8x16_t _pBB = vcombine_s8(_pB, _pB); + + _s0 = vdotq_s32(_s0, _pA0, _pBB); + _s1 = vdotq_s32(_s1, _pA1, _pBB); + _s2 = vdotq_s32(_s2, _pA2, _pBB); + _s3 = vdotq_s32(_s3, _pA3, _pBB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA1, _pB, 0); + _sum0 = vdotq_lane_s32(_sum0, _pA2, _pB, 1); + _sum1 = vdotq_lane_s32(_sum1, _pA3, _pB, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 64; + pB += 8; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_s0, _s1)); + _sum1 = vaddq_s32(_sum1, vpaddq_s32(_s2, _s3)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + + int8x8_t _pB = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // aaaa bbbb cccc dddd eeee ffff gggg hhhh + + // 0000 0000 + + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA1, _pB, 0); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA2 = vld1q_s8(pA + 16); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + int8x8_t _pB1 = vreinterpret_s8_s16(vld1_dup_s16((const short*)(pB + 2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), _pB0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA0), _pB0); + _s0 = vmlal_s8(_s0, vget_low_s8(_pA2), _pB1); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA2), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB); + int16x8_t _s1 = vmull_s8(vget_high_s8(_pA), _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 16; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_dup_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + + pA += 8; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } + + pAT += max_kk * 8; + } + for (; ii + 3 < max_ii; ii += 4) + { + const signed char* pB = pBT; + + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM + asm volatile( + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0] \n" + "sub %0, %0, #64 \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + "eor v20.16b, v20.16b, v20.16b \n" + "eor v21.16b, v21.16b, v21.16b \n" + "eor v22.16b, v22.16b, v22.16b \n" + "eor v23.16b, v23.16b, v23.16b \n" + + "1: \n" +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" + "eor v28.16b, v28.16b, v28.16b \n" + "eor v29.16b, v29.16b, v29.16b \n" + "eor v30.16b, v30.16b, v30.16b \n" + "eor v31.16b, v31.16b, v31.16b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [%2], #64 \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "smmla v24.4s, v0.16b, v2.16b \n" + "smmla v25.4s, v1.16b, v2.16b \n" + "smmla v26.4s, v0.16b, v3.16b \n" + "smmla v27.4s, v1.16b, v3.16b \n" + "subs w4, w4, #1 \n" + "smmla v28.4s, v0.16b, v4.16b \n" + "smmla v29.4s, v1.16b, v4.16b \n" + "smmla v30.4s, v0.16b, v5.16b \n" + "smmla v31.4s, v1.16b, v5.16b \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v0.16b, v3.4b[0] \n" + "sdot v21.4s, v0.16b, v3.4b[1] \n" + "sdot v22.4s, v0.16b, v3.4b[2] \n" + "sdot v23.4s, v0.16b, v3.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v1.16b, v4.4b[0] \n" + "sdot v17.4s, v1.16b, v4.4b[1] \n" + "sdot v18.4s, v1.16b, v4.4b[2] \n" + "sdot v19.4s, v1.16b, v4.4b[3] \n" + "sdot v20.4s, v1.16b, v5.4b[0] \n" + "sdot v21.4s, v1.16b, v5.4b[1] \n" + "sdot v22.4s, v1.16b, v5.4b[2] \n" + "sdot v23.4s, v1.16b, v5.4b[3] \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + "bne 2b \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "uzp1 v0.4s, v24.4s, v25.4s \n" + "uzp2 v1.4s, v24.4s, v25.4s \n" + "uzp1 v2.4s, v26.4s, v27.4s \n" + "uzp2 v3.4s, v26.4s, v27.4s \n" + "uzp1 v4.4s, v28.4s, v29.4s \n" + "uzp2 v5.4s, v28.4s, v29.4s \n" + "uzp1 v6.4s, v30.4s, v31.4s \n" + "uzp2 v7.4s, v30.4s, v31.4s \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" + "add v20.4s, v20.4s, v4.4s \n" + "add v21.4s, v21.4s, v5.4s \n" + "add v22.4s, v22.4s, v6.4s \n" + "add v23.4s, v23.4s, v7.4s \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v2.16b, v3.16b}, [%2], #32 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" + "sdot v20.4s, v0.16b, v3.4b[0] \n" + "sdot v21.4s, v0.16b, v3.4b[1] \n" + "sdot v22.4s, v0.16b, v3.4b[2] \n" + "sdot v23.4s, v0.16b, v3.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull2 v9.8h, v0.16b, v5.16b \n" + "rev64 v2.4s, v0.4s \n" + "smull v10.8h, v2.8b, v4.8b \n" + "smull2 v11.8h, v2.16b, v5.16b \n" + "rev64 v6.8h, v4.8h \n" + "smull v12.8h, v0.8b, v6.8b \n" + "smull v14.8h, v2.8b, v6.8b \n" + "rev64 v7.8h, v5.8h \n" + "smull2 v13.8h, v0.16b, v7.16b \n" + "smull2 v15.8h, v2.16b, v7.16b \n" + "ext v1.16b, v0.16b, v0.16b, #8 \n" + "ext v3.16b, v2.16b, v2.16b, #8 \n" + "smlal v8.8h, v1.8b, v5.8b \n" + "smlal2 v9.8h, v1.16b, v4.16b \n" + "smlal v10.8h, v3.8b, v5.8b \n" + "smlal2 v11.8h, v3.16b, v4.16b \n" + "smlal v12.8h, v1.8b, v7.8b \n" + "smlal v14.8h, v3.8b, v7.8b \n" + "smlal2 v13.8h, v1.16b, v6.16b \n" + "smlal2 v15.8h, v3.16b, v6.16b \n" + "subs w4, w4, #1 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v23.4s, v15.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.16b}, [%2], #16 \n" + "dup v4.8h, v1.h[0] \n" + "dup v5.8h, v1.h[1] \n" + "dup v6.8h, v1.h[2] \n" + "dup v7.8h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "dup v4.8h, v1.h[4] \n" + "dup v5.8h, v1.h[5] \n" + "dup v6.8h, v1.h[6] \n" + "dup v7.8h, v1.h[7] \n" + "smull v12.8h, v0.8b, v4.8b \n" + "smull v13.8h, v0.8b, v5.8b \n" + "smull v14.8h, v0.8b, v6.8b \n" + "smull v15.8h, v0.8b, v7.8b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1r {v0.2d}, [%1] \n" + "add %1, %1, #8 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "rev64 v1.4s, v0.4s \n" + "rev64 v3.8h, v2.8h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull2 v9.8h, v0.16b, v2.16b \n" + "smull v10.8h, v1.8b, v2.8b \n" + "smull2 v11.8h, v1.16b, v2.16b \n" + "smull v12.8h, v0.8b, v3.8b \n" + "smull2 v13.8h, v0.16b, v3.16b \n" + "smull v14.8h, v1.8b, v3.8b \n" + "smull2 v15.8h, v1.16b, v3.16b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "sadalp v20.4s, v12.8h \n" + "sadalp v21.4s, v13.8h \n" + "sadalp v22.4s, v14.8h \n" + "sadalp v23.4s, v15.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1r {v0.2s}, [%1] \n" + "ld1 {v1.8b}, [%2], #8 \n" + "add %1, %1, #4 \n" + "dup v8.8h, v1.h[0] \n" + "dup v9.8h, v1.h[1] \n" + "dup v10.8h, v1.h[2] \n" + "dup v11.8h, v1.h[3] \n" + "uzp1 v2.8b, v8.8b, v9.8b \n" + "uzp2 v3.8b, v8.8b, v9.8b \n" + "uzp1 v4.8b, v10.8b, v11.8b \n" + "uzp2 v5.8b, v10.8b, v11.8b \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v0.8b, v3.8b \n" + "smull v10.8h, v0.8b, v4.8b \n" + "smull v11.8h, v0.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw v17.4s, v17.4s, v9.4h \n" + "saddw2 v18.4s, v18.4s, v8.8h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw v21.4s, v21.4s, v11.4h \n" + "saddw2 v22.4s, v22.4s, v10.8h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1r {v0.2s}, [%1] \n" + "ld1 {v2.8b}, [%2], #8 \n" + "add %1, %1, #4 \n" + "ext v1.8b, v0.8b, v0.8b, #2 \n" + "rev32 v3.8b, v2.8b \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v1.8b, v2.8b \n" + "smull v10.8h, v0.8b, v3.8b \n" + "smull v11.8h, v1.8b, v3.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" + "saddw v20.4s, v20.4s, v10.4h \n" + "saddw2 v21.4s, v21.4s, v10.8h \n" + "saddw v22.4s, v22.4s, v11.4h \n" + "saddw2 v23.4s, v23.4s, v11.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + int32x4_t _sum4; + int32x4_t _sum5; + int32x4_t _sum6; + int32x4_t _sum7; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + _sum4 = vdupq_n_s32(0); + _sum5 = vdupq_n_s32(0); + _sum6 = vdupq_n_s32(0); + _sum7 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + _sum4 = vld1q_s32(outptr + 16); + _sum5 = vld1q_s32(outptr + 20); + _sum6 = vld1q_s32(outptr + 24); + _sum7 = vld1q_s32(outptr + 28); + } + + int kk = 0; +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum10 = vdupq_n_s32(0); + int32x4_t _sum11 = vdupq_n_s32(0); + int32x4_t _sum20 = vdupq_n_s32(0); + int32x4_t _sum21 = vdupq_n_s32(0); + int32x4_t _sum30 = vdupq_n_s32(0); + int32x4_t _sum31 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 11111111 22222222 33333333 + // 44444444 55555555 66666666 77777777 + + _sum00 = vmmlaq_s32(_sum00, _pA0, _pB0); + _sum01 = vmmlaq_s32(_sum01, _pA1, _pB0); + _sum10 = vmmlaq_s32(_sum10, _pA0, _pB1); + _sum11 = vmmlaq_s32(_sum11, _pA1, _pB1); + _sum20 = vmmlaq_s32(_sum20, _pA0, _pB2); + _sum21 = vmmlaq_s32(_sum21, _pA1, _pB2); + _sum30 = vmmlaq_s32(_sum30, _pA0, _pB3); + _sum31 = vmmlaq_s32(_sum31, _pA1, _pB3); + + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // a2 a3 b2 b3 + // c2 c3 d2 d3 + // a4 a5 b4 b5 + // c4 c5 d4 d5 + // a6 a7 b6 b7 + // c6 c7 d6 d7 + + pA += 32; + pB += 64; + } + int32x4x2_t _ss0 = vuzpq_s32(_sum00, _sum01); + int32x4x2_t _ss1 = vuzpq_s32(_sum10, _sum11); + int32x4x2_t _ss2 = vuzpq_s32(_sum20, _sum21); + int32x4x2_t _ss3 = vuzpq_s32(_sum30, _sum31); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + _sum4 = vaddq_s32(_sum4, _ss2.val[0]); + _sum5 = vaddq_s32(_sum5, _ss2.val[1]); + _sum6 = vaddq_s32(_sum6, _ss3.val[0]); + _sum7 = vaddq_s32(_sum7, _ss3.val[1]); + } +#elif __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA0, _pB1, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA0, _pB1, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA0, _pB1, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA0, _pB1, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA1, _pB2, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA1, _pB2, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA1, _pB2, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA1, _pB2, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA1, _pB3, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA1, _pB3, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA1, _pB3, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA1, _pB3, 3); + + pA += 32; + pB += 64; + } +#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + _sum0 = vdotq_laneq_s32(_sum0, _pA, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA, _pB0, 3); + _sum4 = vdotq_laneq_s32(_sum4, _pA, _pB1, 0); + _sum5 = vdotq_laneq_s32(_sum5, _pA, _pB1, 1); + _sum6 = vdotq_laneq_s32(_sum6, _pA, _pB1, 2); + _sum7 = vdotq_laneq_s32(_sum7, _pA, _pB1, 3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA02 = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB2 = vld1q_s8(pB + 16); + + int8x16_t _pA13 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA02))); + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + int8x16_t _pB3 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA02), vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA13), vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB1)); + int16x8_t _s5 = vmull_s8(vget_low_s8(_pA02), vget_high_s8(_pB1)); + int16x8_t _s6 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB1)); + int16x8_t _s7 = vmull_s8(vget_low_s8(_pA13), vget_high_s8(_pB1)); + + _s0 = vmlal_s8(_s0, vget_high_s8(_pA02), vget_low_s8(_pB2)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA02), vget_high_s8(_pB2)); + _s2 = vmlal_s8(_s2, vget_high_s8(_pA13), vget_low_s8(_pB2)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA13), vget_high_s8(_pB2)); + _s4 = vmlal_s8(_s4, vget_high_s8(_pA02), vget_low_s8(_pB3)); + _s5 = vmlal_s8(_s5, vget_high_s8(_pA02), vget_high_s8(_pB3)); + _s6 = vmlal_s8(_s6, vget_high_s8(_pA13), vget_low_s8(_pB3)); + _s7 = vmlal_s8(_s7, vget_high_s8(_pA13), vget_high_s8(_pB3)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x16_t _pB01 = vld1q_s8(pB); + + // aabbccdd + + // 00112233 44556677 + + int16x8_t _s0 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 0))); + int16x8_t _s1 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 1))); + int16x8_t _s2 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 2))); + int16x8_t _s3 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pB01)), 3))); + int16x8_t _s4 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 0))); + int16x8_t _s5 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 1))); + int16x8_t _s6 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 2))); + int16x8_t _s7 = vmull_s8(_pA0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pB01)), 3))); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + + // aabbccdd + // ccddaabb + + int8x8_t _pA1 = vreinterpret_s8_s32(vrev64_s32(vreinterpret_s32_s8(_pA0))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB1 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA1, vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(_pA1, vget_high_s8(_pB0)); + int16x8_t _s4 = vmull_s8(_pA0, vget_low_s8(_pB1)); + int16x8_t _s5 = vmull_s8(_pA0, vget_high_s8(_pB1)); + int16x8_t _s6 = vmull_s8(_pA1, vget_low_s8(_pB1)); + int16x8_t _s7 = vmull_s8(_pA1, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pAA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vld1_s8(pB); + + // abcdabcd + // 01234567 -> 01010101 23232323 45454545 67676767 + int8x8_t _pB0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0)); + int8x8_t _pB2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1)); + int8x8_t _pB4 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2)); + int8x8_t _pB6 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3)); + + int8x8x2_t _pB0123 = vuzp_s8(_pB0, _pB2); + int8x8x2_t _pB4567 = vuzp_s8(_pB4, _pB6); + + int16x8_t _s02 = vmull_s8(_pAA, _pB0123.val[0]); + int16x8_t _s13 = vmull_s8(_pAA, _pB0123.val[1]); + int16x8_t _s46 = vmull_s8(_pAA, _pB4567.val[0]); + int16x8_t _s57 = vmull_s8(_pAA, _pB4567.val[1]); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s02)); + _sum1 = vaddw_s16(_sum1, vget_low_s16(_s13)); + _sum2 = vaddw_s16(_sum2, vget_high_s16(_s02)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s13)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s46)); + _sum5 = vaddw_s16(_sum5, vget_low_s16(_s57)); + _sum6 = vaddw_s16(_sum6, vget_high_s16(_s46)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s57)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB0 = vld1_s8(pB); + + // abcd abcd + // cdab cdab + + int8x8_t _pA1 = vext_s8(_pA0, _pA0, 2); + + // 0123 4567 + // 3210 7654 + + int8x8_t _pB1 = vrev32_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA0, _pB0); + int16x8_t _s23 = vmull_s8(_pA1, _pB0); + int16x8_t _s45 = vmull_s8(_pA0, _pB1); + int16x8_t _s67 = vmull_s8(_pA1, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); + _sum4 = vaddw_s16(_sum4, vget_low_s16(_s45)); + _sum5 = vaddw_s16(_sum5, vget_high_s16(_s45)); + _sum6 = vaddw_s16(_sum6, vget_low_s16(_s67)); + _sum7 = vaddw_s16(_sum7, vget_high_s16(_s67)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + vst1q_s32(outptr + 16, _sum4); + vst1q_s32(outptr + 20, _sum5); + vst1q_s32(outptr + 24, _sum6); + vst1q_s32(outptr + 28, _sum7); + + outptr += 32; +#endif // NCNN_GNU_INLINE_ASM + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const signed char* pA = pAT; + +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "cmp %w7, #0 \n" + "beq 0f \n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0] \n" + "b 1f \n" + + "0: \n" + "eor v16.16b, v16.16b, v16.16b \n" + "eor v17.16b, v17.16b, v17.16b \n" + "eor v18.16b, v18.16b, v18.16b \n" + "eor v19.16b, v19.16b, v19.16b \n" + + "1: \n" +#if __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #3 \n" // w4 = max_kk >> 3 + "cmp w4, #0 \n" + "beq 101f \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "eor v24.16b, v24.16b, v24.16b \n" + "eor v25.16b, v25.16b, v25.16b \n" + "eor v26.16b, v26.16b, v26.16b \n" + "eor v27.16b, v27.16b, v27.16b \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "2: \n" + "ld1 {v0.16b, v1.16b}, [%1], #32 \n" + "ld1 {v4.16b, v5.16b}, [%2], #32 \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "smmla v24.4s, v0.16b, v4.16b \n" + "smmla v25.4s, v1.16b, v4.16b \n" + "subs w4, w4, #1 \n" + "smmla v26.4s, v0.16b, v5.16b \n" + "smmla v27.4s, v1.16b, v5.16b \n" +#else // __ARM_FEATURE_MATMUL_INT8 + "sdot v16.4s, v0.16b, v4.4b[0] \n" + "sdot v17.4s, v0.16b, v4.4b[1] \n" + "sdot v18.4s, v0.16b, v4.4b[2] \n" + "sdot v19.4s, v0.16b, v4.4b[3] \n" + "subs w4, w4, #1 \n" + "sdot v16.4s, v1.16b, v5.4b[0] \n" + "sdot v17.4s, v1.16b, v5.4b[1] \n" + "sdot v18.4s, v1.16b, v5.4b[2] \n" + "sdot v19.4s, v1.16b, v5.4b[3] \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + "bne 2b \n" + +#if __ARM_FEATURE_MATMUL_INT8 + "uzp1 v0.4s, v24.4s, v25.4s \n" + "uzp2 v1.4s, v24.4s, v25.4s \n" + "uzp1 v2.4s, v26.4s, v27.4s \n" + "uzp2 v3.4s, v26.4s, v27.4s \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "add v18.4s, v18.4s, v2.4s \n" + "add v19.4s, v19.4s, v3.4s \n" +#endif // __ARM_FEATURE_MATMUL_INT8 + + "101: \n" + "and w4, %w6, #4 \n" // w4 = remain = max_kk & 4 + "cmp w4, #0 \n" + "beq 3f \n" + + // kk += 4 part + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v2.16b}, [%2], #16 \n" + "sdot v16.4s, v0.16b, v2.4b[0] \n" + "sdot v17.4s, v0.16b, v2.4b[1] \n" + "sdot v18.4s, v0.16b, v2.4b[2] \n" + "sdot v19.4s, v0.16b, v2.4b[3] \n" +#else // __ARM_FEATURE_DOTPROD + "lsr w4, %w6, #2 \n" // w4 = max_kk >> 2 + "cmp w4, #0 \n" + "beq 3f \n" + + "2: \n" + "ld1 {v0.16b}, [%1], #16 \n" + "ld1 {v4.16b}, [%2], #16 \n" + "smull v8.8h, v0.8b, v4.8b \n" + "rev64 v1.4s, v0.4s \n" + "smull v9.8h, v1.8b, v4.8b \n" + "rev64 v5.8h, v4.8h \n" + "smull v10.8h, v0.8b, v5.8b \n" + "smull v11.8h, v1.8b, v5.8b \n" + "smlal2 v8.8h, v0.16b, v4.16b \n" + "smlal2 v9.8h, v1.16b, v4.16b \n" + "smlal2 v10.8h, v0.16b, v5.16b \n" + "smlal2 v11.8h, v1.16b, v5.16b \n" + "subs w4, w4, #1 \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" + "bne 2b \n" +#endif // __ARM_FEATURE_DOTPROD + + "3: \n" + "and w4, %w6, #2 \n" // w4 = remain = max_kk & 2 + "cmp w4, #0 \n" + "beq 4f \n" + + // kk += 2 part +#if __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v1.8b}, [%2], #8 \n" + "dup v4.4h, v1.h[0] \n" + "dup v5.4h, v1.h[1] \n" + "dup v6.4h, v1.h[2] \n" + "dup v7.4h, v1.h[3] \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "smull v10.8h, v0.8b, v6.8b \n" + "smull v11.8h, v0.8b, v7.8b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1], #8 \n" + "ld1 {v2.8b}, [%2], #8 \n" + "ext v1.8b, v0.8b, v0.8b, #4 \n" + "rev64 v3.4h, v2.4h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v1.8b, v2.8b \n" + "smull v10.8h, v0.8b, v3.8b \n" + "smull v11.8h, v1.8b, v3.8b \n" + "sadalp v16.4s, v8.8h \n" + "sadalp v17.4s, v9.8h \n" + "sadalp v18.4s, v10.8h \n" + "sadalp v19.4s, v11.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "4: \n" + "and w4, %w6, #1 \n" // w4 = remain = max_kk & 1 + "cmp w4, #0 \n" + "beq 5f \n" + + // kk += 1 part +#if __ARM_FEATURE_DOTPROD + "ld1r {v0.2s}, [%1] \n" + "ld1r {v1.2s}, [%2] \n" + "add %1, %1, #4 \n" + "add %2, %2, #4 \n" + "zip1 v1.8b, v1.8b, v1.8b \n" + "zip1 v2.4h, v1.4h, v1.4h \n" + "zip2 v3.4h, v1.4h, v1.4h \n" + "smull v8.8h, v0.8b, v2.8b \n" + "smull v9.8h, v0.8b, v3.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" +#else // __ARM_FEATURE_DOTPROD + "ld1 {v0.8b}, [%1] \n" + "ld1r {v4.2s}, [%2] \n" + "add %1, %1, #4 \n" + "add %2, %2, #4 \n" + "rev32 v1.4h, v0.4h \n" + "zip1 v0.2s, v0.2s, v1.2s \n" + "rev32 v5.8b, v4.8b \n" + "smull v8.8h, v0.8b, v4.8b \n" + "smull v9.8h, v0.8b, v5.8b \n" + "saddw v16.4s, v16.4s, v8.4h \n" + "saddw2 v17.4s, v17.4s, v8.8h \n" + "saddw v18.4s, v18.4s, v9.4h \n" + "saddw2 v19.4s, v19.4s, v9.8h \n" +#endif // __ARM_FEATURE_DOTPROD + + "5: \n" + "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile( + "cmp %7, #0 \n" + "beq 0f \n" + + "vldm %0, {d16-d23} \n" + "b 1f \n" + + "0: \n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + + "1: \n" + "lsr r4, %6, #2 \n" // r4 = max_kk >> 2 + "cmp r4, #0 \n" + "beq 3f \n" + + ".align 4 \n" + "2: \n" + "pld [%1, #256] \n" + "vld1.s8 {d0-d1}, [%1 :64]! \n" + "pld [%2, #128] \n" + "vld1.s8 {d4-d5}, [%2]! \n" + "vrev64.32 q1, q0 \n" + "vmull.s8 q4, d0, d4 \n" + "vrev64.16 q3, q2 \n" + "vmull.s8 q5, d2, d4 \n" + "vmull.s8 q6, d0, d6 \n" + "vmull.s8 q7, d2, d6 \n" + "vmlal.s8 q4, d1, d5 \n" + "vmlal.s8 q5, d3, d5 \n" + "vmlal.s8 q6, d1, d7 \n" + "vmlal.s8 q7, d3, d7 \n" + "subs r4, r4, #1 \n" + "vpadal.s16 q8, q4 \n" + "vpadal.s16 q9, q5 \n" + "vpadal.s16 q10, q6 \n" + "vpadal.s16 q11, q7 \n" + "bne 2b \n" + + "3: \n" + "and r4, %6, #2 \n" // r4 = remain = max_kk & 2 + "cmp r4, #0 \n" + "beq 4f \n" + + // kk += 2 part + "vld1.s8 {d0}, [%1 :64]! \n" + "vld1.s8 {d4}, [%2]! \n" + "vext.8 d1, d0, d0, #4 \n" + "vrev64.16 d5, d4 \n" + "vmull.s8 q4, d0, d4 \n" + "vmull.s8 q5, d1, d4 \n" + "vmull.s8 q6, d0, d5 \n" + "vmull.s8 q7, d1, d5 \n" + "vpadal.s16 q8, q4 \n" + "vpadal.s16 q9, q5 \n" + "vpadal.s16 q10, q6 \n" + "vpadal.s16 q11, q7 \n" + + "4: \n" + "and r4, %6, #1 \n" // r4 = remain = max_kk & 1 + "cmp r4, #0 \n" + "beq 5f \n" + + // kk += 1 part + "vld1.s32 {d0[0]}, [%1]! \n" + "vld1.s32 {d2[]}, [%2]! \n" + "vrev32.16 d1, d0 \n" + "vrev32.s8 d3, d2 \n" + "vzip.32 d0, d1 \n" + "vmull.s8 q4, d0, d2 \n" + "vmull.s8 q5, d0, d3 \n" + "vaddw.s16 q8, d8 \n" + "vaddw.s16 q9, d9 \n" + "vaddw.s16 q10, d10 \n" + "vaddw.s16 q11, d11 \n" + + "5: \n" + "vstm %0!, {d16-d23} \n" + + : "=r"(outptr), // %0 + "=r"(pA), // %1 + "=r"(pB) // %2 + : "0"(outptr), + "1"(pA), + "2"(pB), + "r"(max_kk), // %6 + "r"(k) // %7 + : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + int kk = 0; +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum10 = vdupq_n_s32(0); + int32x4_t _sum11 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 11111111 22222222 33333333 + + _sum00 = vmmlaq_s32(_sum00, _pA0, _pB0); + _sum01 = vmmlaq_s32(_sum01, _pA1, _pB0); + _sum10 = vmmlaq_s32(_sum10, _pA0, _pB1); + _sum11 = vmmlaq_s32(_sum11, _pA1, _pB1); + + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // a2 a3 b2 b3 + // c2 c3 d2 d3 + + pA += 32; + pB += 32; + } + int32x4x2_t _ss0 = vuzpq_s32(_sum00, _sum01); + int32x4x2_t _ss1 = vuzpq_s32(_sum10, _sum11); + _sum0 = vaddq_s32(_sum0, _ss0.val[0]); + _sum1 = vaddq_s32(_sum1, _ss0.val[1]); + _sum2 = vaddq_s32(_sum2, _ss1.val[0]); + _sum3 = vaddq_s32(_sum3, _ss1.val[1]); + } +#elif __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB0, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB0, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA0, _pB0, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA0, _pB0, 3); + + _sum0 = vdotq_laneq_s32(_sum0, _pA1, _pB1, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA1, _pB1, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA1, _pB1, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA1, _pB1, 3); + + pA += 32; + pB += 32; + } +#endif // __ARM_FEATURE_MATMUL_INT8 || __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + _sum0 = vdotq_laneq_s32(_sum0, _pA, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA, _pB, 1); + _sum2 = vdotq_laneq_s32(_sum2, _pA, _pB, 2); + _sum3 = vdotq_laneq_s32(_sum3, _pA, _pB, 3); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA02 = vld1q_s8(pA); + int8x16_t _pB02 = vld1q_s8(pB); + + // aabbccdd eeffgghh + // ccddaabb gghheeff + + int8x16_t _pA13 = vreinterpretq_s8_s32(vrev64q_s32(vreinterpretq_s32_s8(_pA02))); + + // 00112233 44556677 + // 33221100 77665544 + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB02)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_pA02), vget_low_s8(_pB13)); + int16x8_t _s3 = vmull_s8(vget_low_s8(_pA13), vget_low_s8(_pB13)); + + _s0 = vmlal_s8(_s0, vget_high_s8(_pA02), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA13), vget_high_s8(_pB02)); + _s2 = vmlal_s8(_s2, vget_high_s8(_pA02), vget_high_s8(_pB13)); + _s3 = vmlal_s8(_s3, vget_high_s8(_pA13), vget_high_s8(_pB13)); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s1 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + int16x8_t _s2 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 2))); + int16x8_t _s3 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 3))); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vld1_s8(pB); + + // aabbccdd + // ccddaabb + + int8x8_t _pA1 = vext_s8(_pA0, _pA0, 4); + + // 00112233 + // 33221100 + + int8x8_t _pB1 = vreinterpret_s8_s16(vrev64_s16(vreinterpret_s16_s8(_pB0))); + + int16x8_t _s0 = vmull_s8(_pA0, _pB0); + int16x8_t _s1 = vmull_s8(_pA1, _pB0); + int16x8_t _s2 = vmull_s8(_pA0, _pB1); + int16x8_t _s3 = vmull_s8(_pA1, _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + _pB = vzip_s8(_pB, _pB).val[0]; + int16x4x2_t _pB0123 = vzip_s16(vreinterpret_s16_s8(_pB), vreinterpret_s16_s8(_pB)); + + int16x8_t _s01 = vmull_s8(_pA, vreinterpret_s8_s16(_pB0123.val[0])); + int16x8_t _s23 = vmull_s8(_pA, vreinterpret_s8_s16(_pB0123.val[1])); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); +#else // __ARM_FEATURE_DOTPROD + + int8x8_t _pA0 = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // abcd.... -> cdab.... -> abcdcdab + int8x8_t _pA1 = vreinterpret_s8_s16(vrev32_s16(vreinterpret_s16_s8(_pA0))); + int8x8_t _pA01 = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pA0), vreinterpret_s32_s8(_pA1)).val[0]); + + // 01230123 -> 32103210 + int8x8_t _pB1 = vrev32_s8(_pB0); + + int16x8_t _s01 = vmull_s8(_pA01, _pB0); + int16x8_t _s23 = vmull_s8(_pA01, _pB1); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s01)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s01)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s23)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s23)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + + outptr += 16; +#endif // NCNN_GNU_INLINE_ASM + } + for (; jj + 1 < max_jj; jj += 2) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 11111111 + + _sum00 = vmmlaq_s32(_sum00, _pA0, _pB); + _sum01 = vmmlaq_s32(_sum01, _pA1, _pB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_laneq_s32(_sum0, _pA0, _pB, 0); + _sum1 = vdotq_laneq_s32(_sum1, _pA0, _pB, 1); + _sum0 = vdotq_laneq_s32(_sum0, _pA1, _pB, 2); + _sum1 = vdotq_laneq_s32(_sum1, _pA1, _pB, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 32; + pB += 16; + } +#if __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _ss = vuzpq_s32(_sum00, _sum01); + _sum0 = vaddq_s32(_sum0, _ss.val[0]); + _sum1 = vaddq_s32(_sum1, _ss.val[1]); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _sum0 = vdotq_lane_s32(_sum0, _pA, _pB, 0); + _sum1 = vdotq_lane_s32(_sum1, _pA, _pB, 1); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + // aabbccdd eeffgghh + + // 00112233 -> 00110011 22332233 + // 11001100 33223322 + + int32x2x2_t _pBB = vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)); + int8x16_t _pB02 = vreinterpretq_s8_s32(vcombine_s32(_pBB.val[0], _pBB.val[1])); + + int8x16_t _pB13 = vreinterpretq_s8_s16(vrev64q_s16(vreinterpretq_s16_s8(_pB02))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vget_low_s8(_pB02)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA), vget_low_s8(_pB13)); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA), vget_high_s8(_pB02)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA), vget_high_s8(_pB13)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 8; + } + for (; kk + 1 < max_kk; kk += 2) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + // aabbccdd + // 0011.... + int16x8_t _s0 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 0))); + int16x8_t _s1 = vmull_s8(_pA, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pB), 1))); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + // aabbccdd + + // 00110011 + // 11001100 + int8x8_t _pB1 = vext_s8(_pB0, _pB0, 2); + + int16x8_t _s0 = vmull_s8(_pA, _pB0); + int16x8_t _s1 = vmull_s8(_pA, _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + // abcdabcd + + // 01010101 -> 00001111 + _pB = vuzp_s8(_pB, vext_s8(_pB, _pB, 1)).val[0]; + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + // abcd abcd + + // 0101 0101 -> 0101 1010 + + int8x8_t _pB1 = vext_s8(_pB0, _pB0, 1); + int8x8_t _pB = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pB0), vreinterpret_s32_s8(_pB1)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 2; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + const signed char* pA = pAT; + + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x8_t _pB = vld1_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + // aaaaaaaa bbbbbbbb cccccccc dddddddd + + // 00000000 + + int8x16_t _pBB = vcombine_s8(_pB, _pB); + + _sum01 = vdotq_s32(_sum01, _pA0, _pBB); + _sum23 = vdotq_s32(_sum23, _pA1, _pBB); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pA0, _pB, 0); + _sum0 = vdotq_lane_s32(_sum0, _pA1, _pB, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 32; + pB += 8; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_sum01, _sum23)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _sum0 = vdotq_lane_s32(_sum0, _pA, _pB, 0); +#else // __ARM_FEATURE_DOTPROD + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB0 = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + int8x8_t _pB1 = vreinterpret_s8_s16(vld1_dup_s16((const short*)(pB + 2))); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), _pB0); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA), _pB1); + _sum0 = vpadalq_s16(_sum0, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s16(vld1_dup_s16((const short*)pB)); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 8; + pB += 2; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vld1_dup_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + + pA += 4; + pB += 1; + } + + vst1q_s32(outptr, _sum0); + + outptr += 4; + } + + pAT += max_kk * 4; + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const signed char* pB = pBT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0; + int32x4_t _sum1; + int32x4_t _sum2; + int32x4_t _sum3; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + _sum2 = vld1q_s32(outptr + 8); + _sum3 = vld1q_s32(outptr + 12); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); + int32x4_t _sum45 = vdupq_n_s32(0); + int32x4_t _sum67 = vdupq_n_s32(0); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); + int32x2_t _sum20 = vdup_n_s32(0); + int32x2_t _sum21 = vdup_n_s32(0); + int32x2_t _sum30 = vdup_n_s32(0); + int32x2_t _sum31 = vdup_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + +#if __ARM_FEATURE_MATMUL_INT8 + _sum01 = vmmlaq_s32(_sum01, _pA, _pB0); + _sum23 = vmmlaq_s32(_sum23, _pA, _pB1); + _sum45 = vmmlaq_s32(_sum45, _pA, _pB2); + _sum67 = vmmlaq_s32(_sum67, _pA, _pB3); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum00 = vdot_laneq_s32(_sum00, vget_low_s8(_pA), _pB0, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_low_s8(_pA), _pB0, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_low_s8(_pA), _pB0, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_low_s8(_pA), _pB0, 3); + _sum20 = vdot_laneq_s32(_sum20, vget_low_s8(_pA), _pB1, 0); + _sum21 = vdot_laneq_s32(_sum21, vget_low_s8(_pA), _pB1, 1); + _sum30 = vdot_laneq_s32(_sum30, vget_low_s8(_pA), _pB1, 2); + _sum31 = vdot_laneq_s32(_sum31, vget_low_s8(_pA), _pB1, 3); + _sum00 = vdot_laneq_s32(_sum00, vget_high_s8(_pA), _pB2, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_high_s8(_pA), _pB2, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_high_s8(_pA), _pB2, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_high_s8(_pA), _pB2, 3); + _sum20 = vdot_laneq_s32(_sum20, vget_high_s8(_pA), _pB3, 0); + _sum21 = vdot_laneq_s32(_sum21, vget_high_s8(_pA), _pB3, 1); + _sum30 = vdot_laneq_s32(_sum30, vget_high_s8(_pA), _pB3, 2); + _sum31 = vdot_laneq_s32(_sum31, vget_high_s8(_pA), _pB3, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 16; + pB += 64; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vcombine_s32(vget_low_s32(_sum01), vget_low_s32(_sum23))); + _sum1 = vaddq_s32(_sum1, vcombine_s32(vget_low_s32(_sum45), vget_low_s32(_sum67))); + _sum2 = vaddq_s32(_sum2, vcombine_s32(vget_high_s32(_sum01), vget_high_s32(_sum23))); + _sum3 = vaddq_s32(_sum3, vcombine_s32(vget_high_s32(_sum45), vget_high_s32(_sum67))); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + int32x2x2_t _sum2x = vzip_s32(_sum20, _sum21); + int32x2x2_t _sum3x = vzip_s32(_sum30, _sum31); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum2x.val[0], _sum3x.val[0])); + _sum2 = vaddq_s32(_sum2, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); + _sum3 = vaddq_s32(_sum3, vcombine_s32(_sum2x.val[1], _sum3x.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_DOTPROD + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); + int32x2_t _sum20 = vdup_n_s32(0); + int32x2_t _sum21 = vdup_n_s32(0); + int32x2_t _sum30 = vdup_n_s32(0); + int32x2_t _sum31 = vdup_n_s32(0); +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_DOTPROD + _sum00 = vdot_laneq_s32(_sum00, _pA, _pB0, 0); + _sum01 = vdot_laneq_s32(_sum01, _pA, _pB0, 1); + _sum10 = vdot_laneq_s32(_sum10, _pA, _pB0, 2); + _sum11 = vdot_laneq_s32(_sum11, _pA, _pB0, 3); + _sum20 = vdot_laneq_s32(_sum20, _pA, _pB1, 0); + _sum21 = vdot_laneq_s32(_sum21, _pA, _pB1, 1); + _sum30 = vdot_laneq_s32(_sum30, _pA, _pB1, 2); + _sum31 = vdot_laneq_s32(_sum31, _pA, _pB1, 3); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA1, vget_low_s8(_pB0)); + int16x8_t _s3 = vmull_s8(_pA1, vget_high_s8(_pB0)); + _s0 = vmlal_s8(_s0, _pA2, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA2, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA3, vget_low_s8(_pB1)); + _s3 = vmlal_s8(_s3, _pA3, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 32; + } +#if __ARM_FEATURE_DOTPROD + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + int32x2x2_t _sum2x = vzip_s32(_sum20, _sum21); + int32x2x2_t _sum3x = vzip_s32(_sum30, _sum31); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum2x.val[0], _sum3x.val[0])); + _sum2 = vaddq_s32(_sum2, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); + _sum3 = vaddq_s32(_sum3, vcombine_s32(_sum2x.val[1], _sum3x.val[1])); +#endif // __ARM_FEATURE_DOTPROD + } + for (; kk + 1 < max_kk; kk += 2) + { + int16x4_t _pA = vreinterpret_s16_s32(vld1_dup_s32((const int*)pA)); + int8x16_t _pB = vld1q_s8(pB); + + int16x4x2_t _pA01 = vuzp_s16(_pA, _pA); + int8x8_t _pA0 = vreinterpret_s8_s16(_pA01.val[0]); + int8x8_t _pA1 = vreinterpret_s8_s16(_pA01.val[1]); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB)); + int16x8_t _s2 = vmull_s8(_pA1, vget_low_s8(_pB)); + int16x8_t _s3 = vmull_s8(_pA1, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + pA += 4; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x8_t _pB = vld1_s8(pB); + + int8x8x2_t _pA01 = vuzp_s8(_pA, _pA); + + int16x8_t _s0 = vmull_s8(_pA01.val[0], _pB); + int16x8_t _s1 = vmull_s8(_pA01.val[1], _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + _sum2 = vaddw_s16(_sum2, vget_low_s16(_s1)); + _sum3 = vaddw_s16(_sum3, vget_high_s16(_s1)); + + pA += 2; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + vst1q_s32(outptr + 8, _sum2); + vst1q_s32(outptr + 12, _sum3); + + outptr += 16; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum23 = vdupq_n_s32(0); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_MATMUL_INT8 + _sum01 = vmmlaq_s32(_sum01, _pA, _pB0); + _sum23 = vmmlaq_s32(_sum23, _pA, _pB1); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum00 = vdot_laneq_s32(_sum00, vget_low_s8(_pA), _pB0, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_low_s8(_pA), _pB0, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_low_s8(_pA), _pB0, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_low_s8(_pA), _pB0, 3); + _sum00 = vdot_laneq_s32(_sum00, vget_high_s8(_pA), _pB1, 0); + _sum01 = vdot_laneq_s32(_sum01, vget_high_s8(_pA), _pB1, 1); + _sum10 = vdot_laneq_s32(_sum10, vget_high_s8(_pA), _pB1, 2); + _sum11 = vdot_laneq_s32(_sum11, vget_high_s8(_pA), _pB1, 3); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 16; + pB += 32; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vcombine_s32(vget_low_s32(_sum01), vget_low_s32(_sum23))); + _sum1 = vaddq_s32(_sum1, vcombine_s32(vget_high_s32(_sum01), vget_high_s32(_sum23))); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#endif // __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_DOTPROD + int32x2_t _sum00 = vdup_n_s32(0); + int32x2_t _sum01 = vdup_n_s32(0); + int32x2_t _sum10 = vdup_n_s32(0); + int32x2_t _sum11 = vdup_n_s32(0); +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_DOTPROD + _sum00 = vdot_laneq_s32(_sum00, _pA, _pB, 0); + _sum01 = vdot_laneq_s32(_sum01, _pA, _pB, 1); + _sum10 = vdot_laneq_s32(_sum10, _pA, _pB, 2); + _sum11 = vdot_laneq_s32(_sum11, _pA, _pB, 3); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + int16x8_t _s1 = vmull_s8(_pA1, vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, _pA2, vget_high_s8(_pB)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 16; + } +#if __ARM_FEATURE_DOTPROD + int32x2x2_t _sum0x = vzip_s32(_sum00, _sum01); + int32x2x2_t _sum1x = vzip_s32(_sum10, _sum11); + _sum0 = vaddq_s32(_sum0, vcombine_s32(_sum0x.val[0], _sum1x.val[0])); + _sum1 = vaddq_s32(_sum1, vcombine_s32(_sum0x.val[1], _sum1x.val[1])); +#endif // __ARM_FEATURE_DOTPROD + } + for (; kk + 1 < max_kk; kk += 2) + { + int16x4_t _pA = vreinterpret_s16_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pA)), 0)); + int8x8_t _pB = vld1_s8(pB); + + int16x4x2_t _pA01 = vuzp_s16(_pA, _pA); + int8x8_t _pA0 = vreinterpret_s8_s16(_pA01.val[0]); + int8x8_t _pA1 = vreinterpret_s8_s16(_pA01.val[1]); + + int16x8_t _s0 = vmull_s8(_pA0, _pB); + int16x8_t _s1 = vmull_s8(_pA1, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 4; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x8_t _pB = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pB)), 0)); + + _pA = vzip_s8(_pA, _pA).val[0]; + _pA = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + + pA += 2; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { +#if __ARM_NEON + int32x4_t _sum; + + if (k == 0) + { + _sum = vdupq_n_s32(0); + } + else + { + _sum = vld1q_s32(outptr); + } + + const signed char* pA = pAT; + int kk = 0; + +#if __ARM_FEATURE_DOTPROD + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_MATMUL_INT8 + _sum = vmmlaq_s32(_sum, _pA, _pB); +#else // __ARM_FEATURE_MATMUL_INT8 + int32x4x2_t _pAA = vzipq_s32(vreinterpretq_s32_s8(_pA), vreinterpretq_s32_s8(_pA)); + int8x16_t _pA01 = vreinterpretq_s8_s32(_pAA.val[0]); + int8x16_t _pA23 = vreinterpretq_s8_s32(_pAA.val[1]); + int8x16_t _pB01 = vcombine_s8(vget_low_s8(_pB), vget_low_s8(_pB)); + int8x16_t _pB23 = vcombine_s8(vget_high_s8(_pB), vget_high_s8(_pB)); + + _sum = vdotq_s32(_sum, _pA01, _pB01); + _sum = vdotq_s32(_sum, _pA23, _pB23); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 16; + pB += 16; + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + +#if __ARM_FEATURE_DOTPROD + int32x2x2_t _pAA = vzip_s32(vreinterpret_s32_s8(_pA), vreinterpret_s32_s8(_pA)); + int8x16_t _pA01 = vreinterpretq_s8_s32(vcombine_s32(_pAA.val[0], _pAA.val[1])); + + int8x16_t _pB01 = vcombine_s8(_pB, _pB); + + _sum = vdotq_s32(_sum, _pA01, _pB01); +#else // __ARM_FEATURE_DOTPROD + int16x4x2_t _pA01 = vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)); + int32x2x2_t _pB01 = vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)); + + int16x8_t _s0 = vmull_s8(vreinterpret_s8_s16(_pA01.val[0]), vreinterpret_s8_s32(_pB01.val[0])); + _s0 = vmlal_s8(_s0, vreinterpret_s8_s16(_pA01.val[1]), vreinterpret_s8_s32(_pB01.val[1])); + _sum = vpadalq_s16(_sum, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 8; + pB += 8; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _pA = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)).val[0]); + _pB = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_pB), vreinterpret_s32_s8(_pB)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum = vpadalq_s16(_sum, _s0); + + // A0 A1 A2 A3 + // B0 B1 B2 B3 + + // A0 A1 A0 A1 A2 A3 A2 A3 + // B0 B1 B2 B3 B0 B1 B2 B3 + + pA += 4; + pB += 4; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x8_t _pB = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(pB)), 0)); + + _pA = vzip_s8(_pA, _pA).val[0]; + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum = vaddw_s16(_sum, vget_low_s16(_s0)); + + // A0 A1 A0 A1 + // B0 B1 B0 B1 + + // A0 A0 A1 A1 + + pA += 2; + pB += 2; + } + + vst1q_s32(outptr, _sum); + + outptr += 4; +#else // __ARM_NEON + int sum00; + int sum10; + int sum01; + int sum11; + + if (k == 0) + { + sum00 = 0; + sum10 = 0; + sum01 = 0; + sum11 = 0; + } + else + { + sum00 = outptr[0]; + sum10 = outptr[1]; + sum01 = outptr[2]; + sum11 = outptr[3]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk + 1 < max_kk; kk += 2) + { + // fomit-frame-pointer implied in optimized flag spare one register + // let us stay away from error: ‘asm’ operand has impossible constraints --- nihui +#if __OPTIMIZE__ + asm volatile( + "ldr r2, [%0], #4 \n" // int8x4_t _pA = *((int8x4_t*)pA); pA += 4; + "ldr r4, [%1], #4 \n" // int8x4_t _pB = *((int8x4_t*)pB); pB += 4; + "ror r3, r2, #8 \n" // int8x4_t _pA_r8 = __ror(_pA, 8); + "ror r5, r4, #8 \n" // int8x4_t _pB_r8 = __ror(_pB, 8); + "sxtb16 r2, r2 \n" // int16x2_t _pA0 = __sxtb16(_pA); + "sxtb16 r4, r4 \n" // int16x2_t _pA1 = __sxtb16(_pA_r8); + "sxtb16 r3, r3 \n" // int16x2_t _pB0 = __sxtb16(_pB); + "sxtb16 r5, r5 \n" // int16x2_t _pB1 = __sxtb16(_pB_r8); + "smlad %2, r2, r4, %2 \n" // sum00 = __smlad(_pA0, _pB0, sum00); + "smlad %3, r3, r4, %3 \n" // sum10 = __smlad(_pA1, _pB0, sum10); + "smlad %4, r2, r5, %4 \n" // sum01 = __smlad(_pA0, _pB1, sum01); + "smlad %5, r3, r5, %5 \n" // sum11 = __smlad(_pA1, _pB1, sum11); + : "=r"(pA), + "=r"(pB), + "=r"(sum00), + "=r"(sum10), + "=r"(sum01), + "=r"(sum11) + : "0"(pA), + "1"(pB), + "2"(sum00), + "3"(sum10), + "4"(sum01), + "5"(sum11) + : "memory", "r2", "r3", "r4", "r5"); +#else + int _pA0 = *((int*)pA); + int _pB0 = *((int*)pB); + int _pA1; + int _pB1; + asm volatile("ror %0, %1, #8" + : "=r"(_pA1) + : "r"(_pA0) + :); + asm volatile("ror %0, %1, #8" + : "=r"(_pB1) + : "r"(_pB0) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pA0) + : "0"(_pA0) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pA1) + : "0"(_pA1) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pB0) + : "0"(_pB0) + :); + asm volatile("sxtb16 %0, %0" + : "=r"(_pB1) + : "0"(_pB1) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum00) + : "0"(sum00), "r"(_pA0), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum10) + : "0"(sum10), "r"(_pA1), "r"(_pB0) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum01) + : "0"(sum01), "r"(_pA0), "r"(_pB1) + :); + asm volatile("smlad %0, %2, %3, %0" + : "=r"(sum11) + : "0"(sum11), "r"(_pA1), "r"(_pB1) + :); + pA += 4; + pB += 4; +#endif + } +#endif // __ARM_FEATURE_SIMD32 && NCNN_GNU_INLINE_ASM + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum10 += pA[1] * pB[0]; + sum01 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + outptr[0] = sum00; + outptr[1] = sum10; + outptr[2] = sum01; + outptr[3] = sum11; + + outptr += 4; +#endif // __ARM_NEON + } + for (; jj < max_jj; jj += 1) + { +#if __ARM_NEON + int32x2_t _sum; + + if (k == 0) + { + _sum = vdup_n_s32(0); + } + else + { + _sum = vld1_s32(outptr); + } +#else // __ARM_NEON + int sum0; + int sum1; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } +#endif // __ARM_NEON + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum0 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int8x16_t _pBB = vcombine_s8(_pB, _pB); + + _sum0 = vdotq_s32(_sum0, _pA, _pBB); + + pA += 16; + pB += 8; + } + int32x2_t _ss = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#else // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x16_t _pA = vld1q_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + _sum = vdot_lane_s32(_sum, vget_low_s8(_pA), _pB, 0); + _sum = vdot_lane_s32(_sum, vget_high_s8(_pA), _pB, 1); + + pA += 16; + pB += 8; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s32(vld1_dup_s32((const int*)pB)); + + _sum = vdot_s32(_sum, _pA, _pB); + + pA += 8; + pB += 4; + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum0 = vdupq_n_s32(0); + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pB)), 0)); + + _pB = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pB), vreinterpret_s16_s8(_pB)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 8; + pB += 4; + } + int32x2_t _ss = vadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#endif // __ARM_FEATURE_DOTPROD + int sum0 = vget_lane_s32(_sum, 0); + int sum1 = vget_lane_s32(_sum, 1); + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[2] * pB[0]; + sum1 += pA[3] * pB[1]; + pA += 4; + pB += 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + outptr[0] = sum0; + outptr[1] = sum1; + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + const signed char* pB = pBT; + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0; + int32x4_t _sum1; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + _sum1 = vld1q_s32(outptr + 4); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); + int32x4_t _sum10 = vdupq_n_s32(0); + int32x4_t _sum11 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _pAA = vcombine_s8(_pA, _pA); + _sum00 = vdotq_s32(_sum00, _pAA, _pB0); + _sum01 = vdotq_s32(_sum01, _pAA, _pB1); + _sum10 = vdotq_s32(_sum10, _pAA, _pB2); + _sum11 = vdotq_s32(_sum11, _pAA, _pB3); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pB0, _pA, 0); + _sum1 = vdotq_lane_s32(_sum1, _pB1, _pA, 0); + _sum0 = vdotq_lane_s32(_sum0, _pB2, _pA, 1); + _sum1 = vdotq_lane_s32(_sum1, _pB3, _pA, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 8; + pB += 64; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_sum00, _sum01)); + _sum1 = vaddq_s32(_sum1, vpaddq_s32(_sum10, _sum11)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + int32x4_t _sum4 = vdupq_n_s32(0); + int32x4_t _sum5 = vdupq_n_s32(0); + int32x4_t _sum6 = vdupq_n_s32(0); + int32x4_t _sum7 = vdupq_n_s32(0); + for (; kk + 15 < max_kk; kk += 16) + { + // TODO + // __builtin_prefetch(pA + 16); + // __builtin_prefetch(pB + 128); + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + int8x16_t _pB4 = vld1q_s8(pB + 64); + int8x16_t _pB5 = vld1q_s8(pB + 80); + int8x16_t _pB6 = vld1q_s8(pB + 96); + int8x16_t _pB7 = vld1q_s8(pB + 112); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 3)); + int8x8_t _pA4 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 0)); + int8x8_t _pA5 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 1)); + int8x8_t _pA6 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 2)); + int8x8_t _pA7 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA2, vget_low_s8(_pB2)); + int16x8_t _s3 = vmull_s8(_pA2, vget_high_s8(_pB2)); + int16x8_t _s4 = vmull_s8(_pA4, vget_low_s8(_pB4)); + int16x8_t _s5 = vmull_s8(_pA4, vget_high_s8(_pB4)); + int16x8_t _s6 = vmull_s8(_pA6, vget_low_s8(_pB6)); + int16x8_t _s7 = vmull_s8(_pA6, vget_high_s8(_pB6)); + _s0 = vmlal_s8(_s0, _pA1, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA1, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA3, vget_low_s8(_pB3)); + _s3 = vmlal_s8(_s3, _pA3, vget_high_s8(_pB3)); + _s4 = vmlal_s8(_s4, _pA5, vget_low_s8(_pB5)); + _s5 = vmlal_s8(_s5, _pA5, vget_high_s8(_pB5)); + _s6 = vmlal_s8(_s6, _pA7, vget_low_s8(_pB7)); + _s7 = vmlal_s8(_s7, _pA7, vget_high_s8(_pB7)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + _sum4 = vpadalq_s16(_sum4, _s4); + _sum5 = vpadalq_s16(_sum5, _s5); + _sum6 = vpadalq_s16(_sum6, _s6); + _sum7 = vpadalq_s16(_sum7, _s7); + + pA += 16; + pB += 128; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + int16x8_t _s2 = vmull_s8(_pA2, vget_low_s8(_pB2)); + int16x8_t _s3 = vmull_s8(_pA2, vget_high_s8(_pB2)); + _s0 = vmlal_s8(_s0, _pA1, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA1, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA3, vget_low_s8(_pB3)); + _s3 = vmlal_s8(_s3, _pA3, vget_high_s8(_pB3)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + pA += 8; + pB += 64; + } + _sum0 = vaddq_s32(_sum0, _sum2); + _sum1 = vaddq_s32(_sum1, _sum3); + _sum0 = vaddq_s32(_sum0, _sum4); + _sum1 = vaddq_s32(_sum1, _sum5); + _sum0 = vaddq_s32(_sum0, _sum6); + _sum1 = vaddq_s32(_sum1, _sum7); + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pA)), 0)); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_DOTPROD + _sum0 = vdotq_lane_s32(_sum0, _pB0, _pA, 0); + _sum1 = vdotq_lane_s32(_sum1, _pB1, _pA, 0); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA0, vget_high_s8(_pB0)); + _s0 = vmlal_s8(_s0, _pA1, vget_low_s8(_pB1)); + _s1 = vmlal_s8(_s1, _pA1, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 32; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vreinterpret_s8_s16(vld1_dup_s16((const short*)pA)); + int8x16_t _pB = vld1q_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, vget_low_s8(_pB)); + int16x8_t _s1 = vmull_s8(_pA, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 2; + pB += 16; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vld1_dup_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + _sum1 = vaddw_s16(_sum1, vget_high_s16(_s0)); + + pA += 1; + pB += 8; + } + + vst1q_s32(outptr, _sum0); + vst1q_s32(outptr + 4, _sum1); + + outptr += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0; + + if (k == 0) + { + _sum0 = vdupq_n_s32(0); + } + else + { + _sum0 = vld1q_s32(outptr); + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_FEATURE_DOTPROD + { +#if __ARM_FEATURE_MATMUL_INT8 + int32x4_t _sum00 = vdupq_n_s32(0); + int32x4_t _sum01 = vdupq_n_s32(0); +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x16_t _pAA = vcombine_s8(_pA, _pA); + _sum00 = vdotq_s32(_sum00, _pAA, _pB0); + _sum01 = vdotq_s32(_sum01, _pAA, _pB1); +#else // __ARM_FEATURE_MATMUL_INT8 + _sum0 = vdotq_lane_s32(_sum0, _pB0, _pA, 0); + _sum0 = vdotq_lane_s32(_sum0, _pB1, _pA, 1); +#endif // __ARM_FEATURE_MATMUL_INT8 + + pA += 8; + pB += 32; + } +#if __ARM_FEATURE_MATMUL_INT8 + _sum0 = vaddq_s32(_sum0, vpaddq_s32(_sum00, _sum01)); +#endif // __ARM_FEATURE_MATMUL_INT8 + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; kk + 15 < max_kk; kk += 16) + { + // TODO + // __builtin_prefetch(pA + 16); + // __builtin_prefetch(pB + 64); + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + int8x16_t _pB2 = vld1q_s8(pB + 32); + int8x16_t _pB3 = vld1q_s8(pB + 48); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_low_s8(_pA)), 3)); + int8x8_t _pA4 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 0)); + int8x8_t _pA5 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 1)); + int8x8_t _pA6 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 2)); + int8x8_t _pA7 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vget_high_s8(_pA)), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA2, vget_low_s8(_pB1)); + int16x8_t _s2 = vmull_s8(_pA4, vget_low_s8(_pB2)); + int16x8_t _s3 = vmull_s8(_pA6, vget_low_s8(_pB3)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB1)); + _s2 = vmlal_s8(_s2, _pA5, vget_high_s8(_pB2)); + _s3 = vmlal_s8(_s3, _pA7, vget_high_s8(_pB3)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + pA += 16; + pB += 64; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int8x8_t _pA2 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 2)); + int8x8_t _pA3 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 3)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA2, vget_low_s8(_pB1)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 8; + pB += 32; + } + _sum0 = vaddq_s32(_sum0, _sum1); + _sum0 = vaddq_s32(_sum0, _sum2); + _sum0 = vaddq_s32(_sum0, _sum3); + } +#endif // __ARM_FEATURE_DOTPROD + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_DOTPROD + _sum0 = vdotq_lane_s32(_sum0, _pB, _pA, 0); +#else // __ARM_FEATURE_DOTPROD + int8x8_t _pA0 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 0)); + int8x8_t _pA1 = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(_pA), 1)); + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 4; + pB += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + int8x8_t _pA = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(pA)), 0)); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 2; + pB += 8; + } + for (; kk < max_kk; kk += 1) + { + int8x8_t _pA = vld1_dup_s8(pA); + int8x8_t _pB = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pB)), 0)); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vaddw_s16(_sum0, vget_low_s16(_s0)); + + pA += 1; + pB += 4; + } + + vst1q_s32(outptr, _sum0); + + outptr += 4; + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { +#if __ARM_NEON + int32x2_t _sum; + + if (k == 0) + { + _sum = vdup_n_s32(0); + } + else + { + _sum = vld1_s32(outptr); + } +#else // __ARM_NEON + int sum0; + int sum1; + + if (k == 0) + { + sum0 = 0; + sum1 = 0; + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } +#endif // __ARM_NEON + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + { + int32x4_t _sum0 = vdupq_n_s32(0); + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + int8x16_t _pAA = vcombine_s8(_pA, _pA); + + _sum0 = vdotq_s32(_sum0, _pAA, _pB); + + pA += 8; + pB += 16; + } + int32x2_t _ss = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#else // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + _sum = vdot_lane_s32(_sum, vget_low_s8(_pB), _pA, 0); + _sum = vdot_lane_s32(_sum, vget_high_s8(_pB), _pA, 1); + + pA += 8; + pB += 16; + } +#endif // __ARM_FEATURE_MATMUL_INT8 + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vreinterpret_s8_s32(vld1_dup_s32((const int*)pA)); + int8x8_t _pB = vld1_s8(pB); + + _sum = vdot_s32(_sum, _pA, _pB); + + pA += 4; + pB += 8; + } +#else // __ARM_FEATURE_DOTPROD + { + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + + int16x8x2_t _pAA = vzipq_s16(vreinterpretq_s16_s8(_pA), vreinterpretq_s16_s8(_pA)); + + int8x8_t _pA0 = vreinterpret_s8_s16(vget_low_s16(_pAA.val[0])); + int8x8_t _pA1 = vreinterpret_s8_s16(vget_high_s16(_pAA.val[0])); + int8x8_t _pA2 = vreinterpret_s8_s16(vget_low_s16(_pAA.val[1])); + int8x8_t _pA3 = vreinterpret_s8_s16(vget_high_s16(_pAA.val[1])); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(_pA2, vget_low_s8(_pB1)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, _pA3, vget_high_s8(_pB1)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + + pA += 16; + pB += 32; + } + _sum0 = vaddq_s32(_sum0, _sum1); + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + + int16x4x2_t _pAA = vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)); + + int8x8_t _pA0 = vreinterpret_s8_s16(_pAA.val[0]); + int8x8_t _pA1 = vreinterpret_s8_s16(_pAA.val[1]); + + int16x8_t _s0 = vmull_s8(_pA0, vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, _pA1, vget_high_s8(_pB)); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 8; + pB += 16; + } + for (; kk + 3 < max_kk; kk += 4) + { + int8x8_t _pA = vreinterpret_s8_s32(vdup_lane_s32(vreinterpret_s32_s8(vld1_s8(pA)), 0)); + int8x8_t _pB = vld1_s8(pB); + + _pA = vreinterpret_s8_s16(vzip_s16(vreinterpret_s16_s8(_pA), vreinterpret_s16_s8(_pA)).val[0]); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum0 = vpadalq_s16(_sum0, _s0); + + pA += 4; + pB += 8; + } + int32x2_t _ss = vadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + _sum = vadd_s32(_sum, _ss); + } +#endif // __ARM_FEATURE_DOTPROD + int sum0 = vget_lane_s32(_sum, 0); + int sum1 = vget_lane_s32(_sum, 1); + for (; kk + 1 < max_kk; kk += 2) + { + sum0 += pA[0] * pB[0]; + sum0 += pA[1] * pB[1]; + sum1 += pA[0] * pB[2]; + sum1 += pA[1] * pB[3]; + pA += 2; + pB += 4; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + pA += 1; + pB += 2; + } + + outptr[0] = sum0; + outptr[1] = sum1; + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + int sum; + + if (k == 0) + { + sum = 0; + } + else + { + sum = outptr[0]; + } + + const signed char* pA = pAT; + int kk = 0; +#if __ARM_NEON + int32x4_t _sum = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + for (; kk + 31 < max_kk; kk += 32) + { + int8x16_t _pA0 = vld1q_s8(pA); + int8x16_t _pA1 = vld1q_s8(pA + 16); + int8x16_t _pB0 = vld1q_s8(pB); + int8x16_t _pB1 = vld1q_s8(pB + 16); + +#if __ARM_FEATURE_DOTPROD + _sum = vdotq_s32(_sum, _pA0, _pB0); + _sum1 = vdotq_s32(_sum1, _pA1, _pB1); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA0), vget_low_s8(_pB0)); + int16x8_t _s1 = vmull_s8(vget_low_s8(_pA1), vget_low_s8(_pB1)); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA0), vget_high_s8(_pB0)); + _s1 = vmlal_s8(_s1, vget_high_s8(_pA1), vget_high_s8(_pB1)); + _sum = vpadalq_s16(_sum, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); +#endif // __ARM_FEATURE_DOTPROD + + pA += 32; + pB += 32; + } + _sum = vaddq_s32(_sum, _sum1); + for (; kk + 15 < max_kk; kk += 16) + { + int8x16_t _pA = vld1q_s8(pA); + int8x16_t _pB = vld1q_s8(pB); + +#if __ARM_FEATURE_DOTPROD + _sum = vdotq_s32(_sum, _pA, _pB); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _s0 = vmull_s8(vget_low_s8(_pA), vget_low_s8(_pB)); + _s0 = vmlal_s8(_s0, vget_high_s8(_pA), vget_high_s8(_pB)); + _sum = vpadalq_s16(_sum, _s0); +#endif // __ARM_FEATURE_DOTPROD + + pA += 16; + pB += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + int8x8_t _pA = vld1_s8(pA); + int8x8_t _pB = vld1_s8(pB); + + int16x8_t _s0 = vmull_s8(_pA, _pB); + _sum = vpadalq_s16(_sum, _s0); + + pA += 8; + pB += 8; + } +#if __aarch64__ + sum += vaddvq_s32(_sum); +#else + int32x2_t _ss = vadd_s32(vget_low_s32(_sum), vget_high_s32(_sum)); + _ss = vpadd_s32(_ss, _ss); + sum += vget_lane_s32(_ss, 0); +#endif +#endif // __ARM_NEON + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + pA += 1; + pB += 1; + } + + outptr[0] = sum; + + outptr += 1; + } + + pAT += max_kk; + } +} + +static void get_optimal_tile_mnk_int8(int M, int N, int K, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) +{ + // resolve optimal tile size from cache size + const size_t l2_cache_size = get_cpu_level2_cache_size(); + + if (nT == 0) + nT = get_physical_big_cpu_count(); + + int tile_size = (int)sqrtf((float)l2_cache_size / (2 * sizeof(signed char) + sizeof(int))); + + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + TILE_K = std::max(8, tile_size / 8 * 8); + + if (K > 0) + { + int nn_K = (K + TILE_K - 1) / TILE_K; + TILE_K = std::min(TILE_K, ((K + nn_K - 1) / nn_K + 7) / 8 * 8); + + if (nn_K == 1) + { + tile_size = (int)((float)l2_cache_size / 2 / sizeof(signed char) / TILE_K); + + TILE_M = std::max(8, tile_size / 8 * 8); + TILE_N = std::max(4, tile_size / 4 * 4); + } + } + + TILE_M *= std::min(nT, get_physical_cpu_count()); + + if (M > 0) + { + int nn_M = (M + TILE_M - 1) / TILE_M; + TILE_M = std::min(TILE_M, ((M + nn_M - 1) / nn_M + 7) / 8 * 8); + } + + if (N > 0) + { + int nn_N = (N + TILE_N - 1) / TILE_N; + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); + } + + if (nT > 1) + { + TILE_M = std::min(TILE_M, (std::max(1, TILE_M / nT) + 7) / 8 * 8); + } + + // always take constant TILE_M/N/K value when provided + if (constant_TILE_M > 0) + { + TILE_M = (constant_TILE_M + 7) / 8 * 8; + } + + if (constant_TILE_N > 0) + { + TILE_N = (constant_TILE_N + 3) / 4 * 4; + } + + if (constant_TILE_K > 0) + { + TILE_K = (constant_TILE_K + 7) / 8 * 8; + } +} diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h new file mode 100644 index 000000000000..dac2a901a5d1 --- /dev/null +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -0,0 +1,9903 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_bf16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_bf16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_bf16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_bf16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 +void pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void transpose_pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); +void pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); +#endif + +static void compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.w; + + // NCNN_LOGE("compute_A_tile_bf16_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif + + for (int ii = 0; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); + + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + +#endif + ps += 4; + pods += 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + for (int ii = 0; ii < max_ii; ii++) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; + + float absmax = 0.f; + int kk = 0; +#if __ARM_NEON + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (; kk + 15 < K; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 7 < K; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); +#endif // __ARM_NEON + for (; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(p0[0]))); + p0++; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_A_tile_bf16_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_A_tile_bf16_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("pack_A_tile_bf16_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + uint16x8x4_t _q = vld4q_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale0, 0); + float32x4_t _p5 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale0, 1); + float32x4_t _p6 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale0, 2); + float32x4_t _p7 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale0, 3); + float32x4_t _p8 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[0])), _scale1, 0); + float32x4_t _p9 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[1])), _scale1, 1); + float32x4_t _pa = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[2])), _scale1, 2); + float32x4_t _pb = vmulq_laneq_f32(bfloat2float(vget_low_u16(_q.val[3])), _scale1, 3); + float32x4_t _pc = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[0])), _scale1, 0); + float32x4_t _pd = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[1])), _scale1, 1); + float32x4_t _pe = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[2])), _scale1, 2); + float32x4_t _pf = vmulq_laneq_f32(bfloat2float(vget_high_u16(_q.val[3])), _scale1, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale0); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale0); + _p8 = vmulq_f32(_p8, _scale1); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale1); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale1); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale1); + _pf = vmulq_f32(_pf, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + uint16x4x4_t _q = vld4_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(_p.val[0]), _scale0, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(_p.val[1]), _scale0, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(_p.val[2]), _scale0, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(_p.val[3]), _scale0, 3); + float32x4_t _p4 = vmulq_laneq_f32(bfloat2float(_q.val[0]), _scale1, 0); + float32x4_t _p5 = vmulq_laneq_f32(bfloat2float(_q.val[1]), _scale1, 1); + float32x4_t _p6 = vmulq_laneq_f32(bfloat2float(_q.val[2]), _scale1, 2); + float32x4_t _p7 = vmulq_laneq_f32(bfloat2float(_q.val[3]), _scale1, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale0); + _p4 = vmulq_f32(_p4, _scale1); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale1); + _p7 = vmulq_f32(_p7, _scale1); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep * 4); + + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p0n = bfloat2float(vget_high_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p1n = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p0n = vmulq_f32(_p0n, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p1n = vmulq_f32(_p1n, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p0n, _p1n); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 0); + _p2 = vmulq_laneq_f32(_p2, _scale0, 1); + _p3 = vmulq_laneq_f32(_p3, _scale0, 1); + _p4 = vmulq_laneq_f32(_p4, _scale0, 2); + _p5 = vmulq_laneq_f32(_p5, _scale0, 2); + _p6 = vmulq_laneq_f32(_p6, _scale0, 3); + _p7 = vmulq_laneq_f32(_p7, _scale0, 3); + _p8 = vmulq_laneq_f32(_p8, _scale1, 0); + _p9 = vmulq_laneq_f32(_p9, _scale1, 0); + _pa = vmulq_laneq_f32(_pa, _scale1, 1); + _pb = vmulq_laneq_f32(_pb, _scale1, 1); + _pc = vmulq_laneq_f32(_pc, _scale1, 2); + _pd = vmulq_laneq_f32(_pd, _scale1, 2); + _pe = vmulq_laneq_f32(_pe, _scale1, 3); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + A_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + A_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + A_hstep * 7)); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + float32x4x2_t _scale01 = vzipq_f32(_scale0, _scale0); + float32x4x2_t _scale23 = vzipq_f32(_scale1, _scale1); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + _p45 = vmulq_f32(_p45, _scale23.val[0]); + _p67 = vmulq_f32(_p67, _scale23.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale, 3); + float32x4_t _p4 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale, 0); + float32x4_t _p5 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale, 1); + float32x4_t _p6 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale, 2); + float32x4_t _p7 = vmulq_laneq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale, 3); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + + float32x4_t _p0 = vmulq_laneq_f32(bfloat2float(_p.val[0]), _scale, 0); + float32x4_t _p1 = vmulq_laneq_f32(bfloat2float(_p.val[1]), _scale, 1); + float32x4_t _p2 = vmulq_laneq_f32(bfloat2float(_p.val[2]), _scale, 2); + float32x4_t _p3 = vmulq_laneq_f32(bfloat2float(_p.val[3]), _scale, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 0); + _p2 = vmulq_laneq_f32(_p2, _scale, 1); + _p3 = vmulq_laneq_f32(_p3, _scale, 1); + _p4 = vmulq_laneq_f32(_p4, _scale, 2); + _p5 = vmulq_laneq_f32(_p5, _scale, 2); + _p6 = vmulq_laneq_f32(_p6, _scale, 3); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + float32x4x2_t _scale01 = vzipq_f32(_scale, _scale); + + _p01 = vmulq_f32(_p01, _scale01.val[0]); + _p23 = vmulq_f32(_p23, _scale01.val[1]); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[A_hstep], _p, 1); + _p = vset_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[A_hstep * 3], _p, 3); + float32x4_t _p0 = bfloat2float(_p); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale0); + _p2 = vmulq_f32(_p2, _scale1); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale0); + pp[2] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); + pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale0); + // pp[2] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); + // pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); + // pp += 4; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep + k; + + const float scale = scales[ii]; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + // pp += 2; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) +{ + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int K = A.dims == 3 ? A.c : A.h; + + // NCNN_LOGE("transpose_compute_A_tile_bf16_int8_scales %d %d", max_ii, elempack); + + const float v127_B_scale = 127.f * B_scale; + +#if __ARM_NEON +#if __aarch64__ + float32x4_t _v127 = vdupq_n_f32(127.f); + float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); +#endif +#endif + + float* ps = scales; + float* pods = out_descales; + +#if __ARM_NEON + if (elempack == 4) + { + int ii = 0; + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + for (int kk = 0; kk < K; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); + float32x2_t _aa2 = vmax_f32(vget_low_f32(_absmax2), vget_high_f32(_absmax2)); + float32x2_t _aa3 = vmax_f32(vget_low_f32(_absmax3), vget_high_f32(_absmax3)); + float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); + float32x2_t _aa23 = vpmax_f32(_aa2, _aa3); + float32x4_t _absmax = vcombine_f32(_aa01, _aa23); + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax); + float32x4_t _out_descale = vdivq_f32(_absmax, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } + for (; ii < max_ii; ii++) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 8)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 12)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep * 4; + } + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int ii = 0; +#if __ARM_NEON + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii); + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 4; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 2; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk < K; kk++) + { + float32x4_t _p = bfloat2float(vld1_u16(p0)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); + p0 += A_hstep; + } + +#if __aarch64__ + float32x4_t _scale = vdivq_f32(_v127, _absmax0); + float32x4_t _out_descale = vdivq_f32(_absmax0, _v127_B_scale); + + vst1q_f32(ps, _scale); + vst1q_f32(pods, _out_descale); +#else + float tmp[4]; + vst1q_f32(tmp, _absmax0); + + ps[0] = 127.f / tmp[0]; + ps[1] = 127.f / tmp[1]; + ps[2] = 127.f / tmp[2]; + ps[3] = 127.f / tmp[3]; + + pods[0] = tmp[0] / v127_B_scale; + pods[1] = tmp[1] / v127_B_scale; + pods[2] = tmp[2] / v127_B_scale; + pods[3] = tmp[3] / v127_B_scale; + + // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); + // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); + // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); +#endif + + ps += 4; + pods += 4; + } +#endif // __ARM_NEON + for (; ii < max_ii; ii++) + { + const unsigned short* p0 = (const unsigned short*)A + (i + ii); + + float absmax = 0.f; + for (int kk = 0; kk < K; kk++) + { + absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(p0[0]))); + p0 += A_hstep; + } + + ps[0] = 127.f / absmax; + pods[0] = absmax / v127_B_scale; + ps++; + pods++; + } + } +} + +static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_A_tile_bf16_to_int8_i8mm(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_A_tile_bf16_to_int8_asimddp(A, AT, i, max_ii, k, max_kk, scales); + return; + } +#endif + + const int elempack = A.elempack; + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + + // NCNN_LOGE("transpose_pack_A_tile_bf16_to_int8 %d %d", max_ii, elempack); + + signed char* pp = AT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); + float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + _p8 = vmulq_laneq_f32(_p8, _scale0, 0); + _p9 = vmulq_laneq_f32(_p9, _scale0, 1); + _pa = vmulq_laneq_f32(_pa, _scale0, 2); + _pb = vmulq_laneq_f32(_pb, _scale0, 3); + _pc = vmulq_laneq_f32(_pc, _scale1, 0); + _pd = vmulq_laneq_f32(_pd, _scale1, 1); + _pe = vmulq_laneq_f32(_pe, _scale1, 2); + _pf = vmulq_laneq_f32(_pf, _scale1, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_laneq_f32(_p0, _scale0, 0); + _p1 = vmulq_laneq_f32(_p1, _scale0, 1); + _p2 = vmulq_laneq_f32(_p2, _scale0, 2); + _p3 = vmulq_laneq_f32(_p3, _scale0, 3); + _p4 = vmulq_laneq_f32(_p4, _scale1, 0); + _p5 = vmulq_laneq_f32(_p5, _scale1, 1); + _p6 = vmulq_laneq_f32(_p6, _scale1, 2); + _p7 = vmulq_laneq_f32(_p7, _scale1, 3); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + A_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + A_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + A_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + _p8 = vmulq_f32(_p8, _scale0); + _p9 = vmulq_f32(_p9, _scale1); + _pa = vmulq_f32(_pa, _scale0); + _pb = vmulq_f32(_pb, _scale1); + _pc = vmulq_f32(_pc, _scale0); + _pd = vmulq_f32(_pd, _scale1); + _pe = vmulq_f32(_pe, _scale0); + _pf = vmulq_f32(_pf, _scale1); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + _p4 = vmulq_f32(_p4, _scale0); + _p5 = vmulq_f32(_p5, _scale1); + _p6 = vmulq_f32(_p6, _scale0); + _p7 = vmulq_f32(_p7, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep); + + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + float32x4_t _scale = vld1q_f32((const float*)scales + ii); + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + A_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + A_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + _p4 = vmulq_laneq_f32(_p4, _scale, 0); + _p5 = vmulq_laneq_f32(_p5, _scale, 1); + _p6 = vmulq_laneq_f32(_p6, _scale, 2); + _p7 = vmulq_laneq_f32(_p7, _scale, 3); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_laneq_f32(_p0, _scale, 0); + _p1 = vmulq_laneq_f32(_p1, _scale, 1); + _p2 = vmulq_laneq_f32(_p2, _scale, 2); + _p3 = vmulq_laneq_f32(_p3, _scale, 3); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + A_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + A_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + A_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 2; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + pp += 4; + p0 += A_hstep; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + const float scale0 = scales[ii]; + const float scale1 = scales[ii + 1]; + +#if __ARM_NEON + float32x4_t _scale0 = vdupq_n_f32(scale0); + float32x4_t _scale1 = vdupq_n_f32(scale1); + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + A_hstep * 4); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + _p2 = vmulq_f32(_p2, _scale0); + _p3 = vmulq_f32(_p3, _scale1); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale0); + _p1 = vmulq_f32(_p1, _scale1); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + float32x4_t _scale = vzipq_f32(_scale0, _scale1).val[0]; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 4 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 6 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep + 1], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 3], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 3 + 1], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 5], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 5 + 1], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 7 + 1], _q, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p46 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p57 = bfloat2float(vget_high_u16(_q)); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep + 1], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 3 + 1], _p, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_high_u16(_p)); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep + 0]) * scale0); + pp[2] = float2int8(bfloat16_to_float32(p0[1]) * scale1); + pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); + pp += 4; + p0 += A_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale1); + pp += 2; + p0 += A_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + const unsigned short* p0 = (const unsigned short*)A + k * A_hstep + (i + ii) * elempack; + + const float scale = scales[ii]; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 8)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 12)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + A_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[2]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[3]) * scale); + pp += 4; + p0 += A_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[A_hstep * 8], _q, 0); + _q = vsetq_lane_u16(p0[A_hstep * 9], _q, 1); + _q = vsetq_lane_u16(p0[A_hstep * 10], _q, 2); + _q = vsetq_lane_u16(p0[A_hstep * 11], _q, 3); + _q = vsetq_lane_u16(p0[A_hstep * 12], _q, 4); + _q = vsetq_lane_u16(p0[A_hstep * 13], _q, 5); + _q = vsetq_lane_u16(p0[A_hstep * 14], _q, 6); + _q = vsetq_lane_u16(p0[A_hstep * 15], _q, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += A_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += A_hstep * 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + // pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale); + // pp += 2; + // p0 += A_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0 += A_hstep; + } + } + } +} + +static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + pack_B_tile_bf16_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + pack_B_tile_bf16_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("pack_B_tile_bf16_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + uint16x8x4_t _q = vld4q_u16(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale); + float32x4_t _p4 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale); + float32x4_t _p5 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale); + float32x4_t _p6 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale); + float32x4_t _p7 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale); + float32x4_t _p8 = vmulq_f32(bfloat2float(vget_low_u16(_q.val[0])), _scale); + float32x4_t _p9 = vmulq_f32(bfloat2float(vget_low_u16(_q.val[1])), _scale); + float32x4_t _pa = vmulq_f32(bfloat2float(vget_low_u16(_q.val[2])), _scale); + float32x4_t _pb = vmulq_f32(bfloat2float(vget_low_u16(_q.val[3])), _scale); + float32x4_t _pc = vmulq_f32(bfloat2float(vget_high_u16(_q.val[0])), _scale); + float32x4_t _pd = vmulq_f32(bfloat2float(vget_high_u16(_q.val[1])), _scale); + float32x4_t _pe = vmulq_f32(bfloat2float(vget_high_u16(_q.val[2])), _scale); + float32x4_t _pf = vmulq_f32(bfloat2float(vget_high_u16(_q.val[3])), _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r4 = float2int8(_p8, _pc); + int8x8_t _r5 = float2int8(_p9, _pd); + int8x8_t _r6 = float2int8(_pa, _pe); + int8x8_t _r7 = float2int8(_pb, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p8, _p9); + int8x8_t _r3 = float2int8(_pa, _pb); + int8x8_t _r4 = float2int8(_p4, _p5); + int8x8_t _r5 = float2int8(_p6, _p7); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p8), float2int8(_p2, _pa)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p9), float2int8(_p3, _pb)); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(float2int8(_p4, _pc), float2int8(_p6, _pe)); + _r23.val[1] = vcombine_s8(float2int8(_p5, _pd), float2int8(_p7, _pf)); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + uint16x4x4_t _q = vld4_u16(p0 + B_hstep * 4); + + float32x4_t _p0 = vmulq_f32(bfloat2float(_p.val[0]), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(_p.val[1]), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(_p.val[2]), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(_p.val[3]), _scale); + float32x4_t _p4 = vmulq_f32(bfloat2float(_q.val[0]), _scale); + float32x4_t _p5 = vmulq_f32(bfloat2float(_q.val[1]), _scale); + float32x4_t _p6 = vmulq_f32(bfloat2float(_q.val[2]), _scale); + float32x4_t _p7 = vmulq_f32(bfloat2float(_q.val[3]), _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p4), float2int8(_p2, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p5), float2int8(_p3, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep * 4); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_pc, _pe)); + int16x4_t _t4 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t5 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4_t _t6 = vreinterpret_s16_s8(float2int8(_p9, _pb)); + int16x4_t _t7 = vreinterpret_s16_s8(float2int8(_pd, _pf)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int16x4x2_t _t45 = vuzp_s16(_t4, _t5); + int16x4x2_t _t67 = vuzp_s16(_t6, _t7); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); + int8x8_t _r4 = vreinterpret_s8_s16(_t45.val[0]); + int8x8_t _r5 = vreinterpret_s8_s16(_t67.val[0]); + int8x8_t _r6 = vreinterpret_s8_s16(_t45.val[1]); + int8x8_t _r7 = vreinterpret_s8_s16(_t67.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); + + pp += 64; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + B_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + B_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + B_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r2 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0++; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8x4_t _p = vld4q_u16(p0); + + float32x4_t _p0 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[0])), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[1])), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[2])), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(vget_low_u16(_p.val[3])), _scale); + float32x4_t _p4 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[0])), _scale); + float32x4_t _p5 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[1])), _scale); + float32x4_t _p6 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[2])), _scale); + float32x4_t _p7 = vmulq_f32(bfloat2float(vget_high_u16(_p.val[3])), _scale); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += 32; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x4x4_t _p = vld4_u16(p0); + + float32x4_t _p0 = vmulq_f32(bfloat2float(_p.val[0]), _scale); + float32x4_t _p1 = vmulq_f32(bfloat2float(_p.val[1]), _scale); + float32x4_t _p2 = vmulq_f32(bfloat2float(_p.val[2]), _scale); + float32x4_t _p3 = vmulq_f32(bfloat2float(_p.val[3]), _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += 16; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[B_hstep], _p, 1); + _p = vset_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[B_hstep * 3], _p, 3); + float32x4_t _p0 = bfloat2float(_p); + + _p0 = vmulq_f32(_p0, _scale); + int8x8_t _r0 = float2int8(_p0, _p0); + + pp[0] = vget_lane_s8(_r0, 0); + pp[1] = vget_lane_s8(_r0, 1); + pp[2] = vget_lane_s8(_r0, 2); + pp[3] = vget_lane_s8(_r0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); + float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); + float32x4_t _t3 = vcombine_f32(vget_high_f32(_p1), vget_high_f32(_p3)); + int8x8_t _r0 = float2int8(_t0, _t1); + int8x8_t _r1 = float2int8(_t2, _t3); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + vst1_s8(pp + 8, _r1); + + pp += 16; + p0 += 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r0 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + // pp[2] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); + // pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); + // pp += 4; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); + pp += 2; + p0++; + } + } + } + for (; jj < max_jj; jj += 1) + { + const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k; + + // if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + // pp += 2; + // p0 += 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0++; + } + } + } +} + +static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_i8mm()) + { + transpose_pack_B_tile_bf16_to_int8_i8mm(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_pack_B_tile_bf16_to_int8_asimddp(B, BT, j, max_jj, k, max_kk, scale); + return; + } +#endif + + const int elempack = B.elempack; + const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + + // NCNN_LOGE("transpose_pack_B_tile_bf16_to_int8 %d %d", max_jj, elempack); + + signed char* pp = BT; + +#if __ARM_NEON + float32x4_t _scale = vdupq_n_f32(scale); +#endif + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 4 + 8); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 4 + 16); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 4 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p8); + int8x8_t _r1 = float2int8(_p1, _p9); + int8x8_t _r2 = float2int8(_p2, _pa); + int8x8_t _r3 = float2int8(_p3, _pb); + int8x8_t _r4 = float2int8(_p4, _pc); + int8x8_t _r5 = float2int8(_p5, _pd); + int8x8_t _r6 = float2int8(_p6, _pe); + int8x8_t _r7 = float2int8(_p7, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); + vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8_t _r45 = vreinterpretq_s16_s8(vcombine_s8(_r4, _r5)); + int16x8_t _r67 = vreinterpretq_s16_s8(vcombine_s8(_r6, _r7)); + int16x8x2_t _rr0 = vuzpq_s16(_r01, _r23); + int16x8x2_t _rr1 = vuzpq_s16(_r45, _r67); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr0.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr0.val[1])); + vst1q_s8(pp + 32, vreinterpretq_s8_s16(_rr1.val[0])); + vst1q_s8(pp + 48, vreinterpretq_s8_s16(_rr1.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + 16); + uint16x8_t _s = vld1q_u16(p0 + 24); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + +#if __ARM_FEATURE_DOTPROD + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); +#else // __ARM_FEATURE_DOTPROD + int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); + int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); + int16x8x2_t _rr = vuzpq_s16(_r01, _r23); + + vst1q_s8(pp, vreinterpretq_s8_s16(_rr.val[0])); + vst1q_s8(pp + 16, vreinterpretq_s8_s16(_rr.val[1])); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + uint16x8_t _t = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _u = vld1q_u16(p0 + B_hstep * 5); + uint16x8_t _v = vld1q_u16(p0 + B_hstep * 6); + uint16x8_t _w = vld1q_u16(p0 + B_hstep * 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + float32x4_t _p8 = bfloat2float(vget_low_u16(_t)); + float32x4_t _p9 = bfloat2float(vget_high_u16(_t)); + float32x4_t _pa = bfloat2float(vget_low_u16(_u)); + float32x4_t _pb = bfloat2float(vget_high_u16(_u)); + float32x4_t _pc = bfloat2float(vget_low_u16(_v)); + float32x4_t _pd = bfloat2float(vget_high_u16(_v)); + float32x4_t _pe = bfloat2float(vget_low_u16(_w)); + float32x4_t _pf = bfloat2float(vget_high_u16(_w)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + _p8 = vmulq_f32(_p8, _scale); + _p9 = vmulq_f32(_p9, _scale); + _pa = vmulq_f32(_pa, _scale); + _pb = vmulq_f32(_pb, _scale); + _pc = vmulq_f32(_pc, _scale); + _pd = vmulq_f32(_pd, _scale); + _pe = vmulq_f32(_pe, _scale); + _pf = vmulq_f32(_pf, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); + int8x8_t _r4 = float2int8(_p8, _p9); + int8x8_t _r5 = float2int8(_pa, _pb); + int8x8_t _r6 = float2int8(_pc, _pd); + int8x8_t _r7 = float2int8(_pe, _pf); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r04 = vzip_s8(_r0, _r4); + int8x8x2_t _r15 = vzip_s8(_r1, _r5); + int8x8x2_t _r26 = vzip_s8(_r2, _r6); + int8x8x2_t _r37 = vzip_s8(_r3, _r7); + int8x16x4_t _r0123; + _r0123.val[0] = vcombine_s8(_r04.val[0], _r04.val[1]); + _r0123.val[1] = vcombine_s8(_r15.val[0], _r15.val[1]); + _r0123.val[2] = vcombine_s8(_r26.val[0], _r26.val[1]); + _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); + + vst4q_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = _r0; + _r0123.val[1] = _r1; + _r0123.val[2] = _r2; + _r0123.val[3] = _r3; + int8x8x4_t _r4567; + _r4567.val[0] = _r4; + _r4567.val[1] = _r5; + _r4567.val[2] = _r6; + _r4567.val[3] = _r7; + + vst4_s8(pp, _r0123); + vst4_s8(pp + 32, _r4567); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(_r0, _r2); + _r01.val[1] = vcombine_s8(_r1, _r3); + int8x16x2_t _r23; + _r23.val[0] = vcombine_s8(_r4, _r6); + _r23.val[1] = vcombine_s8(_r5, _r7); + + vst2q_s8(pp, _r01); + vst2q_s8(pp + 32, _r23); +#endif // __ARM_FEATURE_DOTPROD + + pp += 64; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 2); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 3); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p1); + _r0123.val[1] = float2int8(_p2, _p3); + _r0123.val[2] = float2int8(_p4, _p5); + _r0123.val[3] = float2int8(_p6, _p7); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); + _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p1); + _r01.val[1] = float2int8(_p2, _p3); + + vst2_s8(pp, _r01); + + pp += 16; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r0 = float2int8(_p0, _p1); + + vst1_s8(pp, _r0); + + pp += 8; + p0 += B_hstep; + } + } + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + uint16x8_t _r = vld1q_u16(p0 + B_hstep * 4); + uint16x8_t _s = vld1q_u16(p0 + B_hstep * 4 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + float32x4_t _p4 = bfloat2float(vget_low_u16(_r)); + float32x4_t _p5 = bfloat2float(vget_high_u16(_r)); + float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); + float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p4); + int8x8_t _r1 = float2int8(_p1, _p5); + int8x8_t _r2 = float2int8(_p2, _p6); + int8x8_t _r3 = float2int8(_p3, _p7); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); + int8x8_t _r2 = float2int8(_p4, _p5); + int8x8_t _r3 = float2int8(_p6, _p7); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p6, _p7)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int16x4x2_t _t23 = vuzp_s16(_t2, _t3); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); + int8x8_t _r2 = vreinterpret_s8_s16(_t23.val[0]); + int8x8_t _r3 = vreinterpret_s8_s16(_t23.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + 8); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); + int16x4x2_t _t01 = vuzp_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 4; + } + } + if (elempack == 1) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + float32x4_t _p4 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p5 = bfloat2float(vld1_u16(p0 + B_hstep * 5)); + float32x4_t _p6 = bfloat2float(vld1_u16(p0 + B_hstep * 6)); + float32x4_t _p7 = bfloat2float(vld1_u16(p0 + B_hstep * 7)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + _p4 = vmulq_f32(_p4, _scale); + _p5 = vmulq_f32(_p5, _scale); + _p6 = vmulq_f32(_p6, _scale); + _p7 = vmulq_f32(_p7, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + float32x4x2_t _p04 = vzipq_f32(_p0, _p4); + float32x4x2_t _p15 = vzipq_f32(_p1, _p5); + float32x4x2_t _p26 = vzipq_f32(_p2, _p6); + float32x4x2_t _p37 = vzipq_f32(_p3, _p7); + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p04.val[0], _p04.val[1]); + _r0123.val[1] = float2int8(_p15.val[0], _p15.val[1]); + _r0123.val[2] = float2int8(_p26.val[0], _p26.val[1]); + _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); + + vst4_s8(pp, _r0123); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x4_t _r0123; + _r0123.val[0] = float2int8(_p0, _p4); + _r0123.val[1] = float2int8(_p1, _p5); + _r0123.val[2] = float2int8(_p2, _p6); + _r0123.val[3] = float2int8(_p3, _p7); + + vst4_s8(pp, _r0123); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int8x16x2_t _r01; + _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); + _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); + + vst2q_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 32; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 2)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 3)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD + transpose4x4_ps(_p0, _p1, _p2, _p3); + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); +#else // __ARM_FEATURE_DOTPROD + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p0, _p2); + _r01.val[1] = float2int8(_p1, _p3); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + float32x4x2_t _p01 = vzipq_f32(_p0, _p1); + int8x8_t _r01 = float2int8(_p01.val[0], _p01.val[1]); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 2; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[2]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[3]) * scale); + pp += 4; + p0 += B_hstep; + } + } + } +#endif // __ARM_NEON + for (; jj + 1 < max_jj; jj += 2) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = vld1q_u16(p0); + uint16x8_t _q = vld1q_u16(p0 + B_hstep * 4); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + +#if __ARM_FEATURE_DOTPROD +#if __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p1, _p3); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8_t _r0 = float2int8(_p0, _p1); + int8x8_t _r1 = float2int8(_p2, _p3); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4x2_t _t01 = vzip_s16(_t0, _t1); + int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); + int8x8_t _r1 = vreinterpret_s8_s16(_t01.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1q_s8(pp, vcombine_s8(_r0, _r1)); + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + uint16x8_t _p = vld1q_u16(p0); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + +#if __ARM_FEATURE_DOTPROD + int8x8_t _r01 = float2int8(_p0, _p1); +#else // __ARM_FEATURE_DOTPROD + float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); + float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); + int8x8_t _r01 = float2int8(_t0, _t1); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 7 < max_kk; kk += 8) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 4], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 4 + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 6], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 6 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p45 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p67 = bfloat2float(vget_high_u16(_q)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + _p45 = vmulq_f32(_p45, _scale); + _p67 = vmulq_f32(_p67, _scale); + + int8x8_t _r0 = float2int8(_p01, _p23); + int8x8_t _r1 = float2int8(_p45, _p67); + +#if __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vuzp_s8(_r0, _r1); + + vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); +#else // __ARM_FEATURE_MATMUL_INT8 + int8x8x2_t _r01 = vtrn_s8(_r0, _r1); + int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); + + vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); +#endif // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 4 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 6 + 1], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep + 1], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 3], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 3 + 1], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 5], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 5 + 1], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 7], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 7 + 1], _q, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p46 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p57 = bfloat2float(vget_high_u16(_q)); + + _p02 = vmulq_f32(_p02, _scale); + _p46 = vmulq_f32(_p46, _scale); + _p13 = vmulq_f32(_p13, _scale); + _p57 = vmulq_f32(_p57, _scale); + + int8x8x2_t _r01; + _r01.val[0] = float2int8(_p02, _p46); + _r01.val[1] = float2int8(_p13, _p57); + + vst2_s8(pp, _r01); +#endif // __ARM_FEATURE_DOTPROD + + pp += 16; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { +#if __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p01 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p23 = bfloat2float(vget_high_u16(_p)); + + _p01 = vmulq_f32(_p01, _scale); + _p23 = vmulq_f32(_p23, _scale); + + float32x4x2_t _pp = vuzpq_f32(_p01, _p23); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#else // __ARM_FEATURE_DOTPROD + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[1], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 2 + 1], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep + 1], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 3 + 1], _p, 7); + float32x4_t _p02 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p13 = bfloat2float(vget_high_u16(_p)); + + _p02 = vmulq_f32(_p02, _scale); + _p13 = vmulq_f32(_p13, _scale); + + float32x4x2_t _pp = vzipq_f32(_p02, _p13); + int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); +#endif // __ARM_FEATURE_DOTPROD + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 4; + } + for (; kk + 1 < max_kk; kk += 2) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep + 0]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); + pp += 4; + p0 += B_hstep * 2; + } +#endif // __ARM_NEON + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp += 2; + p0 += B_hstep; + } + } + } + for (; jj < max_jj; jj += 1) + { + const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; + +#if __ARM_NEON + if (elempack == 4) + { + int kk = 0; + for (; kk + 15 < max_kk; kk += 16) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + float32x4_t _p2 = bfloat2float(vld1_u16(p0 + B_hstep * 8)); + float32x4_t _p3 = bfloat2float(vld1_u16(p0 + B_hstep * 12)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + float32x4_t _p0 = bfloat2float(vld1_u16(p0)); + float32x4_t _p1 = bfloat2float(vld1_u16(p0 + B_hstep * 4)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } + for (; kk + 3 < max_kk; kk += 4) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); + pp[2] = float2int8(bfloat16_to_float32(p0[2]) * scale); + pp[3] = float2int8(bfloat16_to_float32(p0[3]) * scale); + pp += 4; + p0 += B_hstep * 4; + } + } +#endif // __ARM_NEON + if (elempack == 1) + { + int kk = 0; +#if __ARM_NEON + for (; kk + 15 < max_kk; kk += 16) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + uint16x8_t _q = uint16x8_t(); + _q = vsetq_lane_u16(p0[B_hstep * 8], _q, 0); + _q = vsetq_lane_u16(p0[B_hstep * 9], _q, 1); + _q = vsetq_lane_u16(p0[B_hstep * 10], _q, 2); + _q = vsetq_lane_u16(p0[B_hstep * 11], _q, 3); + _q = vsetq_lane_u16(p0[B_hstep * 12], _q, 4); + _q = vsetq_lane_u16(p0[B_hstep * 13], _q, 5); + _q = vsetq_lane_u16(p0[B_hstep * 14], _q, 6); + _q = vsetq_lane_u16(p0[B_hstep * 15], _q, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); + float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + _p2 = vmulq_f32(_p2, _scale); + _p3 = vmulq_f32(_p3, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + int8x8_t _r23 = float2int8(_p2, _p3); + + vst1q_s8(pp, vcombine_s8(_r01, _r23)); + + pp += 16; + p0 += B_hstep * 16; + } + for (; kk + 7 < max_kk; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[B_hstep], _p, 1); + _p = vsetq_lane_u16(p0[B_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[B_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[B_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[B_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[B_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[B_hstep * 7], _p, 7); + float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); + float32x4_t _p1 = bfloat2float(vget_high_u16(_p)); + + _p0 = vmulq_f32(_p0, _scale); + _p1 = vmulq_f32(_p1, _scale); + + int8x8_t _r01 = float2int8(_p0, _p1); + + vst1_s8(pp, _r01); + + pp += 8; + p0 += B_hstep * 8; + } +#endif // __ARM_NEON + // for (; kk + 1 < max_kk; kk += 2) + // { + // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + // pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); + // pp += 2; + // p0 += B_hstep * 2; + // } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); + pp += 1; + p0 += B_hstep; + } + } + } +} + +static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const unsigned short* pC = C; + + // NCNN_LOGE("unpack_output_tile_int32_to_bf16 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); + uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); + uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); + uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + transpose8x4_u16(_c89, _cab, _ccd, _cef); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); + float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); + float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); + float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); + float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); + float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); + float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); + float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); + uint16x8_t _cab = vld1q_u16(pC + c_hstep * 4 + 8); + uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 4 + 16); + uint16x8_t _cef = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); + float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); + float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); + float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); + float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); + float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); + float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); + float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); + float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); + float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); + float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); + vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f8), float2bfloat(_f9))); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_fa), float2bfloat(_fb))); + vst1q_u16(p0 + out_hstep * 4 + 16, vcombine_u16(float2bfloat(_fc), float2bfloat(_fd))); + vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(float2bfloat(_fe), float2bfloat(_ff))); + + pp += 64; + p0 += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + uint16x4_t _cc4 = vld1_u16(pC + c_hstep * 4); + uint16x4_t _cc5 = vld1_u16(pC + c_hstep * 5); + uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); + uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + transpose4x4_u16(_cc4, _cc5, _cc6, _cc7); + _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); + _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); + _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); + _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); + _f4 = vaddq_f32(_f4, bfloat2float(_cc4)); + _f5 = vaddq_f32(_f5, bfloat2float(_cc5)); + _f6 = vaddq_f32(_f6, bfloat2float(_cc6)); + _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); + pC += 4; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 4); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); + + pp += 32; + p0 += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + uint16x8_t _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[1], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 2; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + + pp += 16; + p0 += 8; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); + + pp += 8; + p0 += 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); + int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); + int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); + _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); + float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); + float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); + float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); + float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); + float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); + float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); + float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); + float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); + float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); + float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); + float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); + float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + _f8 = vaddq_f32(_f8, _cc4); + _f9 = vaddq_f32(_f9, _cc4); + _fa = vaddq_f32(_fa, _cc5); + _fb = vaddq_f32(_fb, _cc5); + _fc = vaddq_f32(_fc, _cc6); + _fd = vaddq_f32(_fd, _cc6); + _fe = vaddq_f32(_fe, _cc7); + _ff = vaddq_f32(_ff, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); + uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); + uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); + uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); + float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); + float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); + float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); + float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); + float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); + float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); + float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8x4_t _cc0 = vld4q_u16(pC); + uint16x8x4_t _cc1 = vld4q_u16(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0.val[0]))); + _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0.val[0]))); + _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc0.val[1]))); + _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc0.val[1]))); + _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc0.val[2]))); + _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc0.val[2]))); + _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc0.val[3]))); + _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc0.val[3]))); + _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc1.val[0]))); + _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc1.val[0]))); + _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc1.val[1]))); + _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc1.val[1]))); + _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc1.val[2]))); + _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc1.val[2]))); + _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc1.val[3]))); + _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc1.val[3]))); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c1); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f8), float2bfloat(_f9))); + vst1q_u16(p0 + out_hstep * 5, vcombine_u16(float2bfloat(_fa), float2bfloat(_fb))); + vst1q_u16(p0 + out_hstep * 6, vcombine_u16(float2bfloat(_fc), float2bfloat(_fd))); + vst1q_u16(p0 + out_hstep * 7, vcombine_u16(float2bfloat(_fe), float2bfloat(_ff))); + + pp += 64; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); + float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); + float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); + float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); + float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); + float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); + float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); + float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); + float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + _f4 = vaddq_f32(_f4, _cc4); + _f5 = vaddq_f32(_f5, _cc5); + _f6 = vaddq_f32(_f6, _cc6); + _f7 = vaddq_f32(_f7, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 1)); + float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 4; + } + if (c_elempack == 4) + { + uint16x4x4_t _cc0 = vld4_u16(pC); + uint16x4x4_t _cc1 = vld4_u16(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); + _f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2])); + _f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3])); + _f4 = vaddq_f32(_f4, bfloat2float(_cc1.val[0])); + _f5 = vaddq_f32(_f5, bfloat2float(_cc1.val[1])); + _f6 = vaddq_f32(_f6, bfloat2float(_cc1.val[2])); + _f7 = vaddq_f32(_f7, bfloat2float(_cc1.val[3])); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + pC += 4; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); + vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); + vst1_u16(p0 + out_hstep * 4, float2bfloat(_f4)); + vst1_u16(p0 + out_hstep * 5, float2bfloat(_f5)); + vst1_u16(p0 + out_hstep * 6, float2bfloat(_f6)); + vst1_u16(p0 + out_hstep * 7, float2bfloat(_f7)); + + pp += 32; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // e0 e1 f0 f1 + // g0 g1 h0 h1 + { + int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _sum13 = vzipq_s32(_sum2, _sum3); + _sum0 = _sum02.val[0]; + _sum1 = _sum02.val[1]; + _sum2 = _sum13.val[0]; + _sum3 = _sum13.val[1]; + } +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + // e0 e1 f0 f1 + // g0 g1 h0 h1 + { + int32x4x2_t _t0 = vuzpq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vuzpq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_t0.val[0], _t1.val[0]); + int32x4x2_t _t3 = vzipq_s32(_t1.val[1], _t0.val[1]); + _sum0 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4x2_t _descale01 = vzipq_f32(_descale0, _descale0); + float32x4x2_t _descale23 = vzipq_f32(_descale1, _descale1); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale23.val[0]); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale23.val[1]); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); + float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc0.val[1]); + _f2 = vaddq_f32(_f2, _cc1.val[0]); + _f3 = vaddq_f32(_f3, _cc1.val[1]); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _cc0 = uint16x8_t(); + _cc0 = vsetq_lane_u16(pC[0], _cc0, 0); + _cc0 = vsetq_lane_u16(pC[1], _cc0, 1); + _cc0 = vsetq_lane_u16(pC[c_hstep * 1], _cc0, 2); + _cc0 = vsetq_lane_u16(pC[c_hstep * 1 + 1], _cc0, 3); + _cc0 = vsetq_lane_u16(pC[c_hstep * 2], _cc0, 4); + _cc0 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _cc0, 5); + _cc0 = vsetq_lane_u16(pC[c_hstep * 3], _cc0, 6); + _cc0 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _cc0, 7); + uint16x8_t _cc1 = uint16x8_t(); + _cc1 = vsetq_lane_u16(pC[c_hstep * 4], _cc1, 0); + _cc1 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _cc1, 1); + _cc1 = vsetq_lane_u16(pC[c_hstep * 5], _cc1, 2); + _cc1 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _cc1, 3); + _cc1 = vsetq_lane_u16(pC[c_hstep * 6], _cc1, 4); + _cc1 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _cc1, 5); + _cc1 = vsetq_lane_u16(pC[c_hstep * 7], _cc1, 6); + _cc1 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _cc1, 7); + _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0))); + _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0))); + _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc1))); + _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc1))); + pC += 2; + } + if (c_elempack == 4) + { + // TODO optimize + uint16x8_t _cc0 = vld1q_u16(pC); + uint16x8_t _cc1 = vld1q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_cc0)); + _c1 = bfloat2float(vget_high_u16(_cc0)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_cc1)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_cc1)); + float32x4x2_t _c01 = vzipq_f32(_c0, _c1); + float32x4x2_t _c23 = vzipq_f32(_c2, _c3); + _f0 = vaddq_f32(_f0, vcombine_f32(vget_low_f32(_c01.val[0]), vget_low_f32(_c01.val[1]))); + _f1 = vaddq_f32(_f1, vcombine_f32(vget_high_f32(_c01.val[0]), vget_high_f32(_c01.val[1]))); + _f2 = vaddq_f32(_f2, vcombine_f32(vget_low_f32(_c23.val[0]), vget_low_f32(_c23.val[1]))); + _f3 = vaddq_f32(_f3, vcombine_f32(vget_high_f32(_c23.val[0]), vget_high_f32(_c23.val[1]))); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[0], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + pC += 2; + } + } + + uint16x4_t _fb0 = float2bfloat(_f0); + uint16x4_t _fb1 = float2bfloat(_f1); + uint16x4_t _fb2 = float2bfloat(_f2); + uint16x4_t _fb3 = float2bfloat(_f3); + + p0[0] = vget_lane_u16(_fb0, 0); + p0[1] = vget_lane_u16(_fb0, 1); + p0[out_hstep] = vget_lane_u16(_fb0, 2); + p0[out_hstep + 1] = vget_lane_u16(_fb0, 3); + p0[out_hstep * 2] = vget_lane_u16(_fb1, 0); + p0[out_hstep * 2 + 1] = vget_lane_u16(_fb1, 1); + p0[out_hstep * 3] = vget_lane_u16(_fb1, 2); + p0[out_hstep * 3 + 1] = vget_lane_u16(_fb1, 3); + p0[out_hstep * 4] = vget_lane_u16(_fb2, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_fb2, 1); + p0[out_hstep * 5] = vget_lane_u16(_fb2, 2); + p0[out_hstep * 5 + 1] = vget_lane_u16(_fb2, 3); + p0[out_hstep * 6] = vget_lane_u16(_fb3, 0); + p0[out_hstep * 6 + 1] = vget_lane_u16(_fb3, 1); + p0[out_hstep * 7] = vget_lane_u16(_fb3, 2); + p0[out_hstep * 7 + 1] = vget_lane_u16(_fb3, 3); + + pp += 16; + p0 += 2; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[c_hstep * 4], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep * 5], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 6], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 7], _c, 7); + _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c))); + _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c))); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + uint16x4_t _fb0 = float2bfloat(_f0); + uint16x4_t _fb1 = float2bfloat(_f1); + + p0[0] = vget_lane_u16(_fb0, 0); + p0[out_hstep] = vget_lane_u16(_fb0, 1); + p0[out_hstep * 2] = vget_lane_u16(_fb0, 2); + p0[out_hstep * 3] = vget_lane_u16(_fb0, 3); + p0[out_hstep * 4] = vget_lane_u16(_fb1, 0); + p0[out_hstep * 5] = vget_lane_u16(_fb1, 1); + p0[out_hstep * 6] = vget_lane_u16(_fb1, 2); + p0[out_hstep * 7] = vget_lane_u16(_fb1, 3); + + pp += 8; + p0++; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + _c0 = bfloat2float(vld1_u16(pC)); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); + float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); + float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); + float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); + vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); + + pp += 32; + p0 += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep * 1); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); + _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); + _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); + _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); + pC += 4; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + + pp += 16; + p0 += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + + pp += 8; + p0 += 8; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + _f0 = vaddq_f32(_f0, bfloat2float(_c)); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + + pp += 4; + p0 += 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8x4_t _cc = vld4q_u16(pC); + _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc.val[0]))); + _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc.val[0]))); + _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc.val[1]))); + _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc.val[1]))); + _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc.val[2]))); + _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc.val[2]))); + _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc.val[3]))); + _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc.val[3]))); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); + + pp += 32; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + _sum1 = vextq_s32(_sum1, _sum1, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { +#if __aarch64__ + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); +#endif + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = bfloat2float(vld1_u16(pC)); + float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep * 1)); + float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + if (c_elempack == 4) + { + uint16x4x4_t _c = vld4_u16(pC); + _f0 = vaddq_f32(_f0, bfloat2float(_c.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_c.val[1])); + _f2 = vaddq_f32(_f2, bfloat2float(_c.val[2])); + _f3 = vaddq_f32(_f3, bfloat2float(_c.val[3])); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + pC += 4; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); + vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); + + pp += 16; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + { + int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + _sum0 = _sum01.val[0]; + _sum1 = _sum01.val[1]; + } +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 a1 b0 b1 + // c0 c1 d0 d1 + { + int32x4_t _t0 = vuzpq_s32(_sum0, _sum1).val[0]; + int32x4_t _t1 = vuzpq_s32(_sum1, _sum0).val[1]; + int32x4x2_t _t3 = vuzpq_s32(_t0, _t1); + _sum0 = _t3.val[0]; + _sum1 = _t3.val[1]; + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4x2_t _descale01 = vzipq_f32(_descale, _descale); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); + _f0 = vaddq_f32(_f0, _cc0.val[0]); + _f1 = vaddq_f32(_f1, _cc0.val[1]); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[1], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 3); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c))); + _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c))); + pC += 2; + } + if (c_elempack == 4) + { + uint16x8_t _cc = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_cc)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_cc)); + float32x4x2_t _c01 = vzipq_f32(_c0, _c1); + _f0 = vaddq_f32(_f0, _c01.val[0]); + _f1 = vaddq_f32(_f1, _c01.val[1]); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[0], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 2; + } + } + + uint16x4_t _fb0 = float2bfloat(_f0); + uint16x4_t _fb1 = float2bfloat(_f1); + + p0[0] = vget_lane_u16(_fb0, 0); + p0[1] = vget_lane_u16(_fb0, 1); + p0[out_hstep] = vget_lane_u16(_fb0, 2); + p0[out_hstep + 1] = vget_lane_u16(_fb0, 3); + p0[out_hstep * 2] = vget_lane_u16(_fb1, 0); + p0[out_hstep * 2 + 1] = vget_lane_u16(_fb1, 1); + p0[out_hstep * 3] = vget_lane_u16(_fb1, 2); + p0[out_hstep * 3 + 1] = vget_lane_u16(_fb1, 3); + + // vst1_f32(p0, vget_low_f32(_f0)); + // vst1_f32(p1, vget_high_f32(_f0)); + // vst1_f32(p2, vget_low_f32(_f1)); + // vst1_f32(p3, vget_high_f32(_f1)); + + pp += 8; + p0 += 2; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + _f0 = vaddq_f32(_f0, bfloat2float(_c)); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + uint16x4_t _fb0 = float2bfloat(_f0); + + p0[0] = vget_lane_u16(_fb0, 0); + p0[out_hstep] = vget_lane_u16(_fb0, 1); + p0[out_hstep * 2] = vget_lane_u16(_fb0, 2); + p0[out_hstep * 3] = vget_lane_u16(_fb0, 3); + + pp += 4; + p0++; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]); + c1 = bfloat16_to_float32(pC[1]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + // if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + + pp += 16; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // TODO neon optimize + float f00 = pp[0] * descale0; + float f01 = pp[1] * descale0; + float f10 = pp[2] * descale1; + float f11 = pp[3] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f00 += c0; + f01 += c0; + f10 += c0; + f11 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f00 += c0; + f01 += c0; + f10 += c1; + f11 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f00 += bfloat16_to_float32(pC[0]); + f01 += bfloat16_to_float32(pC[1]); + f10 += bfloat16_to_float32(pC[c_hstep]); + f11 += bfloat16_to_float32(pC[c_hstep + 1]); + pC += 2; + } + if (broadcast_type_C == 4) + { + f00 += bfloat16_to_float32(pC[0]); + f01 += bfloat16_to_float32(pC[1]); + f10 += bfloat16_to_float32(pC[0]); + f11 += bfloat16_to_float32(pC[1]); + pC += 2; + } + } + + p0[0] = float32_to_bfloat16(f00); + p0[1] = float32_to_bfloat16(f01); + p0[out_hstep] = float32_to_bfloat16(f10); + p0[out_hstep + 1] = float32_to_bfloat16(f11); + + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]); + f1 += bfloat16_to_float32(pC[c_hstep]); + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += bfloat16_to_float32(pC[0]); + f1 += bfloat16_to_float32(pC[0]); + pC += 1; + } + } + + p0[0] = float32_to_bfloat16(f0); + p0[out_hstep] = float32_to_bfloat16(f1); + + pp += 2; + p0++; + } + } + } + for (; ii < max_ii; ii += 1) + { + // out_elempack == 1 + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + // if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _cc = float32x2_t(); + _cc = vset_lane_f32(bfloat16_to_float32(pC[0]), _cc, 0); + _cc = vset_lane_f32(bfloat16_to_float32(pC[1]), _cc, 1); + _f0 = vadd_f32(_f0, _cc); + pC += 2; + } + } + + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); + p0[1] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); + + pp += 2; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]); + pC += 1; + } + } + + p0[0] = float32_to_bfloat16(f0); + + pp += 1; + p0++; + } + } + } +} + +static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + return; + } +#endif + + const int out_elempack = top_blob.elempack; + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int c_hstep = C.dims == 3 ? (int)C.cstep : C.w; + const int c_elempack = C.elempack; + const unsigned short* pC = C; + + // NCNN_LOGE("transpose_unpack_output_tile_int32_to_bf16 %d %d %d %d %d %d %d", i, max_ii, j, max_jj, out_elempack, broadcast_type_C, c_elempack); + + const int* pp = topT; + + int ii = 0; +#if __ARM_NEON + for (; ii + 7 < max_ii; ii += 8) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale0 = vld1q_f32((const float*)descales + ii); + float32x4_t _descale1 = vld1q_f32((const float*)descales + ii + 4); + + float32x4_t _c0; + float32x4_t _c1; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); + int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); + int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); + _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); + } +#else // __ARM_FEATURE_DOTPROD + + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + // e0 e1 e2 e3 + // e4 e5 e6 e7 + // f0 f1 f2 f3 + // f4 f5 f6 f7 + // g0 g1 g2 g3 + // g4 g5 g6 g7 + // h0 h1 h2 h3 + // h4 h5 h6 h7 + { + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); + float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); + float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); + float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); + float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); + float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); + float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); + float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); + float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + _f8 = vaddq_f32(_f8, _cc4); + _f9 = vaddq_f32(_f9, _cc4); + _fa = vaddq_f32(_fa, _cc5); + _fb = vaddq_f32(_fb, _cc5); + _fc = vaddq_f32(_fc, _cc6); + _fd = vaddq_f32(_fd, _cc6); + _fe = vaddq_f32(_fe, _cc7); + _ff = vaddq_f32(_ff, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + 4)); + float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep)); + float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep + 4)); + float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 2 + 4)); + float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 3 + 4)); + float32x4_t _c8 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + float32x4_t _c9 = bfloat2float(vld1_u16(pC + c_hstep * 4 + 4)); + float32x4_t _ca = bfloat2float(vld1_u16(pC + c_hstep * 5)); + float32x4_t _cb = bfloat2float(vld1_u16(pC + c_hstep * 5 + 4)); + float32x4_t _cc = bfloat2float(vld1_u16(pC + c_hstep * 6)); + float32x4_t _cd = bfloat2float(vld1_u16(pC + c_hstep * 6 + 4)); + float32x4_t _ce = bfloat2float(vld1_u16(pC + c_hstep * 7)); + float32x4_t _cf = bfloat2float(vld1_u16(pC + c_hstep * 7 + 4)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8x4_t _cc0 = vld4q_u16(pC); + uint16x8x4_t _cc1 = vld4q_u16(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0.val[0]))); + _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0.val[0]))); + _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc0.val[1]))); + _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc0.val[1]))); + _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc0.val[2]))); + _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc0.val[2]))); + _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc0.val[3]))); + _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc0.val[3]))); + _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc1.val[0]))); + _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc1.val[0]))); + _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc1.val[1]))); + _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc1.val[1]))); + _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc1.val[2]))); + _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc1.val[2]))); + _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc1.val[3]))); + _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc1.val[3]))); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c1); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); + vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f8), float2bfloat(_fa))); + vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_fc), float2bfloat(_fe))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_f5), float2bfloat(_f7))); + vst1q_u16(p0 + out_hstep * 4 + 16, vcombine_u16(float2bfloat(_f9), float2bfloat(_fb))); + vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(float2bfloat(_fd), float2bfloat(_ff))); + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else // __ARM_FEATURE_DOTPROD + + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + // e0 e1 e2 e3 + // f0 f1 f2 f3 + // g0 g1 g2 g3 + // h0 h1 h2 h3 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); + float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); + float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); + float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); + float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); + float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); + float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); + float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + _f4 = vaddq_f32(_f4, _cc4); + _f5 = vaddq_f32(_f5, _cc5); + _f6 = vaddq_f32(_f6, _cc6); + _f7 = vaddq_f32(_f7, _cc7); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 4; + } + if (c_elempack == 4) + { + uint16x4x4_t _cc0 = vld4_u16(pC); + uint16x4x4_t _cc1 = vld4_u16(pC + c_hstep * 4); + _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); + _f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2])); + _f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3])); + _f4 = vaddq_f32(_f4, bfloat2float(_cc1.val[0])); + _f5 = vaddq_f32(_f5, bfloat2float(_cc1.val[1])); + _f6 = vaddq_f32(_f6, bfloat2float(_cc1.val[2])); + _f7 = vaddq_f32(_f7, bfloat2float(_cc1.val[3])); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + pC += 4; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); + vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); + pp += 32; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); +#endif // __ARM_FEATURE_DOTPROD + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); + uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); + uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); + uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); + transpose8x8_u16(_c01, _c23, _c45, _c67, _c89, _cab, _ccd, _cef); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); + float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); + float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); + float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); + float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); + float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); + float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); + float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c2); + _f2 = vaddq_f32(_f2, _c4); + _f3 = vaddq_f32(_f3, _c6); + _f4 = vaddq_f32(_f4, _c8); + _f5 = vaddq_f32(_f5, _ca); + _f6 = vaddq_f32(_f6, _cc); + _f7 = vaddq_f32(_f7, _ce); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c3); + _fa = vaddq_f32(_fa, _c5); + _fb = vaddq_f32(_fb, _c7); + _fc = vaddq_f32(_fc, _c9); + _fd = vaddq_f32(_fd, _cb); + _fe = vaddq_f32(_fe, _cd); + _ff = vaddq_f32(_ff, _cf); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); + uint16x8_t _cab = vld1q_u16(pC + c_hstep * 4 + 8); + uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 4 + 16); + uint16x8_t _cef = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); + float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); + float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); + float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); + float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); + float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); + float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); + float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); + float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); + float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); + float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f8))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f9))); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_fa))); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f3), float2bfloat(_fb))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f4), float2bfloat(_fc))); + vst1q_u16(p0 + out_hstep * 5, vcombine_u16(float2bfloat(_f5), float2bfloat(_fd))); + vst1q_u16(p0 + out_hstep * 6, vcombine_u16(float2bfloat(_f6), float2bfloat(_fe))); + vst1q_u16(p0 + out_hstep * 7, vcombine_u16(float2bfloat(_f7), float2bfloat(_ff))); + + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + uint16x4_t _cc4 = vld1_u16(pC + c_hstep * 4); + uint16x4_t _cc5 = vld1_u16(pC + c_hstep * 5); + uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); + uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); + transpose4x8_u16(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); + _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); + _f1 = vaddq_f32(_f1, bfloat2float(_cc2)); + _f2 = vaddq_f32(_f2, bfloat2float(_cc4)); + _f3 = vaddq_f32(_f3, bfloat2float(_cc6)); + _f4 = vaddq_f32(_f4, bfloat2float(_cc1)); + _f5 = vaddq_f32(_f5, bfloat2float(_cc3)); + _f6 = vaddq_f32(_f6, bfloat2float(_cc5)); + _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); + pC += 4; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 4); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f4))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f5))); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_f6))); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f3), float2bfloat(_f7))); + + pp += 32; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + + uint16x8_t _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[1], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); + + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c2 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 2; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); + + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + pp += 8; + p0 += out_hstep; + } + } + } + for (; ii + 3 < max_ii; ii += 4) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) + { + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + _c0 = bfloat2float(vld1_u16(pC)); + } + if (broadcast_type_C == 3) + { + pC = (const unsigned short*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); + int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); + int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + } +#else // __ARM_FEATURE_DOTPROD + + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + // c0 c1 c2 c3 + // c4 c5 c6 c7 + // d0 d1 d2 d3 + // d4 d5 d6 d7 + { + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); + float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); + float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); + float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); + float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc0); + _f2 = vaddq_f32(_f2, _cc1); + _f3 = vaddq_f32(_f3, _cc1); + _f4 = vaddq_f32(_f4, _cc2); + _f5 = vaddq_f32(_f5, _cc2); + _f6 = vaddq_f32(_f6, _cc3); + _f7 = vaddq_f32(_f7, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8x4_t _c = vld4q_u16(pC); + _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c.val[0]))); + _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c.val[0]))); + _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_c.val[1]))); + _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_c.val[1]))); + _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_c.val[2]))); + _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_c.val[2]))); + _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_c.val[3]))); + _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_c.val[3]))); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c1); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_f5), float2bfloat(_f7))); + + pp += 32; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); + } +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 a1 a2 a3 + // b0 b1 b2 b3 + // c0 c1 c2 c3 + // d0 d1 d2 d3 + { + _sum1 = vextq_s32(_sum1, _sum1, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + +#if __aarch64__ + float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); + float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); + float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); +#else + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); +#endif + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); + float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); + float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); + float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); + _f0 = vaddq_f32(_f0, _cc0); + _f1 = vaddq_f32(_f1, _cc1); + _f2 = vaddq_f32(_f2, _cc2); + _f3 = vaddq_f32(_f3, _cc3); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + _c0 = bfloat2float(vld1_u16(pC)); + float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + if (c_elempack == 4) + { + uint16x4x4_t _c = vld4_u16(pC); + _f0 = vaddq_f32(_f0, bfloat2float(_c.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_c.val[1])); + _f2 = vaddq_f32(_f2, bfloat2float(_c.val[2])); + _f3 = vaddq_f32(_f3, bfloat2float(_c.val[3])); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + pC += 4; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + + pp += 16; + p0 += out_hstep * 4; + } + } + if (out_elempack == 1) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 32; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); + float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); + float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); + float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); + vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); + vst1_u16(p0 + out_hstep * 4, float2bfloat(_f4)); + vst1_u16(p0 + out_hstep * 5, float2bfloat(_f5)); + vst1_u16(p0 + out_hstep * 6, float2bfloat(_f6)); + vst1_u16(p0 + out_hstep * 7, float2bfloat(_f7)); + + pp += 32; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 +#else + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); + _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); + _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); + _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); + pC += 4; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); + float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); + vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 +#else + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x8_t _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); + + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + _f0 = vaddq_f32(_f0, bfloat2float(_c)); + pC += 1; + } + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + pp += 4; + p0 += out_hstep; + } + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale01 = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]); + c1 = bfloat16_to_float32(pC[1]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + +#if __ARM_NEON + if (out_elempack == 4) + { + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; + } + } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 4; + } + if (broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + 4, float2bfloat(_f1)); + + pp += 8; + p0 += out_hstep * 4; + } + } +#endif // __ARM_NEON + if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + // a0 a1 a2 a3 + // a4 a5 a6 a7 + // b0 b1 b2 b3 + // b4 b5 b6 b7 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _sum13 = vzipq_s32(_sum1, _sum3); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum02.val[0]), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum02.val[1]), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum13.val[0]), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum13.val[1]), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + _f1 = vaddq_f32(_f1, _cc); + _f2 = vaddq_f32(_f2, _cc); + _f3 = vaddq_f32(_f3, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x4x2_t _c02 = vzip_u16(vget_low_u16(_c01), vget_low_u16(_c23)); + uint16x4x2_t _c13 = vzip_u16(vget_high_u16(_c01), vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, bfloat2float(_c02.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_c02.val[1])); + _f2 = vaddq_f32(_f2, bfloat2float(_c13.val[0])); + _f3 = vaddq_f32(_f3, bfloat2float(_c13.val[1])); + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x4x2_t _cc0 = vzip_u16(vget_low_u16(_c01), vget_low_u16(_c01)); + uint16x4x2_t _cc1 = vzip_u16(vget_high_u16(_c01), vget_high_u16(_c01)); + _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); + _f2 = vaddq_f32(_f2, bfloat2float(_cc1.val[0])); + _f3 = vaddq_f32(_f3, bfloat2float(_cc1.val[1])); + pC += 8; + } + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 2] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 3] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 5] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 5 + 1] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 6] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 6 + 1] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 7] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 7 + 1] = vget_lane_u16(_bf3, 3); + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 + + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum01.val[0]), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum01.val[1]), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + _f1 = vaddq_f32(_f1, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4x2_t _c01 = vzip_u16(_cc0, _cc1); + _f0 = vaddq_f32(_f0, bfloat2float(_c01.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_c01.val[1])); + pC += 4; + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = vld1_u16(pC); + uint16x4x2_t _cc = vzip_u16(_c, _c); + _f0 = vaddq_f32(_f0, bfloat2float(_cc.val[0])); + _f1 = vaddq_f32(_f1, bfloat2float(_cc.val[1])); + pC += 4; + } + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 2] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 3] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // TODO neon optimize + // a0 a1 b0 b1 + + float f00 = pp[0] * descale0; + float f01 = pp[2] * descale1; + float f10 = pp[1] * descale0; + float f11 = pp[3] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f00 += c0; + f01 += c0; + f10 += c0; + f11 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f00 += c0; + f01 += c1; + f10 += c0; + f11 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f00 += bfloat16_to_float32(pC[0]); + f01 += bfloat16_to_float32(pC[c_hstep]); + f10 += bfloat16_to_float32(pC[1]); + f11 += bfloat16_to_float32(pC[c_hstep + 1]); + pC += 2; + } + if (broadcast_type_C == 4) + { + float c0 = bfloat16_to_float32(pC[0]); + float c1 = bfloat16_to_float32(pC[1]); + f00 += c0; + f01 += c0; + f10 += c1; + f11 += c1; + pC += 2; + } + } + + p0[0] = float32_to_bfloat16(f00); + p0[1] = float32_to_bfloat16(f01); + p0[out_hstep] = float32_to_bfloat16(f10); + p0[out_hstep + 1] = float32_to_bfloat16(f11); + + pp += 4; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]); + f1 += bfloat16_to_float32(pC[c_hstep]); + pC += 1; + } + if (broadcast_type_C == 4) + { + c0 = bfloat16_to_float32(pC[0]); + f0 += c0; + f1 += c0; + pC += 1; + } + } + + p0[0] = float32_to_bfloat16(f0); + p0[1] = float32_to_bfloat16(f1); + pp += 2; + p0 += out_hstep; + } + } + } + for (; ii < max_ii; ii += 1) + { + unsigned short* p0 = (unsigned short*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + const float descale = descales[ii]; +#if __ARM_NEON + float32x4_t _descale = vdupq_n_f32(descale); +#endif + + float c0; +#if __ARM_NEON + float32x4_t _c0; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]); +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const unsigned short*)C + j; + } + } + +#if __ARM_NEON + if (out_elempack == 4) + { + int jj = 0; + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); + vst1_u16(p0 + out_hstep * 8, float2bfloat(_f2)); + vst1_u16(p0 + out_hstep * 12, float2bfloat(_f3)); + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + + vst1_u16(p0, float2bfloat(_f0)); + pp += 4; + p0 += out_hstep * 4; + } + } +#endif // __ARM_NEON + if (out_elempack == 1) + { + int jj = 0; +#if __ARM_NEON + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 16; + } + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 8; + } + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vaddq_f32(_f0, _c0); + pC += 4; + } + } + + uint16x4_t _bf0 = float2bfloat(_f0); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _c = float32x2_t(); + _c = vset_lane_f32(bfloat16_to_float32(pC[0]), _c, 0); + _c = vset_lane_f32(bfloat16_to_float32(pC[1]), _c, 1); + _f0 = vadd_f32(_f0, _c); + pC += 2; + } + } + + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); + p0[out_hstep] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); + + pp += 2; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]); + pC += 1; + } + } + + p0[0] = float32_to_bfloat16(f0); + + pp += 1; + p0 += out_hstep; + } + } + } +} diff --git a/src/layer/gemm.cpp b/src/layer/gemm.cpp index de6b2adeb956..0ebe5974d0b7 100644 --- a/src/layer/gemm.cpp +++ b/src/layer/gemm.cpp @@ -39,10 +39,19 @@ int Gemm::load_param(const ParamDict& pd) output_elempack = pd.get(12, 0); output_elemtype = pd.get(13, 0); output_transpose = pd.get(14, 0); + int8_scale_term = pd.get(18, 0); constant_TILE_M = pd.get(20, 0); constant_TILE_N = pd.get(21, 0); constant_TILE_K = pd.get(22, 0); + if (int8_scale_term) + { +#if !NCNN_INT8 + NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference"); + return -1; +#endif + } + if (constantA == 1 && (constantM == 0 || constantK == 0)) { NCNN_LOGE("constantM and constantK must be non-zero when constantA enabled"); @@ -111,9 +120,175 @@ int Gemm::load_model(const ModelBin& mb) return -100; } +#if NCNN_INT8 + if (int8_scale_term) + { + if (constantA == 1) + { + A_data_int8_scales = mb.load(constantM, 1); + } + + if (constantB == 1) + { + B_data_int8_scale = mb.load(1, 1)[0]; + } + } +#endif // NCNN_INT8 + return 0; } +static void gemm_transB(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, float alpha, float beta, int broadcast_type_C, int output_transpose, const Option& opt) +{ + const int M = A.dims == 3 ? A.c : A.h; + const int N = BT.dims == 3 ? BT.c : BT.h; + const int K = A.w; // assert A.w == BT.w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < M; i++) + { + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const int BT_hstep = BT.dims == 3 ? (int)BT.cstep : BT.w; + + const float* ptrA = (const float*)A + i * A_hstep; + const float* ptrC = C; + + for (int j = 0; j < N; j++) + { + const float* ptrBT = (const float*)BT + j * BT_hstep; + + float sum = 0.f; + if (ptrC) + { + if (broadcast_type_C == 0) + { + sum = ptrC[0]; + } + if (broadcast_type_C == 1) + { + sum = ptrC[i]; + } + if (broadcast_type_C == 2) + { + sum = ptrC[i]; + } + if (broadcast_type_C == 3) + { + sum = ptrC[i * N + j]; + } + if (broadcast_type_C == 4) + { + sum = ptrC[j]; + } + + sum *= beta; + } + + for (int k = 0; k < K; k++) + { + sum += ptrA[k] * ptrBT[k]; + } + + sum *= alpha; + + if (output_transpose) + { + top_blob[j * out_hstep + i] = sum; + } + else + { + top_blob[i * out_hstep + j] = sum; + } + } + } +} + +#if NCNN_INT8 +static inline signed char float2int8(float v) +{ + int int32 = static_cast(round(v)); + if (int32 > 127) return 127; + if (int32 < -127) return -127; + return (signed char)int32; +} + +static void gemm_transB_int8(const Mat& A_int8, const Mat& BT_int8, const Mat& A_int8_scales, float BT_int8_scale, const Mat& C, Mat& top_blob, float alpha, float beta, int broadcast_type_C, int output_transpose, const Option& opt) +{ + const int M = A_int8.h; + const int N = BT_int8.h; + const int K = A_int8.w; // assert A_int8.w == BT_int8.w + + // NCNN_LOGE("naive ds %f %f", A_int8_scales[0], BT_int8_scale); + + // #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < M; i++) + { + const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + + const signed char* ptrA = A_int8.row(i); + const float* ptrC = C; + + const float descale = 1.f / (A_int8_scales[i] * BT_int8_scale); + + // NCNN_LOGE("descale %f", descale); + + for (int j = 0; j < N; j++) + { + const signed char* ptrBT = BT_int8.row(j); + + int sum = 0; + for (int k = 0; k < K; k++) + { + // NCNN_LOGE("ptrA[%d] %d", k, ptrA[k]); + sum += ptrA[k] * ptrBT[k]; + } + + float sum_fp32 = sum * descale; + + if (ptrC) + { + float c = 0.f; + if (broadcast_type_C == 0) + { + c = ptrC[0]; + } + if (broadcast_type_C == 1) + { + c = ptrC[i]; + } + if (broadcast_type_C == 2) + { + c = ptrC[i]; + } + if (broadcast_type_C == 3) + { + c = ptrC[i * N + j]; + } + if (broadcast_type_C == 4) + { + c = ptrC[j]; + } + + sum_fp32 += c * beta; + } + + sum_fp32 *= alpha; + + if (output_transpose) + { + top_blob[j * out_hstep + i] = sum_fp32; + } + else + { + top_blob[i * out_hstep + j] = sum_fp32; + } + } + } +} +#endif // NCNN_INT8 + int Gemm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { std::vector bottom_blobs(1, bottom_blob); @@ -125,6 +300,13 @@ int Gemm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif // NCNN_INT8 + const Mat& A0 = constantA ? A_data : bottom_blobs[0]; const Mat& B0 = constantB ? B_data : constantA ? bottom_blobs[0] : bottom_blobs[1]; @@ -152,18 +334,18 @@ int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_bl } } - Mat B; + Mat BT; if (transB == 0) { // transpose B to col-major - B.create((B0.dims == 3 ? B0.c : B0.h), B0.w, elemsize, opt.workspace_allocator); + BT.create((B0.dims == 3 ? B0.c : B0.h), B0.w, elemsize, opt.workspace_allocator); const int B0_hstep = B0.dims == 3 ? (int)B0.cstep : B0.w; - for (int i = 0; i < B.h; i++) + for (int i = 0; i < BT.h; i++) { - float* ptr = B.row(i); - for (int j = 0; j < B.w; j++) + float* ptr = BT.row(i); + for (int j = 0; j < BT.w; j++) { ptr[j] = B0[j * B0_hstep + i]; } @@ -171,43 +353,36 @@ int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_bl } else { - B = B0; + BT = B0; } const int M = A.dims == 3 ? A.c : A.h; - const int K = A.w; // assert A.w == B.w - const int N = B.dims == 3 ? B.c : B.h; + const int N = BT.dims == 3 ? BT.c : BT.h; - const float* ptrC = 0; + Mat C; int broadcast_type_C = 0; if (constantC) { - ptrC = C_data; + C = C_data; broadcast_type_C = constant_broadcast_type_C; } else { - if (constantA && constantB) + if (constantA && constantB && bottom_blobs.size() == 1) { - ptrC = bottom_blobs.size() == 1 ? bottom_blobs[0] : 0; + C = bottom_blobs[0]; } - else if (constantA) + else if ((constantA || constantB) && bottom_blobs.size() == 2) { - ptrC = bottom_blobs.size() == 2 ? bottom_blobs[1] : 0; + C = bottom_blobs[1]; } - else if (constantB) - { - ptrC = bottom_blobs.size() == 2 ? bottom_blobs[1] : 0; - } - else + else if (bottom_blobs.size() == 3) { - ptrC = bottom_blobs.size() == 3 ? bottom_blobs[2] : 0; + C = bottom_blobs[2]; } - if (ptrC) + if (!C.empty()) { - const Mat& C = bottom_blobs[bottom_blobs.size() - 1]; - if (C.dims == 1 && C.w == 1) { // scalar @@ -260,66 +435,226 @@ int Gemm::forward(const std::vector& bottom_blobs, std::vector& top_bl if (top_blob.empty()) return -100; - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < M; i++) - { - const int out_hstep = top_blob.dims == 3 ? (int)top_blob.cstep : top_blob.w; + gemm_transB(A, BT, C, top_blob, alpha, beta, broadcast_type_C, output_transpose, opt); - const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; - const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; + return 0; +} - const float* ptrA = (const float*)A + i * A_hstep; +#if NCNN_INT8 +int Gemm::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& A0 = constantA ? A_data : bottom_blobs[0]; + const Mat& B0 = constantB ? B_data : constantA ? bottom_blobs[0] : bottom_blobs[1]; - for (int j = 0; j < N; j++) + Mat A; + if (transA == 0) + { + A = A0; + } + else + { + // transpose A to row-major + if (A0.elemsize == 1) { - const float* ptrB = (const float*)B + j * B_hstep; + A.create(A0.h, A0.w, (size_t)1u, 1, opt.workspace_allocator); - float sum = 0.f; - if (ptrC) + for (int i = 0; i < A.h; i++) { - if (broadcast_type_C == 0) - { - sum = ptrC[0]; - } - if (broadcast_type_C == 1) - { - sum = ptrC[i]; - } - if (broadcast_type_C == 2) + signed char* ptr = A.row(i); + for (int j = 0; j < A.w; j++) { - sum = ptrC[i]; + ptr[j] = A0.row(j)[i]; } - if (broadcast_type_C == 3) - { - sum = ptrC[i * N + j]; - } - if (broadcast_type_C == 4) + } + } + else + { + A.create(A0.dims == 3 ? A0.c : A0.h, A0.w, (size_t)4u, 1, opt.workspace_allocator); + + for (int i = 0; i < A.h; i++) + { + float* ptr = A.row(i); + for (int j = 0; j < A.w; j++) { - sum = ptrC[j]; + ptr[j] = A0.dims == 3 ? A0.channel(j)[i] : A0.row(j)[i]; } + } + } + } - sum *= beta; + // dynamic quantize A + Mat A_int8 = A; + Mat A_int8_scales = A_data_int8_scales; + if (A_int8.elemsize != 1) + { + A_int8.create(A.w, A.dims == 3 ? A.c : A.h, (size_t)1u, 1, opt.workspace_allocator); + A_int8_scales.create(A_int8.h, (size_t)4u, 1, opt.workspace_allocator); + + for (int i = 0; i < A_int8.h; i++) + { + const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; + const float* ptr = (const float*)A + i * A_hstep; + + float absmax = 0.f; + for (int k = 0; k < A_int8.w; k++) + { + absmax = std::max(absmax, (float)fabs(ptr[k])); } - for (int k = 0; k < K; k++) + // NCNN_LOGE("A[%d] absmax %f", i, absmax); + + float A_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; + A_int8_scales[i] = A_int8_scale; + + signed char* ptrAi = A_int8.row(i); + + for (int k = 0; k < A_int8.w; k++) { - sum += ptrA[k] * ptrB[k]; + ptrAi[k] = float2int8(ptr[k] * A_int8_scale); } + } + } - sum *= alpha; + // dynamic quantize B + Mat B0_int8 = B0; + float B_int8_scale = B_data_int8_scale; + if (B0_int8.elemsize != 1) + { + B0_int8.create(B0.w, B0.dims == 3 ? B0.c : B0.h, (size_t)1u, 1, opt.workspace_allocator); - if (output_transpose) + float absmax = 0.f; + for (int i = 0; i < B0_int8.h; i++) + { + const int B_hstep = B0.dims == 3 ? (int)B0.cstep : B0.w; + const float* ptr = (const float*)B0 + i * B_hstep; + + for (int k = 0; k < B0_int8.w; k++) { - top_blob[j * out_hstep + i] = sum; + absmax = std::max(absmax, (float)fabs(ptr[k])); } - else + } + + // NCNN_LOGE("B0 absmax %f", absmax); + + B_int8_scale = absmax == 0.f ? 1.f : 127.f / absmax; + + for (int i = 0; i < B0_int8.h; i++) + { + const int B_hstep = B0.dims == 3 ? (int)B0.cstep : B0.w; + const float* ptr = (const float*)B0 + i * B_hstep; + + signed char* ptrBi = B0_int8.row(i); + + for (int k = 0; k < B0_int8.w; k++) { - top_blob[i * out_hstep + j] = sum; + ptrBi[k] = float2int8(ptr[k] * B_int8_scale); } } } + Mat BT_int8; + if (transB == 0) + { + // transpose B to col-major + BT_int8.create(B0_int8.h, B0_int8.w, (size_t)1u, 1, opt.workspace_allocator); + + for (int i = 0; i < BT_int8.h; i++) + { + signed char* ptr = BT_int8.row(i); + for (int j = 0; j < BT_int8.w; j++) + { + ptr[j] = B0_int8.row(j)[i]; + } + } + } + else + { + BT_int8 = B0_int8; + } + + const int M = A_int8.h; + const int N = BT_int8.h; + + Mat C; + int broadcast_type_C = 0; + if (constantC) + { + C = C_data; + broadcast_type_C = constant_broadcast_type_C; + } + else + { + if (constantA && constantB && bottom_blobs.size() == 1) + { + C = bottom_blobs[0]; + } + else if ((constantA || constantB) && bottom_blobs.size() == 2) + { + C = bottom_blobs[1]; + } + else if (bottom_blobs.size() == 3) + { + C = bottom_blobs[2]; + } + + if (!C.empty()) + { + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } + } + } + + Mat& top_blob = top_blobs[0]; + if (output_transpose) + { + if (output_N1M) + top_blob.create(M, 1, N, 4u, opt.blob_allocator); + else + top_blob.create(M, N, 4u, opt.blob_allocator); + } + else + { + if (output_N1M) + top_blob.create(N, 1, M, 4u, opt.blob_allocator); + else + top_blob.create(N, M, 4u, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + gemm_transB_int8(A_int8, BT_int8, A_int8_scales, B_int8_scale, C, top_blob, alpha, beta, broadcast_type_C, output_transpose, opt); + return 0; } +#endif // NCNN_INT8 } // namespace ncnn diff --git a/src/layer/gemm.h b/src/layer/gemm.h index e006114a1498..8408772c12c0 100644 --- a/src/layer/gemm.h +++ b/src/layer/gemm.h @@ -32,6 +32,11 @@ class Gemm : public Layer virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +protected: +#if NCNN_INT8 + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif + public: float alpha; float beta; @@ -50,6 +55,8 @@ class Gemm : public Layer int output_elemtype; // 0=auto 1=fp32 int output_transpose; + int int8_scale_term; + int constant_TILE_M; int constant_TILE_N; int constant_TILE_K; @@ -58,6 +65,11 @@ class Gemm : public Layer Mat A_data; Mat B_data; Mat C_data; + +#if NCNN_INT8 + Mat A_data_int8_scales; + float B_data_int8_scale; +#endif }; } // namespace ncnn diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp new file mode 100644 index 000000000000..295625f12ad1 --- /dev/null +++ b/tests/test_gemm_3.cpp @@ -0,0 +1,316 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M = 0) +{ + alpha = 1.f;//TODO + + // transA = 0;//TODO FIXME HACK + // transB = 1;//TODO FIXME HACK + // constantA = 0;//TODO FIXME HACK + // constantB = 0;//TODO FIXME HACK + // output_transpose = 0;//TODO FIXME HACK + + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, 1.f); // beta + pd.set(2, transA); + pd.set(3, transB); + pd.set(4, constantA); + pd.set(5, constantB); + pd.set(6, 1); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, -1); + pd.set(11, output_N1M); + pd.set(13, output_elemtype); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + std::vector weights; + if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M))); + if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K))); + if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); + if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); + + std::vector a; + if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); + if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i], -10.f, 10.f); + } + + // fprintf(stderr, "test_gemm_int8 M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M); + + int ret = test_layer("Gemm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M); + } + + return ret; +} + +static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int constantC) +{ + alpha = 1.f;//TODO + beta = 1.f;//TODO + + // transA = 0;//TODO FIXME HACK + // transB = 1;//TODO FIXME HACK + // constantA = 0;//TODO FIXME HACK + // constantB = 0;//TODO FIXME HACK + // constantC = 0;//TODO FIXME HACK + // output_transpose = 0;//TODO FIXME HACK + + int broadcast_type_C = 0; + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } + + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, beta); + pd.set(2, transA); + pd.set(3, transB); + pd.set(4, constantA); + pd.set(5, constantB); + pd.set(6, constantC); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, broadcast_type_C); + // pd.set(12, 1); // output_elempack + pd.set(13, output_elemtype); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + std::vector weights; + if (constantA) weights.push_back(transA ? RandomS8Mat(M, K) : RandomS8Mat(K, M)); + if (constantB) weights.push_back(transB ? RandomS8Mat(K, N) : RandomS8Mat(N, K)); + if (constantC) weights.push_back(C); + if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); + if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); + + std::vector a; + if (!constantA) a.push_back(transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M)); + if (!constantB) a.push_back(transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K)); + if (!constantC) a.push_back(C); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i], -10.f, 10.f); + } + + // fprintf(stderr, "test_gemm_int8_bias M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_elemtype, output_transpose, constantA, constantB, constantC); + int ret = test_layer("Gemm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8_bias failed M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_elemtype, output_transpose, constantA, constantB, constantC); + } + + return ret; +} + +static int test_gemm_0(int M, int N, int K) +{ + // return 0 + // || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0) + // || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 0, 0, 0) + // || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 0, 0, 0) + // || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 0, 0, 0) + // + // || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0, 1) + // || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 0, 0, 0, 1) + // || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 0, 0, 0, 1) + // || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 0, 0, 0, 1) + // + // || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 1, 0, 0) + // || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 1, 0, 0) + // || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 1, 0, 0) + // || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 1, 0, 0) + // + // || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 1, 0, 0, 1) + // || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 1, 0, 0, 1) + // || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 1, 0, 0, 1) + // || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 1, 0, 0, 1); + + + return 0 + || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0) + || test_gemm_int8(M, N, K, 3.1f, 1, 1, 0, 1, 1, 0) + || test_gemm_int8(M, N, K, 4.1f, 0, 0, 1, 0, 0, 1) + || test_gemm_int8(M, N, K, 5.1f, 1, 0, 1, 1, 1, 1) + || test_gemm_int8(M, N, K, 2.1f, 0, 1, 2, 0, 1, 0) + || test_gemm_int8(M, N, K, 3.1f, 1, 1, 2, 1, 0, 1) + || test_gemm_int8(M, N, K, 4.1f, 0, 0, 3, 0, 1, 1) + || test_gemm_int8(M, N, K, 5.1f, 1, 0, 3, 1, 0, 0) + + || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0, 1) + || test_gemm_int8(M, N, K, 3.1f, 1, 1, 0, 1, 1, 0, 1) + || test_gemm_int8(M, N, K, 4.1f, 0, 0, 1, 0, 0, 1, 1) + || test_gemm_int8(M, N, K, 5.1f, 1, 0, 1, 1, 1, 1, 1) + || test_gemm_int8(M, N, K, 2.1f, 0, 1, 2, 0, 1, 0, 1) + || test_gemm_int8(M, N, K, 3.1f, 1, 1, 2, 1, 0, 1, 1) + || test_gemm_int8(M, N, K, 4.1f, 0, 0, 3, 0, 1, 1, 1) + || test_gemm_int8(M, N, K, 5.1f, 1, 0, 3, 1, 0, 0, 1); +} + +static int test_gemm_1(int M, int N, int K) +{ + return 0 + || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 1, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 2, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 3, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 1, 0, 0, 0, 0) + + || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 3, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 0, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 3, 0, 1, 1, 1); +} + +int main() +{ + SRAND(7767517); + + int mnk[][3] = { + {1, 1, 1}, + {2, 2, 2}, + {3, 3, 3}, + {4, 4, 4}, + {5, 5, 5}, + {6, 6, 6}, + {7, 7, 7}, + {8, 8, 8}, + {15, 15, 15}, + {16, 16, 16}, + {31, 31, 31}, + {40, 40, 40}, + {1, 1, 23}, + {1, 31, 1}, + {23, 1, 1}, + {12, 12, 23}, + {12, 31, 12}, + {23, 12, 12}, + {1, 1, 47}, + {1, 35, 1}, + {47, 1, 1}, + {24, 24, 47}, + {24, 35, 24}, + {47, 24, 24}, + {1, 35, 47}, + {23, 31, 1}, + {23, 1, 23}, + {23, 31, 23}, + {31, 7, 3}, + {28, 20, 7}, + {32, 32, 9}, + {44, 19, 7}, + {47, 35, 48}, + {47, 48, 47}, + {48, 35, 47} + }; + + int mnk_count = sizeof(mnk) / sizeof(int) / 3; + + for (int i = 0; i < mnk_count; i++) + { + int M = mnk[i][0]; + int N = mnk[i][1]; + int K = mnk[i][2]; + + int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); + + // int ret = test_gemm_0(M, N, K); + + // ncnn::Mat C(N, M); + // Randomize(C, -1000.f, 1000.f); + // + // fprintf(stderr, "C %f\n", C[0]); + // + // int ret = test_gemm_int8_bias(M, N, K, C, 1.f, 1.f, 0, 0, 0, 0, 0, 0, 1); + + if (ret != 0) + return 0; + } + + for (int M = 1; M <= 15; M++) + { + for (int N = 1; N <= 15; N++) + { + for (int K = 1; K <= 15; K++) + { + // int ret = 0 + // || test_gemm_0(M, N, K) + // || test_gemm_1(M, N, K); + + // int ret = test_gemm_0(M, N, K); + + ncnn::Mat C(N, M); + Randomize(C, -100.f, 100.f); + + // fprintf(stderr, "C %f\n", C[0]); + + int ret = test_gemm_int8_bias(M, N, K, C, 1.f, 1.f, 0, 0, 0, 0, 0, 0, 1); + + if (ret != 0) + return 0; + + ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); + if (ret != 0) + return 0; + } + } + } + + return 0; +} From ab9f5553422fae025ab1cf8df240bc039b2ee352 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 19 Sep 2024 11:12:44 +0000 Subject: [PATCH 02/55] apply code-format changes --- src/layer/arm/gemm_int8.h | 291 ++++++++++++++++---------------- src/layer/arm/gemm_int8_bf16s.h | 242 +++++++++++++------------- tests/test_gemm_3.cpp | 49 +++--- 3 files changed, 288 insertions(+), 294 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 9ccd880ea62c..21de12253bff 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -162,7 +162,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in int8x16_t _r1 = vcombine_s8(_p2, _p3); int8x16_t _r2 = vcombine_s8(_p4, _p5); int8x16_t _r3 = vcombine_s8(_p6, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x2x2_t _p01 = vzip_s32(vreinterpret_s32_s8(_p0), vreinterpret_s32_s8(_p1)); int32x2x2_t _p23 = vzip_s32(vreinterpret_s32_s8(_p2), vreinterpret_s32_s8(_p3)); int32x2x2_t _p45 = vzip_s32(vreinterpret_s32_s8(_p4), vreinterpret_s32_s8(_p5)); @@ -172,7 +172,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[1], _p23.val[1])); int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[1], _p67.val[1])); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8_t _p04 = vreinterpretq_s16_s8(vcombine_s8(_p0, _p4)); int16x8_t _p15 = vreinterpretq_s16_s8(vcombine_s8(_p1, _p5)); int16x8_t _p26 = vreinterpretq_s16_s8(vcombine_s8(_p2, _p6)); @@ -235,7 +235,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in pp[29] = p7[1]; pp[30] = p7[2]; pp[31] = p7[3]; -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -382,7 +382,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in #if __ARM_FEATURE_MATMUL_INT8 vst1q_s8(pp, vcombine_s8(_p0, _p1)); vst1q_s8(pp + 16, vcombine_s8(_p2, _p3)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x2x4_t _r0123; _r0123.val[0] = vreinterpret_s32_s8(_p0); _r0123.val[1] = vreinterpret_s32_s8(_p1); @@ -390,7 +390,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in _r0123.val[3] = vreinterpret_s32_s8(_p3); vst4_s32((int*)pp, _r0123); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4x4_t _r0123; _r0123.val[0] = vreinterpret_s16_s8(_p0); _r0123.val[1] = vreinterpret_s16_s8(_p1); @@ -423,7 +423,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in pp[13] = p3[1]; pp[14] = p3[2]; pp[15] = p3[3]; -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -494,13 +494,13 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in _r01.val[0] = vreinterpretq_s64_s8(_p0); _r01.val[1] = vreinterpretq_s64_s8(_p1); vst2q_s64((int64_t*)pp, _r01); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x4x2_t _r01; _r01.val[0] = vreinterpretq_s32_s8(_p0); _r01.val[1] = vreinterpretq_s32_s8(_p1); vst2q_s32((int*)pp, _r01); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8x2_t _r01; _r01.val[0] = vreinterpretq_s16_s8(_p0); _r01.val[1] = vreinterpretq_s16_s8(_p1); @@ -517,13 +517,13 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 vst1q_s8(pp, vcombine_s8(_p0, _p1)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x2x2_t _r01; _r01.val[0] = vreinterpret_s32_s8(_p0); _r01.val[1] = vreinterpret_s32_s8(_p1); vst2_s32((int*)pp, _r01); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4x2_t _r01; _r01.val[0] = vreinterpret_s16_s8(_p0); _r01.val[1] = vreinterpret_s16_s8(_p1); @@ -544,7 +544,7 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in pp[5] = p1[1]; pp[6] = p1[2]; pp[7] = p1[3]; -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -999,7 +999,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in int8x16_t _r1 = vcombine_s8(_p2, _p3); int8x16_t _r2 = vcombine_s8(_p4, _p5); int8x16_t _r3 = vcombine_s8(_p6, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x2x2_t _p01 = vzip_s32(vreinterpret_s32_s8(_p0), vreinterpret_s32_s8(_p1)); int32x2x2_t _p23 = vzip_s32(vreinterpret_s32_s8(_p2), vreinterpret_s32_s8(_p3)); int32x2x2_t _p45 = vzip_s32(vreinterpret_s32_s8(_p4), vreinterpret_s32_s8(_p5)); @@ -1009,7 +1009,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in int8x16_t _r2 = vreinterpretq_s8_s32(vcombine_s32(_p01.val[1], _p23.val[1])); int8x16_t _r3 = vreinterpretq_s8_s32(vcombine_s32(_p45.val[1], _p67.val[1])); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8_t _p04 = vreinterpretq_s16_s8(vcombine_s8(_p0, _p4)); int16x8_t _p15 = vreinterpretq_s16_s8(vcombine_s8(_p1, _p5)); int16x8_t _p26 = vreinterpretq_s16_s8(vcombine_s8(_p2, _p6)); @@ -1072,7 +1072,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in pp[29] = p7[1]; pp[30] = p7[2]; pp[31] = p7[3]; -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -1220,7 +1220,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in #if __ARM_FEATURE_MATMUL_INT8 vst1q_s8(pp, vcombine_s8(_p0, _p1)); vst1q_s8(pp + 16, vcombine_s8(_p2, _p3)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x2x4_t _r0123; _r0123.val[0] = vreinterpret_s32_s8(_p0); _r0123.val[1] = vreinterpret_s32_s8(_p1); @@ -1228,7 +1228,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in _r0123.val[3] = vreinterpret_s32_s8(_p3); vst4_s32((int*)pp, _r0123); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4x4_t _r0123; _r0123.val[0] = vreinterpret_s16_s8(_p0); _r0123.val[1] = vreinterpret_s16_s8(_p1); @@ -1261,7 +1261,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in pp[13] = p3[1]; pp[14] = p3[2]; pp[15] = p3[3]; -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -1332,13 +1332,13 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in _r01.val[0] = vreinterpretq_s64_s8(_p0); _r01.val[1] = vreinterpretq_s64_s8(_p1); vst2q_s64((int64_t*)pp, _r01); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x4x2_t _r01; _r01.val[0] = vreinterpretq_s32_s8(_p0); _r01.val[1] = vreinterpretq_s32_s8(_p1); vst2q_s32((int*)pp, _r01); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8x2_t _r01; _r01.val[0] = vreinterpretq_s16_s8(_p0); _r01.val[1] = vreinterpretq_s16_s8(_p1); @@ -1355,13 +1355,13 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 vst1q_s8(pp, vcombine_s8(_p0, _p1)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int32x2x2_t _r01; _r01.val[0] = vreinterpret_s32_s8(_p0); _r01.val[1] = vreinterpret_s32_s8(_p1); vst2_s32((int*)pp, _r01); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4x2_t _r01; _r01.val[0] = vreinterpret_s16_s8(_p0); _r01.val[1] = vreinterpret_s16_s8(_p1); @@ -1382,7 +1382,7 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in pp[5] = p1[1]; pp[6] = p1[2]; pp[7] = p1[3]; -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD pp[0] = p0[0]; pp[1] = p0[1]; pp[2] = p1[0]; @@ -1937,7 +1937,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r5 = float2int8(_p9, _pd); int8x8_t _r6 = float2int8(_pa, _pe); int8x8_t _r7 = float2int8(_pb, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p8, _p9); @@ -1952,7 +1952,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -2023,7 +2023,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -2138,7 +2138,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r5 = float2int8(_pa, _pb); int8x8_t _r6 = float2int8(_pc, _pd); int8x8_t _r7 = float2int8(_pe, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p8, _pa); @@ -2148,7 +2148,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r6 = float2int8(_p9, _pb); int8x8_t _r7 = float2int8(_pd, _pf); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); @@ -2204,7 +2204,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -2309,7 +2309,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -2318,7 +2318,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -2361,7 +2361,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r0, _r1)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -2443,13 +2443,13 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p1, _p3); int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); @@ -2483,7 +2483,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -2569,11 +2569,11 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); @@ -2598,7 +2598,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r0 = float2int8(_t0, _t1); @@ -2618,16 +2618,16 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 4; p0 += 2; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(p0[0] * scale0); - // pp[1] = float2int8(p0[1] * scale0); - // pp[2] = float2int8(p0[A_hstep] * scale1); - // pp[3] = float2int8(p0[A_hstep + 1] * scale1); - // pp += 4; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(p0[0] * scale0); \ +// pp[1] = float2int8(p0[1] * scale0); \ +// pp[2] = float2int8(p0[A_hstep] * scale1); \ +// pp[3] = float2int8(p0[A_hstep + 1] * scale1); \ +// pp += 4; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale0); @@ -2683,14 +2683,14 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 8; p0 += 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(p0[0] * scale); - // pp[1] = float2int8(p0[1] * scale); - // pp += 2; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(p0[0] * scale); \ +// pp[1] = float2int8(p0[1] * scale); \ +// pp += 2; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -3080,7 +3080,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -3095,7 +3095,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -3149,7 +3149,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_DOTPROD vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); int16x8x2_t _rr = vuzpq_s16(_r01, _r23); @@ -3223,7 +3223,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); vst4q_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = _r0; _r0123.val[1] = _r1; @@ -3238,7 +3238,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int vst4_s8(pp, _r0123); vst4_s8(pp + 32, _r4567); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(_r0, _r2); _r01.val[1] = vcombine_s8(_r1, _r3); @@ -3281,7 +3281,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int _r0123.val[3] = float2int8(_p6, _p7); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); @@ -3365,13 +3365,13 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -3405,7 +3405,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -3455,7 +3455,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = float2int8(_p0, _p4); _r0123.val[1] = float2int8(_p1, _p5); @@ -3464,7 +3464,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int vst4_s8(pp, _r0123); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); @@ -3494,7 +3494,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int8x8_t _r23 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r01, _r23)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8x2_t _r01; _r01.val[0] = float2int8(_p0, _p2); _r01.val[1] = float2int8(_p1, _p3); @@ -3567,11 +3567,11 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); int16x4x2_t _t01 = vzip_s16(_t0, _t1); @@ -3594,7 +3594,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_DOTPROD int8x8_t _r01 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r01 = float2int8(_t0, _t1); @@ -3643,13 +3643,13 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int int8x8x2_t _r01 = vuzp_s8(_r0, _r1); vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x2_t _r01 = vtrn_s8(_r0, _r1); int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p02 = vcombine_f32(_p0, _p2); float32x4_t _p46 = vcombine_f32(_p4, _p6); float32x4_t _p13 = vcombine_f32(_p1, _p3); @@ -3686,7 +3686,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int float32x4x2_t _pp = vuzpq_f32(_p01, _p23); int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p02 = vcombine_f32(_p0, _p2); float32x4_t _p13 = vcombine_f32(_p1, _p3); @@ -3841,14 +3841,14 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 8; p0 += A_hstep * 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(p0[0] * scale); - // pp[1] = float2int8(p0[A_hstep] * scale); - // pp += 2; - // p0 += A_hstep * 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(p0[0] * scale); \ +// pp[1] = float2int8(p0[A_hstep] * scale); \ +// pp += 2; \ +// p0 += A_hstep * 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -3932,7 +3932,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r5 = float2int8(_p9, _pd); int8x8_t _r6 = float2int8(_pa, _pe); int8x8_t _r7 = float2int8(_pb, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p8, _p9); @@ -3947,7 +3947,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -4018,7 +4018,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -4133,7 +4133,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r5 = float2int8(_pa, _pb); int8x8_t _r6 = float2int8(_pc, _pd); int8x8_t _r7 = float2int8(_pe, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p8, _pa); @@ -4143,7 +4143,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r6 = float2int8(_p9, _pb); int8x8_t _r7 = float2int8(_pd, _pf); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); @@ -4199,7 +4199,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -4300,7 +4300,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -4309,7 +4309,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -4352,7 +4352,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r0, _r1)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p0 = vld1q_f32(p0); float32x4_t _p1 = vld1q_f32(p0 + 4); float32x4_t _p2 = vld1q_f32(p0 + 8); @@ -4434,13 +4434,13 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p1, _p3); int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); @@ -4474,7 +4474,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -4553,11 +4553,11 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); @@ -4582,7 +4582,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r0 = float2int8(_t0, _t1); @@ -4602,16 +4602,16 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 4; p0 += 2; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(p0[0] * scale); - // pp[1] = float2int8(p0[1] * scale); - // pp[2] = float2int8(p0[B_hstep] * scale); - // pp[3] = float2int8(p0[B_hstep + 1] * scale); - // pp += 4; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(p0[0] * scale); \ +// pp[1] = float2int8(p0[1] * scale); \ +// pp[2] = float2int8(p0[B_hstep] * scale); \ +// pp[3] = float2int8(p0[B_hstep + 1] * scale); \ +// pp += 4; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -4664,14 +4664,14 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 8; p0 += 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(p0[0] * scale); - // pp[1] = float2int8(p0[1] * scale); - // pp += 2; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(p0[0] * scale); \ +// pp[1] = float2int8(p0[1] * scale); \ +// pp += 2; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -4772,7 +4772,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -4787,7 +4787,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -4841,7 +4841,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_DOTPROD vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); int16x8x2_t _rr = vuzpq_s16(_r01, _r23); @@ -4915,7 +4915,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); vst4q_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = _r0; _r0123.val[1] = _r1; @@ -4930,7 +4930,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int vst4_s8(pp, _r0123); vst4_s8(pp + 32, _r4567); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(_r0, _r2); _r01.val[1] = vcombine_s8(_r1, _r3); @@ -4973,7 +4973,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int _r0123.val[3] = float2int8(_p6, _p7); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); @@ -5056,13 +5056,13 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -5096,7 +5096,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -5146,7 +5146,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = float2int8(_p0, _p4); _r0123.val[1] = float2int8(_p1, _p5); @@ -5155,7 +5155,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int vst4_s8(pp, _r0123); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); @@ -5184,7 +5184,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int int8x8_t _r23 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r01, _r23)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8x2_t _r01; _r01.val[0] = float2int8(_p0, _p2); _r01.val[1] = float2int8(_p1, _p3); @@ -5247,11 +5247,11 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); int16x4x2_t _t01 = vzip_s16(_t0, _t1); @@ -5274,7 +5274,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_DOTPROD int8x8_t _r01 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r01 = float2int8(_t0, _t1); @@ -5320,13 +5320,13 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int int8x8x2_t _r01 = vuzp_s8(_r0, _r1); vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x2_t _r01 = vtrn_s8(_r0, _r1); int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p02 = vcombine_f32(_p0, _p2); float32x4_t _p46 = vcombine_f32(_p4, _p6); float32x4_t _p13 = vcombine_f32(_p1, _p3); @@ -5363,7 +5363,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int float32x4x2_t _pp = vuzpq_f32(_p01, _p23); int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _p02 = vcombine_f32(_p0, _p2); float32x4_t _p13 = vcombine_f32(_p1, _p3); @@ -5515,14 +5515,14 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 8; p0 += B_hstep * 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(p0[0] * scale); - // pp[1] = float2int8(p0[B_hstep] * scale); - // pp += 2; - // p0 += B_hstep * 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(p0[0] * scale); \ +// pp[1] = float2int8(p0[B_hstep] * scale); \ +// pp += 2; \ +// p0 += B_hstep * 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -8039,7 +8039,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - // if (out_elempack == 1) { int jj = 0; @@ -8285,7 +8284,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - // if (out_elempack == 1) { int jj = 0; @@ -8580,7 +8578,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); } -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD // from // a0 b1 c2 d3 @@ -8878,7 +8876,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); } -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD // from // a0 b1 c2 d3 @@ -8955,9 +8953,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _c0); _f6 = vaddq_f32(_f6, _c0); _f7 = vaddq_f32(_f7, _c0); - - - } if (broadcast_type_C == 1 || broadcast_type_C == 2) { @@ -9820,7 +9815,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); } -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD // from // a0 b1 c2 d3 diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index dac2a901a5d1..8f5b0da86967 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -255,7 +255,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r5 = float2int8(_p9, _pd); int8x8_t _r6 = float2int8(_pa, _pe); int8x8_t _r7 = float2int8(_pb, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p8, _p9); @@ -270,7 +270,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); uint16x8_t _r = vld1q_u16(p0 + 16); @@ -349,7 +349,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); uint16x8_t _r = vld1q_u16(p0 + A_hstep * 4); @@ -479,7 +479,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r5 = float2int8(_pa, _pb); int8x8_t _r6 = float2int8(_pc, _pd); int8x8_t _r7 = float2int8(_pe, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p8, _pa); @@ -489,7 +489,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r6 = float2int8(_p9, _pb); int8x8_t _r7 = float2int8(_pd, _pf); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); @@ -545,7 +545,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -659,7 +659,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -668,7 +668,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); uint16x8_t _r = vld1q_u16(p0 + 16); @@ -715,7 +715,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r0, _r1)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); @@ -804,13 +804,13 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p1, _p3); int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); @@ -844,7 +844,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -937,11 +937,11 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); @@ -966,7 +966,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r0 = float2int8(_t0, _t1); @@ -986,16 +986,16 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 4; p0 += 2; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); - // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale0); - // pp[2] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); - // pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); - // pp += 4; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); \ +// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale0); \ +// pp[2] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); \ +// pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); \ +// pp += 4; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); @@ -1054,14 +1054,14 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 8; p0 += 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); - // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); - // pp += 2; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ +// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); \ +// pp += 2; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -1395,7 +1395,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -1410,7 +1410,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -1468,7 +1468,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_DOTPROD vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); int16x8x2_t _rr = vuzpq_s16(_r01, _r23); @@ -1550,7 +1550,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); vst4q_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = _r0; _r0123.val[1] = _r1; @@ -1565,7 +1565,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int vst4_s8(pp, _r0123); vst4_s8(pp + 32, _r4567); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(_r0, _r2); _r01.val[1] = vcombine_s8(_r1, _r3); @@ -1612,7 +1612,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int _r0123.val[3] = float2int8(_p6, _p7); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); @@ -1704,13 +1704,13 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -1746,7 +1746,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -1796,7 +1796,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = float2int8(_p0, _p4); _r0123.val[1] = float2int8(_p1, _p5); @@ -1805,7 +1805,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int vst4_s8(pp, _r0123); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); @@ -1835,7 +1835,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int int8x8_t _r23 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r01, _r23)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8x2_t _r01; _r01.val[0] = float2int8(_p0, _p2); _r01.val[1] = float2int8(_p1, _p3); @@ -1910,11 +1910,11 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); int16x4x2_t _t01 = vzip_s16(_t0, _t1); @@ -1938,7 +1938,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_FEATURE_DOTPROD int8x8_t _r01 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r01 = float2int8(_t0, _t1); @@ -1994,13 +1994,13 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int int8x8x2_t _r01 = vuzp_s8(_r0, _r1); vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x2_t _r01 = vtrn_s8(_r0, _r1); int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = uint16x8_t(); _p = vsetq_lane_u16(p0[0], _p, 0); _p = vsetq_lane_u16(p0[1], _p, 1); @@ -2059,7 +2059,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int float32x4x2_t _pp = vuzpq_f32(_p01, _p23); int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = uint16x8_t(); _p = vsetq_lane_u16(p0[0], _p, 0); _p = vsetq_lane_u16(p0[1], _p, 1); @@ -2226,14 +2226,14 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int pp += 8; p0 += A_hstep * 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); - // pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale); - // pp += 2; - // p0 += A_hstep * 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ +// pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale); \ +// pp += 2; \ +// p0 += A_hstep * 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -2315,7 +2315,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r5 = float2int8(_p9, _pd); int8x8_t _r6 = float2int8(_pa, _pe); int8x8_t _r7 = float2int8(_pb, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p8, _p9); @@ -2330,7 +2330,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); uint16x8_t _r = vld1q_u16(p0 + 16); @@ -2409,7 +2409,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); uint16x8_t _r = vld1q_u16(p0 + B_hstep * 4); @@ -2538,7 +2538,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r5 = float2int8(_pa, _pb); int8x8_t _r6 = float2int8(_pc, _pd); int8x8_t _r7 = float2int8(_pe, _pf); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p8, _pa); @@ -2548,7 +2548,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r6 = float2int8(_p9, _pb); int8x8_t _r7 = float2int8(_pd, _pf); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p8, _pa)); @@ -2604,7 +2604,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -2714,7 +2714,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -2723,7 +2723,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); uint16x8_t _r = vld1q_u16(p0 + 16); @@ -2770,7 +2770,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r0, _r1)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = vld1q_u16(p0); uint16x8_t _q = vld1q_u16(p0 + 8); float32x4_t _p0 = bfloat2float(vget_low_u16(_p)); @@ -2859,13 +2859,13 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p4, _p6); int8x8_t _r2 = float2int8(_p1, _p3); int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); @@ -2899,7 +2899,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -2985,11 +2985,11 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p2)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p2)); float32x4_t _t2 = vcombine_f32(vget_low_f32(_p1), vget_low_f32(_p3)); @@ -3014,7 +3014,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r0 = float2int8(_t0, _t1); @@ -3034,16 +3034,16 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 4; p0 += 2; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); - // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); - // pp[2] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); - // pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); - // pp += 4; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ +// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); \ +// pp[2] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); \ +// pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); \ +// pp += 4; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -3099,14 +3099,14 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 8; p0 += 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); - // pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); - // pp += 2; - // p0 += 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ +// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); \ +// pp += 2; \ +// p0 += 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -3215,7 +3215,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -3230,7 +3230,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); @@ -3288,7 +3288,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_DOTPROD vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); int16x8x2_t _rr = vuzpq_s16(_r01, _r23); @@ -3370,7 +3370,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int _r0123.val[3] = vcombine_s8(_r37.val[0], _r37.val[1]); vst4q_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = _r0; _r0123.val[1] = _r1; @@ -3385,7 +3385,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int vst4_s8(pp, _r0123); vst4_s8(pp + 32, _r4567); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(_r0, _r2); _r01.val[1] = vcombine_s8(_r1, _r3); @@ -3432,7 +3432,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int _r0123.val[3] = float2int8(_p6, _p7); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p1), float2int8(_p4, _p5)); _r01.val[1] = vcombine_s8(float2int8(_p2, _p3), float2int8(_p6, _p7)); @@ -3522,13 +3522,13 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int int8x8_t _r1 = float2int8(_p1, _p5); int8x8_t _r2 = float2int8(_p2, _p6); int8x8_t _r3 = float2int8(_p3, _p7); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p4, _p5)); @@ -3564,7 +3564,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p1)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p2, _p3)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); @@ -3614,7 +3614,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int _r0123.val[3] = float2int8(_p37.val[0], _p37.val[1]); vst4_s8(pp, _r0123); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x4_t _r0123; _r0123.val[0] = float2int8(_p0, _p4); _r0123.val[1] = float2int8(_p1, _p5); @@ -3623,7 +3623,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int vst4_s8(pp, _r0123); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x16x2_t _r01; _r01.val[0] = vcombine_s8(float2int8(_p0, _p2), float2int8(_p4, _p6)); _r01.val[1] = vcombine_s8(float2int8(_p1, _p3), float2int8(_p5, _p7)); @@ -3652,7 +3652,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int int8x8_t _r23 = float2int8(_p2, _p3); vst1q_s8(pp, vcombine_s8(_r01, _r23)); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int8x8x2_t _r01; _r01.val[0] = float2int8(_p0, _p2); _r01.val[1] = float2int8(_p1, _p3); @@ -3717,11 +3717,11 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p2); int8x8_t _r1 = float2int8(_p1, _p3); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p3)); int16x4x2_t _t01 = vzip_s16(_t0, _t1); @@ -3745,7 +3745,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int #if __ARM_FEATURE_DOTPROD int8x8_t _r01 = float2int8(_p0, _p1); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD float32x4_t _t0 = vcombine_f32(vget_low_f32(_p0), vget_low_f32(_p1)); float32x4_t _t1 = vcombine_f32(vget_high_f32(_p0), vget_high_f32(_p1)); int8x8_t _r01 = float2int8(_t0, _t1); @@ -3800,13 +3800,13 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int int8x8x2_t _r01 = vuzp_s8(_r0, _r1); vst1q_s8(pp, vcombine_s8(_r01.val[0], _r01.val[1])); -#else // __ARM_FEATURE_MATMUL_INT8 +#else // __ARM_FEATURE_MATMUL_INT8 int8x8x2_t _r01 = vtrn_s8(_r0, _r1); int8x8x2_t _rr01 = vuzp_s8(_r01.val[0], _r01.val[1]); vst1q_s8(pp, vcombine_s8(_rr01.val[0], _rr01.val[1])); #endif // __ARM_FEATURE_MATMUL_INT8 -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = uint16x8_t(); _p = vsetq_lane_u16(p0[0], _p, 0); _p = vsetq_lane_u16(p0[1], _p, 1); @@ -3865,7 +3865,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int float32x4x2_t _pp = vuzpq_f32(_p01, _p23); int8x8_t _r01 = float2int8(_pp.val[0], _pp.val[1]); -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD uint16x8_t _p = uint16x8_t(); _p = vsetq_lane_u16(p0[0], _p, 0); _p = vsetq_lane_u16(p0[1], _p, 1); @@ -4029,14 +4029,14 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int pp += 8; p0 += B_hstep * 8; } -#endif // __ARM_NEON - // for (; kk + 1 < max_kk; kk += 2) - // { - // pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); - // pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); - // pp += 2; - // p0 += B_hstep * 2; - // } +#endif // __ARM_NEON \ +// for (; kk + 1 < max_kk; kk += 2) \ +// { \ +// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ +// pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); \ +// pp += 2; \ +// p0 += B_hstep * 2; \ +// } for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -7185,7 +7185,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); } -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD // from // a0 b1 c2 d3 @@ -7474,7 +7474,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); } -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD // from // a0 b1 c2 d3 @@ -8429,7 +8429,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); } -#else // __ARM_FEATURE_DOTPROD +#else // __ARM_FEATURE_DOTPROD // from // a0 b1 c2 d3 diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index 295625f12ad1..11b78a263383 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -16,7 +16,7 @@ static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M = 0) { - alpha = 1.f;//TODO + alpha = 1.f; //TODO // transA = 0;//TODO FIXME HACK // transB = 1;//TODO FIXME HACK @@ -56,7 +56,7 @@ static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int tran Randomize(a[i], -10.f, 10.f); } - // fprintf(stderr, "test_gemm_int8 M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M); + // fprintf(stderr, "test_gemm_int8 M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M); int ret = test_layer("Gemm", pd, weights, a); if (ret != 0) @@ -69,8 +69,8 @@ static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int tran static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float alpha, float beta, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int constantC) { - alpha = 1.f;//TODO - beta = 1.f;//TODO + alpha = 1.f; //TODO + beta = 1.f; //TODO // transA = 0;//TODO FIXME HACK // transB = 1;//TODO FIXME HACK @@ -146,7 +146,7 @@ static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float al Randomize(a[i], -10.f, 10.f); } - // fprintf(stderr, "test_gemm_int8_bias M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_elemtype, output_transpose, constantA, constantB, constantC); + // fprintf(stderr, "test_gemm_int8_bias M=%d N=%d K=%d C.dims=%d C=(%d %d %d) alpha=%f beta=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d constantC=%d\n", M, N, K, C.dims, C.w, C.h, C.c, alpha, beta, transA, transB, output_elemtype, output_transpose, constantA, constantB, constantC); int ret = test_layer("Gemm", pd, weights, a); if (ret != 0) { @@ -179,7 +179,6 @@ static int test_gemm_0(int M, int N, int K) // || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 1, 0, 0, 1) // || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 1, 0, 0, 1); - return 0 || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0) || test_gemm_int8(M, N, K, 3.1f, 1, 1, 0, 1, 1, 0) @@ -285,31 +284,31 @@ int main() for (int M = 1; M <= 15; M++) { - for (int N = 1; N <= 15; N++) - { - for (int K = 1; K <= 15; K++) - { - // int ret = 0 - // || test_gemm_0(M, N, K) - // || test_gemm_1(M, N, K); + for (int N = 1; N <= 15; N++) + { + for (int K = 1; K <= 15; K++) + { + // int ret = 0 + // || test_gemm_0(M, N, K) + // || test_gemm_1(M, N, K); - // int ret = test_gemm_0(M, N, K); + // int ret = test_gemm_0(M, N, K); - ncnn::Mat C(N, M); - Randomize(C, -100.f, 100.f); + ncnn::Mat C(N, M); + Randomize(C, -100.f, 100.f); - // fprintf(stderr, "C %f\n", C[0]); + // fprintf(stderr, "C %f\n", C[0]); - int ret = test_gemm_int8_bias(M, N, K, C, 1.f, 1.f, 0, 0, 0, 0, 0, 0, 1); + int ret = test_gemm_int8_bias(M, N, K, C, 1.f, 1.f, 0, 0, 0, 0, 0, 0, 1); - if (ret != 0) - return 0; + if (ret != 0) + return 0; - ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); - if (ret != 0) - return 0; - } - } + ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); + if (ret != 0) + return 0; + } + } } return 0; From dd6bf5cb63ea4a87997ef0c11e0512aab2532946 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 20 Sep 2024 09:47:00 +0800 Subject: [PATCH 03/55] clean --- src/layer/arm/gemm_int8.h | 58 ++++----------------------------- src/layer/arm/gemm_int8_bf16s.h | 58 ++++----------------------------- 2 files changed, 12 insertions(+), 104 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 21de12253bff..a532f7cf5981 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -2618,16 +2618,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 4; p0 += 2; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(p0[0] * scale0); \ -// pp[1] = float2int8(p0[1] * scale0); \ -// pp[2] = float2int8(p0[A_hstep] * scale1); \ -// pp[3] = float2int8(p0[A_hstep + 1] * scale1); \ -// pp += 4; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale0); @@ -2683,14 +2674,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 8; p0 += 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(p0[0] * scale); \ -// pp[1] = float2int8(p0[1] * scale); \ -// pp += 2; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -3841,14 +3825,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int pp += 8; p0 += A_hstep * 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(p0[0] * scale); \ -// pp[1] = float2int8(p0[A_hstep] * scale); \ -// pp += 2; \ -// p0 += A_hstep * 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -4602,16 +4579,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 4; p0 += 2; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(p0[0] * scale); \ -// pp[1] = float2int8(p0[1] * scale); \ -// pp[2] = float2int8(p0[B_hstep] * scale); \ -// pp[3] = float2int8(p0[B_hstep + 1] * scale); \ -// pp += 4; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -4664,14 +4632,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 8; p0 += 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(p0[0] * scale); \ -// pp[1] = float2int8(p0[1] * scale); \ -// pp += 2; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); @@ -5515,14 +5476,7 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int pp += 8; p0 += B_hstep * 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(p0[0] * scale); \ -// pp[1] = float2int8(p0[B_hstep] * scale); \ -// pp += 2; \ -// p0 += B_hstep * 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(p0[0] * scale); diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 8f5b0da86967..b6ce8b417644 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -986,16 +986,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 4; p0 += 2; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); \ -// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale0); \ -// pp[2] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale1); \ -// pp[3] = float2int8(bfloat16_to_float32(p0[A_hstep + 1]) * scale1); \ -// pp += 4; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale0); @@ -1054,14 +1045,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp += 8; p0 += 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ -// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); \ -// pp += 2; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -2226,14 +2210,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int pp += 8; p0 += A_hstep * 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ -// pp[1] = float2int8(bfloat16_to_float32(p0[A_hstep]) * scale); \ -// pp += 2; \ -// p0 += A_hstep * 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -3034,16 +3011,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 4; p0 += 2; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ -// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); \ -// pp[2] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); \ -// pp[3] = float2int8(bfloat16_to_float32(p0[B_hstep + 1]) * scale); \ -// pp += 4; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -3099,14 +3067,7 @@ static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i pp += 8; p0 += 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ -// pp[1] = float2int8(bfloat16_to_float32(p0[1]) * scale); \ -// pp += 2; \ -// p0 += 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); @@ -4029,14 +3990,7 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int pp += 8; p0 += B_hstep * 8; } -#endif // __ARM_NEON \ -// for (; kk + 1 < max_kk; kk += 2) \ -// { \ -// pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); \ -// pp[1] = float2int8(bfloat16_to_float32(p0[B_hstep]) * scale); \ -// pp += 2; \ -// p0 += B_hstep * 2; \ -// } +#endif // __ARM_NEON for (; kk < max_kk; kk++) { pp[0] = float2int8(bfloat16_to_float32(p0[0]) * scale); From 86b03dfe4ae21673ff4fd1996e50a902edb2986d Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 20 Sep 2024 10:28:23 +0800 Subject: [PATCH 04/55] build for armv7 --- src/layer/arm/gemm_int8.h | 114 ++++++++++++++++++++++++++++++++ src/layer/arm/gemm_int8_bf16s.h | 114 ++++++++++++++++++++++++++++++++ 2 files changed, 228 insertions(+) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index a532f7cf5981..12d6d51b7efc 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -2111,6 +2111,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _pe = vld1q_f32(p0 + A_hstep * 7); float32x4_t _pf = vld1q_f32(p0 + A_hstep * 7 + 4); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 0); _p2 = vmulq_laneq_f32(_p2, _scale0, 1); @@ -2127,6 +2128,24 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i _pd = vmulq_laneq_f32(_pd, _scale1, 2); _pe = vmulq_laneq_f32(_pe, _scale1, 3); _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale0), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale0), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale0), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale0), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale0), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale1), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale1), 0); + _pa = vmulq_lane_f32(_pa, vget_low_f32(_scale1), 1); + _pb = vmulq_lane_f32(_pb, vget_low_f32(_scale1), 1); + _pc = vmulq_lane_f32(_pc, vget_high_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_high_f32(_scale1), 0); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 1); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -2190,6 +2209,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 6); float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 7); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 1); _p2 = vmulq_laneq_f32(_p2, _scale0, 2); @@ -2198,6 +2218,16 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i _p5 = vmulq_laneq_f32(_p5, _scale1, 1); _p6 = vmulq_laneq_f32(_p6, _scale1, 2); _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); @@ -2428,6 +2458,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 3); float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 3 + 4); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 0); _p2 = vmulq_laneq_f32(_p2, _scale, 1); @@ -2436,6 +2467,16 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i _p5 = vmulq_laneq_f32(_p5, _scale, 2); _p6 = vmulq_laneq_f32(_p6, _scale, 3); _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -2475,10 +2516,17 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _p2 = vld1q_f32(p0 + A_hstep * 2); float32x4_t _p3 = vld1q_f32(p0 + A_hstep * 3); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 1); _p2 = vmulq_laneq_f32(_p2, _scale, 2); _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); @@ -3032,6 +3080,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _pe = vld1q_f32(p0 + A_hstep * 4 + 24); float32x4_t _pf = vld1q_f32(p0 + A_hstep * 4 + 28); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 1); _p2 = vmulq_laneq_f32(_p2, _scale0, 2); @@ -3048,6 +3097,24 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int _pd = vmulq_laneq_f32(_pd, _scale1, 1); _pe = vmulq_laneq_f32(_pe, _scale1, 2); _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale0), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale0), 1); + _pa = vmulq_lane_f32(_pa, vget_high_f32(_scale0), 0); + _pb = vmulq_lane_f32(_pb, vget_high_f32(_scale0), 1); + _pc = vmulq_lane_f32(_pc, vget_low_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_low_f32(_scale1), 1); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 0); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -3116,6 +3183,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _p6 = vld1q_f32(p0 + 24); float32x4_t _p7 = vld1q_f32(p0 + 28); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 1); _p2 = vmulq_laneq_f32(_p2, _scale0, 2); @@ -3124,6 +3192,16 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int _p5 = vmulq_laneq_f32(_p5, _scale1, 1); _p6 = vmulq_laneq_f32(_p6, _scale1, 2); _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); @@ -3334,6 +3412,7 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _p6 = vld1q_f32(p0 + A_hstep * 4 + 8); float32x4_t _p7 = vld1q_f32(p0 + A_hstep * 4 + 12); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 1); _p2 = vmulq_laneq_f32(_p2, _scale, 2); @@ -3342,6 +3421,16 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int _p5 = vmulq_laneq_f32(_p5, _scale, 1); _p6 = vmulq_laneq_f32(_p6, _scale, 2); _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -3381,10 +3470,17 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _p2 = vld1q_f32(p0 + 8); float32x4_t _p3 = vld1q_f32(p0 + 12); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 1); _p2 = vmulq_laneq_f32(_p2, _scale, 2); _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); @@ -8910,6 +9006,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } if (broadcast_type_C == 1 || broadcast_type_C == 2) { +#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); @@ -8918,6 +9015,16 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); + float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); + float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); + float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); + float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); @@ -9993,10 +10100,17 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } if (broadcast_type_C == 1 || broadcast_type_C == 2) { +#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); +#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index b6ce8b417644..9cecdc892989 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -452,6 +452,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _pe = bfloat2float(vget_low_u16(_w)); float32x4_t _pf = bfloat2float(vget_high_u16(_w)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 0); _p2 = vmulq_laneq_f32(_p2, _scale0, 1); @@ -468,6 +469,24 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i _pd = vmulq_laneq_f32(_pd, _scale1, 2); _pe = vmulq_laneq_f32(_pe, _scale1, 3); _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale0), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale0), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale0), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale0), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale0), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale1), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale1), 0); + _pa = vmulq_lane_f32(_pa, vget_low_f32(_scale1), 1); + _pb = vmulq_lane_f32(_pb, vget_low_f32(_scale1), 1); + _pc = vmulq_lane_f32(_pc, vget_high_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_high_f32(_scale1), 0); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 1); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -531,6 +550,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _p6 = bfloat2float(vld1_u16(p0 + A_hstep * 6)); float32x4_t _p7 = bfloat2float(vld1_u16(p0 + A_hstep * 7)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 1); _p2 = vmulq_laneq_f32(_p2, _scale0, 2); @@ -539,6 +559,16 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i _p5 = vmulq_laneq_f32(_p5, _scale1, 1); _p6 = vmulq_laneq_f32(_p6, _scale1, 2); _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); @@ -789,6 +819,7 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 0); _p2 = vmulq_laneq_f32(_p2, _scale, 1); @@ -797,6 +828,16 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i _p5 = vmulq_laneq_f32(_p5, _scale, 2); _p6 = vmulq_laneq_f32(_p6, _scale, 3); _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 0); + _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale), 1); + _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale), 0); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 1); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -836,10 +877,17 @@ static void pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _p2 = bfloat2float(vld1_u16(p0 + A_hstep * 2)); float32x4_t _p3 = bfloat2float(vld1_u16(p0 + A_hstep * 3)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 1); _p2 = vmulq_laneq_f32(_p2, _scale, 2); _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); @@ -1347,6 +1395,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _pe = bfloat2float(vget_low_u16(_w)); float32x4_t _pf = bfloat2float(vget_high_u16(_w)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 1); _p2 = vmulq_laneq_f32(_p2, _scale0, 2); @@ -1363,6 +1412,24 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int _pd = vmulq_laneq_f32(_pd, _scale1, 1); _pe = vmulq_laneq_f32(_pe, _scale1, 2); _pf = vmulq_laneq_f32(_pf, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); + _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale0), 0); + _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale0), 1); + _pa = vmulq_lane_f32(_pa, vget_high_f32(_scale0), 0); + _pb = vmulq_lane_f32(_pb, vget_high_f32(_scale0), 1); + _pc = vmulq_lane_f32(_pc, vget_low_f32(_scale1), 0); + _pd = vmulq_lane_f32(_pd, vget_low_f32(_scale1), 1); + _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 0); + _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -1435,6 +1502,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 1); _p2 = vmulq_laneq_f32(_p2, _scale0, 2); @@ -1443,6 +1511,16 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int _p5 = vmulq_laneq_f32(_p5, _scale1, 1); _p6 = vmulq_laneq_f32(_p6, _scale1, 2); _p7 = vmulq_laneq_f32(_p7, _scale1, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale0), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale0), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale1), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale1), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale1), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale1), 1); +#endif int8x8_t _r0 = float2int8(_p0, _p1); int8x8_t _r1 = float2int8(_p2, _p3); @@ -1673,6 +1751,7 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _p6 = bfloat2float(vget_low_u16(_s)); float32x4_t _p7 = bfloat2float(vget_high_u16(_s)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 1); _p2 = vmulq_laneq_f32(_p2, _scale, 2); @@ -1681,6 +1760,16 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int _p5 = vmulq_laneq_f32(_p5, _scale, 1); _p6 = vmulq_laneq_f32(_p6, _scale, 2); _p7 = vmulq_laneq_f32(_p7, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); + _p4 = vmulq_lane_f32(_p4, vget_low_f32(_scale), 0); + _p5 = vmulq_lane_f32(_p5, vget_low_f32(_scale), 1); + _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 0); + _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -1722,10 +1811,17 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _p2 = bfloat2float(vget_low_u16(_q)); float32x4_t _p3 = bfloat2float(vget_high_u16(_q)); +#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 1); _p2 = vmulq_laneq_f32(_p2, _scale, 2); _p3 = vmulq_laneq_f32(_p3, _scale, 3); +#else + _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); + _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 1); + _p2 = vmulq_lane_f32(_p2, vget_high_f32(_scale), 0); + _p3 = vmulq_lane_f32(_p3, vget_high_f32(_scale), 1); +#endif #if __ARM_FEATURE_DOTPROD int8x8_t _r0 = float2int8(_p0, _p1); @@ -7508,6 +7604,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 1 || broadcast_type_C == 2) { +#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); @@ -7516,6 +7613,16 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); + float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); + float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); + float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); + float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); @@ -8607,10 +8714,17 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 1 || broadcast_type_C == 2) { +#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); +#else + float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); + float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); + float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); + float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); +#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); From 7f2d1daf133239baa7dead20ca4a825de9090d84 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 20 Sep 2024 14:36:15 +0800 Subject: [PATCH 05/55] quantize gemm --- tools/quantize/ncnn2int8.cpp | 109 +++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/tools/quantize/ncnn2int8.cpp b/tools/quantize/ncnn2int8.cpp index 5e92b333aa57..686accc6089c 100644 --- a/tools/quantize/ncnn2int8.cpp +++ b/tools/quantize/ncnn2int8.cpp @@ -134,6 +134,7 @@ class NetQuantize : public ModelWriter int quantize_gru(); int quantize_embed(); + int quantize_gemm(); int fuse_requantize(); }; @@ -613,6 +614,113 @@ int NetQuantize::quantize_embed() return 0; } +int NetQuantize::quantize_gemm() +{ + for (size_t i = 0; i < layers.size(); i++) + { + if (layers[i]->type != "Gemm") + continue; + + // Gemm - quantize weight from fp32 to int8 + ncnn::Gemm* gemm = (ncnn::Gemm*)layers[i]; + + fprintf(stderr, "quantize_gemm %s\n", gemm->name.c_str()); + + // TODO move to ncnn2table + + if (gemm->constantA) + { + if (gemm->transA == 1) + { + // transpose for easier quantization + ncnn::Mat A_data_transposed(gemm->constantK * gemm->constantM); + for (int i = 0; i < gemm->constantM; i++) + { + float* ptr = (float*)A_data_transposed + i * gemm->constantK; + for (int j = 0; j < gemm->constantK; j++) + { + ptr[j] = gemm->A_data[j * gemm->constantM + i]; + } + } + gemm->A_data = A_data_transposed; + gemm->transA = 0; + } + + gemm->A_data_int8_scales.create(gemm->constantM); + for (int i = 0; i < gemm->constantM; i++) + { + float absmax = 0.f; + + const float* ptr = (const float*)gemm->A_data + i * gemm->constantK; + for (int j = 0; j < gemm->constantK; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + gemm->A_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax; + } + + ncnn::Mat A_data = gemm->A_data.reshape(gemm->constantK, gemm->constantM); + ncnn::Mat A_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = A_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(A_data, A_data_int8, gemm->A_data_int8_scales, opt_q); + if (A_data_int8.empty()) + return -100; + + gemm->A_data = A_data_int8.reshape(gemm->constantK * gemm->constantM); + } + + if (gemm->constantB) + { + if (gemm->transB == 0) + { + // transpose for easier quantization + ncnn::Mat B_data_transposed(gemm->constantK * gemm->constantN); + for (int i = 0; i < gemm->constantN; i++) + { + float* ptr = (float*)B_data_transposed + i * gemm->constantK; + for (int j = 0; j < gemm->constantK; j++) + { + ptr[j] = gemm->B_data[j * gemm->constantN + i]; + } + } + gemm->B_data = B_data_transposed; + gemm->transB = 1; + } + + const float* ptr = gemm->B_data; + float absmax = 0.f; + for (int j = 0; j < gemm->B_data.w; j++) + { + absmax = std::max(absmax, (float)fabs(ptr[j])); + } + + gemm->B_data_int8_scale = absmax == 0.f ? 1.f : 127 / absmax; + + ncnn::Mat B_data_int8_scales(1); + B_data_int8_scales[0] = gemm->B_data_int8_scale; + + ncnn::Mat B_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = gemm->B_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(gemm->B_data, B_data_int8, B_data_int8_scales, opt_q); + if (B_data_int8.empty()) + return -100; + + gemm->B_data = B_data_int8; + } + + gemm->int8_scale_term = 2; + } + + return 0; +} + int NetQuantize::fuse_requantize() { const size_t layer_count = layers.size(); @@ -861,6 +969,7 @@ int main(int argc, char** argv) quantizer.quantize_lstm(); quantizer.quantize_gru(); quantizer.quantize_embed(); + quantizer.quantize_gemm(); quantizer.fuse_requantize(); From 37208ab57e0025bae7e18ae46e7e906a640c143d Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 20 Sep 2024 14:37:55 +0800 Subject: [PATCH 06/55] write gemm quantize scales --- tools/modelwriter.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tools/modelwriter.h b/tools/modelwriter.h index ff86338bca9c..218b211901fb 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1773,6 +1773,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 12=%d", output_elempack) fprintf_param_value(" 13=%d", output_elemtype) fprintf_param_value(" 14=%d", output_transpose) + fprintf_param_value(" 18=%d", int8_scale_term) fprintf_param_value(" 20=%d", constant_TILE_M) fprintf_param_value(" 21=%d", constant_TILE_N) fprintf_param_value(" 22=%d", constant_TILE_K) @@ -1789,6 +1790,23 @@ int ModelWriter::save(const char* parampath, const char* binpath) { fwrite_weight_tag_data(op->C_data, bp); } + +#if NCNN_INT8 + // write int8_scale data + if (op->int8_scale_term) + { + if (op->constantA == 1) + { + fwrite_weight_data(op->A_data_int8_scales, bp, 90, 100); + } + if (op->constantB == 1) + { + ncnn::Mat B_data_int8_scales(1); + B_data_int8_scales[0] = op->B_data_int8_scale; + fwrite_weight_data(B_data_int8_scales, bp, 90, 100); + } + } +#endif // NCNN_INT8 } else if (layer->type == "GLU") { From 47cd674976a53b01a2b773a66ff52dcda99aa287 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 20 Sep 2024 14:40:31 +0800 Subject: [PATCH 07/55] update doc --- docs/developer-guide/operators.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index de4d6b428e99..28f1ce626466 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -942,15 +942,18 @@ y = (gemm(a, b) + c * beta) * alpha | 12 | output_elempack | int | 0 | | | 13 | output_elemtype | int | 0 | | | 14 | output_transpose | int| 0 | | +| 18 | int8_scale_term | int | 0 | | | 20 | constant_TILE_M | int | 0 | | | 21 | constant_TILE_N | int | 0 | | | 22 | constant_TILE_K | int | 0 | | | weight | type | shape | | ------------- | ----- | --------------------- | -| A_data | float | [M, K] or [K, M] | -| B_data | float | [N, K] or [K, N] | +| A_data | float/fp16/int8 | [M, K] or [K, M] | +| B_data | float/fp16/int8 | [N, K] or [K, N] | | C_data | float | [1], [M] or [N] or [1, M] or [N,1] or [N, M] | +| A_data_int8_scales| float | [M] | +| B_data_int8_scales| float | [1] | # GridSample ``` From ecf2e3b5e05229957f64520949f29d6810a26d7c Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 20 Sep 2024 19:30:20 +0800 Subject: [PATCH 08/55] wip --- src/layer/arm/gemm_arm.cpp | 31 +- src/layer/arm/gemm_int8.h | 1626 +++++++++++++++++++++++++++++------- 2 files changed, 1315 insertions(+), 342 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index c1e0c3e0d697..b209f24b9c24 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -6160,21 +6160,6 @@ int Gemm_arm::forward_int8(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector Date: Mon, 23 Sep 2024 14:50:07 +0800 Subject: [PATCH 09/55] fp32 alpha beta --- src/layer/arm/gemm_arm.cpp | 57 +- src/layer/arm/gemm_arm_asimddp.cpp | 8 +- src/layer/arm/gemm_int8.h | 1145 ++++++++++++++++++++-------- tests/test_gemm_3.cpp | 82 +- 4 files changed, 846 insertions(+), 446 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index b209f24b9c24..62efb5f6cacd 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -47,7 +47,7 @@ Gemm_arm::Gemm_arm() #endif // __ARM_NEON #if NCNN_BF16 - support_bf16_storage = true; + // support_bf16_storage = true; #endif nT = 0; @@ -5303,7 +5303,7 @@ int Gemm_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector Date: Wed, 25 Sep 2024 17:55:50 +0800 Subject: [PATCH 10/55] stash --- src/layer/arm/gemm_arm.cpp | 16 +- src/layer/arm/gemm_arm_asimddp.cpp | 8 +- src/layer/arm/gemm_int8.h | 600 ++++---- src/layer/arm/gemm_int8_bf16s.h | 2056 +++++++++++++++++++++------- 4 files changed, 1790 insertions(+), 890 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 62efb5f6cacd..522ffd7d7049 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5530,14 +5530,14 @@ static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } @@ -5728,14 +5728,14 @@ static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } @@ -5849,14 +5849,14 @@ static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, con if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } @@ -5923,14 +5923,14 @@ static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Ma if (output_transpose) { if (top_blob.elembits() == 16) - transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + transpose_unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else transpose_unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } else { if (top_blob.elembits() == 16) - unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales); + unpack_output_tile_int32_to_bf16(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); else unpack_output_tile_int32_to_fp32(topT_tile, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, output_descales, alpha, beta); } diff --git a/src/layer/arm/gemm_arm_asimddp.cpp b/src/layer/arm/gemm_arm_asimddp.cpp index 821e3d2812e8..518fbd46c76b 100644 --- a/src/layer/arm/gemm_arm_asimddp.cpp +++ b/src/layer/arm/gemm_arm_asimddp.cpp @@ -100,14 +100,14 @@ void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, in transpose_pack_B_tile_bf16_to_int8(B, BT, j, max_jj, k, max_kk, scale); } -void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { - unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); } -void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { - transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + transpose_unpack_output_tile_int32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); } #endif // NCNN_BF16 diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 42779200c34e..10045d2f7221 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -6007,7 +6007,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _cc0 = vmulq_f32(_cc0, _beta); _cc1 = vmulq_f32(_cc1, _beta); } -#if __aarch64__ _c0 = vdupq_laneq_f32(_cc0, 0); _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -6016,16 +6015,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -6194,23 +6183,13 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { float32x4_t _c2; float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; if (c_elempack == 1) { _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep); _c2 = vld1q_f32(pC + c_hstep * 2); _c3 = vld1q_f32(pC + c_hstep * 3); - _c4 = vld1q_f32(pC + c_hstep * 4); - _c5 = vld1q_f32(pC + c_hstep * 5); - _c6 = vld1q_f32(pC + c_hstep * 6); - _c7 = vld1q_f32(pC + c_hstep * 7); transpose4x4_ps(_c0, _c1, _c2, _c3); - transpose4x4_ps(_c4, _c5, _c6, _c7); - pC += 4; } else // if (c_elempack == 4) { @@ -6218,11 +6197,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c1 = vld1q_f32(pC + 4); _c2 = vld1q_f32(pC + 8); _c3 = vld1q_f32(pC + 12); - _c4 = vld1q_f32(pC + c_hstep * 4); - _c5 = vld1q_f32(pC + c_hstep * 4 + 4); - _c6 = vld1q_f32(pC + c_hstep * 4 + 8); - _c7 = vld1q_f32(pC + c_hstep * 4 + 12); - pC += 16; } if (beta == 1.f) { @@ -6230,10 +6204,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -6242,19 +6212,44 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + pC += 16; + } + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } } if (broadcast_type_C == 4) { float32x4_t _c = vld1q_f32(pC); - if (beta != 1.f) - { - _c = vmulq_n_f32(_c, beta); - } + _c = vmulq_n_f32(_c, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_c, 0); _c1 = vdupq_laneq_f32(_c, 1); @@ -6416,10 +6411,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 4) { float32x2_t _c = vld1_f32(pC); - if (beta != 1.f) - { - _c = vmul_n_f32(_c, beta); - } + _c = vmul_n_f32(_c, beta); _c0 = vdupq_lane_f32(_c, 0); _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); @@ -6481,7 +6473,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep * 4); @@ -6727,7 +6719,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); @@ -6736,16 +6727,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -7087,28 +7068,31 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); #else float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); #endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); _f3 = vaddq_f32(_f3, _cc3); - _f4 = vaddq_f32(_f4, _cc4); - _f5 = vaddq_f32(_f5, _cc5); - _f6 = vaddq_f32(_f6, _cc6); - _f7 = vaddq_f32(_f7, _cc7); +#if __aarch64__ + _cc0 = vdupq_laneq_f32(_c1, 0); + _cc1 = vdupq_laneq_f32(_c1, 1); + _cc2 = vdupq_laneq_f32(_c1, 2); + _cc3 = vdupq_laneq_f32(_c1, 3); +#else + _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); + _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); + _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); + _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f4 = vaddq_f32(_f4, _cc0); + _f5 = vaddq_f32(_f5, _cc1); + _f6 = vaddq_f32(_f6, _cc2); + _f7 = vaddq_f32(_f7, _cc3); } if (broadcast_type_C == 3) { @@ -7118,20 +7102,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -7140,27 +7116,37 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 4; } else // if (c_elempack == 4) { float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + c_hstep * 4); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _cc0.val[0]); _f1 = vaddq_f32(_f1, _cc0.val[1]); _f2 = vaddq_f32(_f2, _cc0.val[2]); _f3 = vaddq_f32(_f3, _cc0.val[3]); - _f4 = vaddq_f32(_f4, _cc1.val[0]); - _f5 = vaddq_f32(_f5, _cc1.val[1]); - _f6 = vaddq_f32(_f6, _cc1.val[2]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); } else { @@ -7169,10 +7155,22 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - _f4 = vmlaq_f32(_f4, _cc1.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc1.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); + } + _cc0 = vld4q_f32(pC + c_hstep * 4); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _cc0.val[0]); + _f5 = vaddq_f32(_f5, _cc0.val[1]); + _f6 = vaddq_f32(_f6, _cc0.val[2]); + _f7 = vaddq_f32(_f7, _cc0.val[3]); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); + _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); + _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); + _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); } pC += 16; } @@ -7303,12 +7301,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x2_t _cc1 = vld1_f32(pC + c_hstep); float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + _c0 = vcombine_f32(_cc0, _cc1); + _c1 = vcombine_f32(_cc2, _cc3); float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - _c0 = vcombine_f32(_cc0, _cc1); - _c1 = vcombine_f32(_cc2, _cc3); float32x4_t _c2 = vcombine_f32(_cc4, _cc5); float32x4_t _c3 = vcombine_f32(_cc6, _cc7); if (beta == 1.f) @@ -7673,7 +7671,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _cc0 = vmulq_f32(_cc0, _beta); _cc1 = vmulq_f32(_cc1, _beta); } -#if __aarch64__ _c0 = vdupq_laneq_f32(_cc0, 0); float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -7682,16 +7679,6 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -7833,10 +7820,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 4) { float32x4_t _c = vld1q_f32(pC); - if (beta != 1.f) - { - _c = vmulq_n_f32(_c, beta); - } + _c = vmulq_n_f32(_c, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_c, 0); float32x4_t _c1 = vdupq_laneq_f32(_c, 1); @@ -7951,10 +7935,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 4) { float32x2_t _c = vld1_f32(pC); - if (beta != 1.f) - { - _c = vmul_n_f32(_c, beta); - } + _c = vmul_n_f32(_c, beta); _c0 = vdupq_lane_f32(_c, 0); float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); @@ -8000,19 +7981,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); pC += 4; } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { @@ -8022,10 +7996,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); @@ -8150,17 +8121,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -8528,7 +8492,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); float32x4_t _c1 = vld1q_f32(pC + 4); @@ -8597,19 +8561,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = vld1q_f32(pC); pC += 4; } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { @@ -8619,10 +8576,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -8936,39 +8890,20 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { // c_elempack == 1 - if (beta == 1.f) - { - f0 += pC[0]; - f1 += pC[c_hstep]; - } - else - { - f0 += pC[0] * beta; - f1 += pC[c_hstep] * beta; - } + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; pC += 1; } if (broadcast_type_C == 4) { - if (beta == 1.f) - { - f0 += pC[0]; - f1 += pC[0]; - } - else - { - f0 += pC[0] * beta; - f1 += pC[0] * beta; - } + f0 += pC[0] * beta; + f1 += pC[0] * beta; pC += 1; } } - if (alpha != 1.f) - { - f0 *= alpha; - f1 *= alpha; - } + f0 *= alpha; + f1 *= alpha; p0[0] = f0; p0[out_hstep] = f1; @@ -9150,22 +9085,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { // out_elempack == 1 _c0 = vld1q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); @@ -9186,22 +9111,12 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { // out_elempack == 1 float32x2_t _c = vld1_f32(pC); - if (beta == 1.f) - { - _f0 = vadd_f32(_f0, _c); - } - else - { - _f0 = vmla_n_f32(_f0, _c, beta); - } + _f0 = vmla_n_f32(_f0, _c, beta); pC += 2; } } - if (alpha != 1.f) - { - _f0 = vmul_n_f32(_f0, alpha); - } + _f0 = vmul_n_f32(_f0, alpha); vst1_f32(p0, _f0); @@ -9597,8 +9512,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float32x4x4_t _cc0 = vld4q_f32(pC); float32x4x4_t _cc1 = vld4q_f32(pC + 16); - float32x4x4_t _cc2 = vld4q_f32(pC + c_hstep * 4); - float32x4x4_t _cc3 = vld4q_f32(pC + c_hstep * 4 + 16); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _cc0.val[0]); @@ -9609,14 +9522,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _cc1.val[2]); _f6 = vaddq_f32(_f6, _cc0.val[3]); _f7 = vaddq_f32(_f7, _cc1.val[3]); - _f8 = vaddq_f32(_f8, _cc2.val[0]); - _f9 = vaddq_f32(_f9, _cc3.val[0]); - _fa = vaddq_f32(_fa, _cc2.val[1]); - _fb = vaddq_f32(_fb, _cc3.val[1]); - _fc = vaddq_f32(_fc, _cc2.val[2]); - _fd = vaddq_f32(_fd, _cc3.val[2]); - _fe = vaddq_f32(_fe, _cc2.val[3]); - _ff = vaddq_f32(_ff, _cc3.val[3]); } else { @@ -9629,14 +9534,31 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - _f8 = vmlaq_f32(_f8, _cc2.val[0], _beta); - _f9 = vmlaq_f32(_f9, _cc3.val[0], _beta); - _fa = vmlaq_f32(_fa, _cc2.val[1], _beta); - _fb = vmlaq_f32(_fb, _cc3.val[1], _beta); - _fc = vmlaq_f32(_fc, _cc2.val[2], _beta); - _fd = vmlaq_f32(_fd, _cc3.val[2], _beta); - _fe = vmlaq_f32(_fe, _cc2.val[3], _beta); - _ff = vmlaq_f32(_ff, _cc3.val[3], _beta); + } + _cc0 = vld4q_f32(pC + c_hstep * 4); + _cc1 = vld4q_f32(pC + c_hstep * 4 + 16); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _cc0.val[0]); + _f9 = vaddq_f32(_f9, _cc1.val[0]); + _fa = vaddq_f32(_fa, _cc0.val[1]); + _fb = vaddq_f32(_fb, _cc1.val[1]); + _fc = vaddq_f32(_fc, _cc0.val[2]); + _fd = vaddq_f32(_fd, _cc1.val[2]); + _fe = vaddq_f32(_fe, _cc0.val[3]); + _ff = vaddq_f32(_ff, _cc1.val[3]); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _cc0.val[0], _beta); + _f9 = vmlaq_f32(_f9, _cc1.val[0], _beta); + _fa = vmlaq_f32(_fa, _cc0.val[1], _beta); + _fb = vmlaq_f32(_fb, _cc1.val[1], _beta); + _fc = vmlaq_f32(_fc, _cc0.val[2], _beta); + _fd = vmlaq_f32(_fd, _cc1.val[2], _beta); + _fe = vmlaq_f32(_fe, _cc0.val[3], _beta); + _ff = vmlaq_f32(_ff, _cc1.val[3], _beta); } pC += 32; } @@ -9842,28 +9764,31 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); #else float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); #endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); _f3 = vaddq_f32(_f3, _cc3); - _f4 = vaddq_f32(_f4, _cc4); - _f5 = vaddq_f32(_f5, _cc5); - _f6 = vaddq_f32(_f6, _cc6); - _f7 = vaddq_f32(_f7, _cc7); +#if __aarch64__ + _cc0 = vdupq_laneq_f32(_c1, 0); + _cc1 = vdupq_laneq_f32(_c1, 1); + _cc2 = vdupq_laneq_f32(_c1, 2); + _cc3 = vdupq_laneq_f32(_c1, 3); +#else + _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); + _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); + _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); + _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f4 = vaddq_f32(_f4, _cc0); + _f5 = vaddq_f32(_f5, _cc1); + _f6 = vaddq_f32(_f6, _cc2); + _f7 = vaddq_f32(_f7, _cc3); } if (broadcast_type_C == 3) { @@ -9873,20 +9798,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -9895,27 +9812,37 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 4; } else // if (c_elempack == 4) { float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + c_hstep * 4); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _cc0.val[0]); _f1 = vaddq_f32(_f1, _cc0.val[1]); _f2 = vaddq_f32(_f2, _cc0.val[2]); _f3 = vaddq_f32(_f3, _cc0.val[3]); - _f4 = vaddq_f32(_f4, _cc1.val[0]); - _f5 = vaddq_f32(_f5, _cc1.val[1]); - _f6 = vaddq_f32(_f6, _cc1.val[2]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); } else { @@ -9924,10 +9851,22 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - _f4 = vmlaq_f32(_f4, _cc1.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc1.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); + } + _cc0 = vld4q_f32(pC + c_hstep * 4); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _cc0.val[0]); + _f5 = vaddq_f32(_f5, _cc0.val[1]); + _f6 = vaddq_f32(_f6, _cc0.val[2]); + _f7 = vaddq_f32(_f7, _cc0.val[3]); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); + _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); + _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); + _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); } pC += 16; } @@ -9935,10 +9874,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -10201,6 +10137,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 8x8 to 8x4 and 8x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + 4); float32x4_t _c2 = vld1q_f32(pC + c_hstep); @@ -10336,7 +10273,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _cc0 = vmulq_f32(_cc0, _beta); _cc1 = vmulq_f32(_cc1, _beta); } -#if __aarch64__ _c0 = vdupq_laneq_f32(_cc0, 0); _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -10345,16 +10281,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -10525,6 +10451,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 4x8 to 4x4 and 4x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); @@ -10565,20 +10492,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + 4); float32x4_t _c2 = vld1q_f32(pC + 8); float32x4_t _c3 = vld1q_f32(pC + 12); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 4 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 4 + 8); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 4 + 12); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -10587,10 +10506,25 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 16; } @@ -10598,10 +10532,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { float32x4_t _cc = vld1q_f32(pC); - if (beta != 1.f) - { - _cc = vmulq_n_f32(_cc, beta); - } + _cc = vmulq_n_f32(_cc, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_cc, 0); _c1 = vdupq_laneq_f32(_cc, 1); @@ -10715,7 +10646,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (c_elempack == 1) { float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep * 1); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); @@ -10772,10 +10703,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { float32x2_t _cc = vld1_f32(pC); - if (beta != 1.f) - { - _cc = vmul_n_f32(_cc, beta); - } + _cc = vmul_n_f32(_cc, beta); _c0 = vdupq_lane_f32(_cc, 0); _c1 = vdupq_lane_f32(_cc, 1); _f0 = vaddq_f32(_f0, _c0); @@ -11291,10 +11219,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -11489,7 +11414,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { float32x4_t _cc0 = vld1q_f32(pC); float32x4_t _cc1 = vld1q_f32(pC + 4); -#if __aarch64__ + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } _c0 = vdupq_laneq_f32(_cc0, 0); float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); @@ -11498,16 +11428,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc0), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc0), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc0), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc0), 1); - float32x4_t _c4 = vdupq_lane_f32(vget_low_f32(_cc1), 0); - float32x4_t _c5 = vdupq_lane_f32(vget_low_f32(_cc1), 1); - float32x4_t _c6 = vdupq_lane_f32(vget_high_f32(_cc1), 0); - float32x4_t _c7 = vdupq_lane_f32(vget_high_f32(_cc1), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -11649,10 +11569,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { float32x4_t _cc = vld1q_f32(pC); - if (beta != 1.f) - { - _cc = vmulq_n_f32(_cc, beta); - } + _cc = vmulq_n_f32(_cc, beta); #if __aarch64__ _c0 = vdupq_laneq_f32(_cc, 0); float32x4_t _c1 = vdupq_laneq_f32(_cc, 1); @@ -11766,8 +11683,10 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(pC[0]); - float32x4_t _c1 = vdupq_n_f32(pC[1]); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); pC += 2; @@ -11816,14 +11735,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c0 = vld1q_f32(pC); pC += 4; } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { @@ -11833,10 +11745,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); pp += 4; @@ -11954,8 +11863,9 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + 4); if (beta != 1.f) { - _c0 = vmulq_n_f32(_c0, beta); - _c1 = vmulq_n_f32(_c1, beta); + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); @@ -12027,10 +11937,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 4; @@ -12128,8 +12035,9 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _c1 = vld1q_f32(pC + 4); if (beta != 1.f) { - _c0 = vmulq_n_f32(_c0, beta); - _c1 = vmulq_n_f32(_c1, beta); + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); @@ -12213,10 +12121,7 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma if (broadcast_type_C == 4) { _c0 = vld1q_f32(pC); - if (beta != 1.f) - { - _c0 = vmulq_n_f32(_c0, beta); - } + _c0 = vmulq_n_f32(_c0, beta); float32x4x2_t _cc = vzipq_f32(_c0, _c0); _f0 = vaddq_f32(_f0, _cc.val[0]); _f1 = vaddq_f32(_f1, _cc.val[1]); @@ -12333,11 +12238,8 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } - if (alpha != 1.f) - { - f0 *= alpha; - f1 *= alpha; - } + f0 *= alpha; + f1 *= alpha; p0[0] = f0; p0[1] = f1; @@ -12516,22 +12418,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { // c_elempack == 1 _c0 = vld1q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1q_f32(p0, _f0); pp += 4; @@ -12687,22 +12579,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { // out_elempack == 1 _c0 = vld1q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - } - else - { - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -12726,22 +12608,12 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { // c_elempack == 1 float32x2_t _c = vld1_f32(pC); - if (beta == 1.f) - { - _f0 = vadd_f32(_f0, _c); - } - else - { - _f0 = vmla_n_f32(_f0, _c, beta); - } + _f0 = vmla_n_f32(_f0, _c, beta); pC += 2; } } - if (alpha != 1.f) - { - _f0 = vmul_n_f32(_f0, alpha); - } + _f0 = vmul_n_f32(_f0, alpha); p0[0] = vget_lane_f32(_f0, 0); p0[out_hstep] = vget_lane_f32(_f0, 1); diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 9cecdc892989..e1a6c0c14994 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -24,8 +24,8 @@ void pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, void transpose_pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales); void pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); void transpose_pack_B_tile_bf16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale); -void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); -void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales); +void unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); +void transpose_unpack_output_tile_int32_to_bf16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta); #endif static void compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii) @@ -4097,12 +4097,12 @@ static void transpose_pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int } } -static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { #if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_asimddp()) { - unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); return; } #endif @@ -4381,12 +4381,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); transpose8x4_u16(_c01, _c23, _c45, _c67); - transpose8x4_u16(_c89, _cab, _ccd, _cef); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4395,42 +4390,73 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); uint16x8_t _c45 = vld1q_u16(pC + 16); uint16x8_t _c67 = vld1q_u16(pC + 24); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 4 + 8); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 4 + 16); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 4 + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4439,43 +4465,86 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 32; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); - float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); - float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); - float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); - float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + uint16x8_t _cc = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_cc)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_cc)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -4496,6 +4565,27 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -4619,53 +4709,114 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x4_t _cc1 = vld1_u16(pC + c_hstep); uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - uint16x4_t _cc4 = vld1_u16(pC + c_hstep * 4); - uint16x4_t _cc5 = vld1_u16(pC + c_hstep * 5); - uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); - uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - transpose4x4_u16(_cc4, _cc5, _cc6, _cc7); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); - _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); - _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); - _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); - _f4 = vaddq_f32(_f4, bfloat2float(_cc4)); - _f5 = vaddq_f32(_f5, bfloat2float(_cc5)); - _f6 = vaddq_f32(_f6, bfloat2float(_cc6)); - _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 4 + 8); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 16; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -4678,6 +4829,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -4748,6 +4912,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -4770,33 +4936,40 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 8; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 8; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -4805,6 +4978,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -4846,28 +5028,42 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 1; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); @@ -5080,7 +5276,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); @@ -5089,16 +5284,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -5124,10 +5309,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -5136,36 +5317,69 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8x4_t _cc0 = vld4q_u16(pC); - uint16x8x4_t _cc1 = vld4q_u16(pC + c_hstep * 4); _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0.val[0]))); _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0.val[0]))); _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc0.val[1]))); @@ -5174,14 +5388,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc0.val[2]))); _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc0.val[3]))); _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc0.val[3]))); - _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc1.val[0]))); - _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc1.val[0]))); - _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc1.val[1]))); - _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc1.val[1]))); - _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc1.val[2]))); - _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc1.val[2]))); - _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc1.val[3]))); - _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc1.val[3]))); + _cc0 = vld4q_u16(pC + c_hstep * 4); + _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc0.val[0]))); + _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc0.val[0]))); + _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc0.val[1]))); + _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc0.val[1]))); + _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc0.val[2]))); + _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc0.val[2]))); + _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc0.val[3]))); + _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc0.val[3]))); pC += 32; } } @@ -5190,6 +5405,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -5210,6 +5431,27 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -5383,38 +5625,61 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c1 = bfloat2float(vld1_u16(pC + c_hstep * 1)); float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 5)); - float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 6)); - float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x4x4_t _cc0 = vld4_u16(pC); - uint16x4x4_t _cc1 = vld4_u16(pC + c_hstep * 4); _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); _f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2])); _f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3])); - _f4 = vaddq_f32(_f4, bfloat2float(_cc1.val[0])); - _f5 = vaddq_f32(_f5, bfloat2float(_cc1.val[1])); - _f6 = vaddq_f32(_f6, bfloat2float(_cc1.val[2])); - _f7 = vaddq_f32(_f7, bfloat2float(_cc1.val[3])); + _cc0 = vld4_u16(pC + c_hstep * 4); + _f4 = vaddq_f32(_f4, bfloat2float(_cc0.val[0])); + _f5 = vaddq_f32(_f5, bfloat2float(_cc0.val[1])); + _f6 = vaddq_f32(_f6, bfloat2float(_cc0.val[2])); + _f7 = vaddq_f32(_f7, bfloat2float(_cc0.val[3])); pC += 16; } } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -5427,6 +5692,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -5518,6 +5796,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x8_t _cc0 = uint16x8_t(); @@ -5538,13 +5818,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _cc1 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _cc1, 5); _cc1 = vsetq_lane_u16(pC[c_hstep * 7], _cc1, 6); _cc1 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _cc1, 7); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc1))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc1))); + _c0 = bfloat2float(vget_low_u16(_cc0)); + _c1 = bfloat2float(vget_high_u16(_cc0)); + _c2 = bfloat2float(vget_low_u16(_cc1)); + _c3 = bfloat2float(vget_high_u16(_cc1)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { // TODO optimize uint16x8_t _cc0 = vld1q_u16(pC); @@ -5555,12 +5835,27 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c3 = bfloat2float(vget_high_u16(_cc1)); float32x4x2_t _c01 = vzipq_f32(_c0, _c1); float32x4x2_t _c23 = vzipq_f32(_c2, _c3); - _f0 = vaddq_f32(_f0, vcombine_f32(vget_low_f32(_c01.val[0]), vget_low_f32(_c01.val[1]))); - _f1 = vaddq_f32(_f1, vcombine_f32(vget_high_f32(_c01.val[0]), vget_high_f32(_c01.val[1]))); - _f2 = vaddq_f32(_f2, vcombine_f32(vget_low_f32(_c23.val[0]), vget_low_f32(_c23.val[1]))); - _f3 = vaddq_f32(_f3, vcombine_f32(vget_high_f32(_c23.val[0]), vget_high_f32(_c23.val[1]))); + _c0 = vcombine_f32(vget_low_f32(_c01.val[0]), vget_low_f32(_c01.val[1])); + _c1 = vcombine_f32(vget_high_f32(_c01.val[0]), vget_high_f32(_c01.val[1])); + _c2 = vcombine_f32(vget_low_f32(_c23.val[0]), vget_low_f32(_c23.val[1])); + _c3 = vcombine_f32(vget_high_f32(_c23.val[0]), vget_high_f32(_c23.val[1])); pC += 8; } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } } if (broadcast_type_C == 4) { @@ -5570,6 +5865,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vset_lane_u16(pC[0], _c, 2); _c = vset_lane_u16(pC[1], _c, 3); _c0 = bfloat2float(_c); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -5578,6 +5874,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + uint16x4_t _fb0 = float2bfloat(_f0); uint16x4_t _fb1 = float2bfloat(_f1); uint16x4_t _fb2 = float2bfloat(_f2); @@ -5636,28 +5941,44 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep * 5], _c, 5); _c = vsetq_lane_u16(pC[c_hstep * 6], _c, 6); _c = vsetq_lane_u16(pC[c_hstep * 7], _c, 7); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c))); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 1; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _fb0 = float2bfloat(_f0); uint16x4_t _fb1 = float2bfloat(_f1); @@ -5811,45 +6132,37 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; if (c_elempack == 1) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + 16); - uint16x8_t _c67 = vld1q_u16(pC + 24); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -5858,19 +6171,39 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 32; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); - float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); - float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); - float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); - float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -5883,6 +6216,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -5956,6 +6302,9 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x4_t _cc0 = vld1_u16(pC); @@ -5963,33 +6312,53 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); - _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); - _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); - _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 16; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -5998,6 +6367,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -6047,9 +6425,10 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x8_t _c; if (c_elempack == 1) { - uint16x8_t _c = uint16x8_t(); + _c = uint16x8_t(); _c = vsetq_lane_u16(pC[0], _c, 0); _c = vsetq_lane_u16(pC[c_hstep], _c, 1); _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); @@ -6058,32 +6437,44 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) { - uint16x8_t _c01 = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 8; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); pC += 2; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); pp += 8; @@ -6105,31 +6496,34 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x4_t _c; if (c_elempack == 1) { - uint16x4_t _c = uint16x4_t(); + _c = uint16x4_t(); _c = vset_lane_u16(pC[0], _c, 0); _c = vset_lane_u16(pC[c_hstep], _c, 1); _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - _f0 = vaddq_f32(_f0, bfloat2float(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _c = vld1_u16(pC); pC += 4; } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); pC += 1; } } + _f0 = vmulq_n_f32(_f0, alpha); + vst1_u16(p0, float2bfloat(_f0)); pp += 4; @@ -6253,17 +6647,10 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 1 || broadcast_type_C == 2) { -#if __aarch64__ float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc0); _f2 = vaddq_f32(_f2, _cc1); @@ -6275,6 +6662,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; if (c_elempack == 1) { uint16x8_t _c01 = vld1q_u16(pC); @@ -6282,13 +6676,30 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + pC += 8; + } + else // if (c_elempack == 4) + { + uint16x8x4_t _cc = vld4q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_cc.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc.val[0])); + _c2 = bfloat2float(vget_low_u16(_cc.val[1])); + _c3 = bfloat2float(vget_high_u16(_cc.val[1])); + _c4 = bfloat2float(vget_low_u16(_cc.val[2])); + _c5 = bfloat2float(vget_high_u16(_cc.val[2])); + _c6 = bfloat2float(vget_low_u16(_cc.val[3])); + _c7 = bfloat2float(vget_high_u16(_cc.val[3])); + pC += 32; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -6297,20 +6708,18 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 8; } - if (c_elempack == 4) + else { - uint16x8x4_t _cc = vld4q_u16(pC); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc.val[0]))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc.val[0]))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc.val[1]))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc.val[1]))); - _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc.val[2]))); - _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc.val[2]))); - _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc.val[3]))); - _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc.val[3]))); - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } } if (broadcast_type_C == 4) @@ -6318,6 +6727,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -6330,6 +6745,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -6433,31 +6861,46 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { _c0 = bfloat2float(vld1_u16(pC)); - float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep * 1)); - float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); + pC += 4; + } + else // if (c_elempack == 4) + { + uint16x4x4_t _c = vld4_u16(pC); + _c0 = bfloat2float(_c.val[0]); + _c1 = bfloat2float(_c.val[1]); + _c2 = bfloat2float(_c.val[2]); + _c3 = bfloat2float(_c.val[3]); + pC += 16; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 4; } - if (c_elempack == 4) + else { - uint16x4x4_t _c = vld4_u16(pC); - _f0 = vaddq_f32(_f0, bfloat2float(_c.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_c.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_c.val[2])); - _f3 = vaddq_f32(_f3, bfloat2float(_c.val[3])); - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -6466,6 +6909,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -6529,6 +6981,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + float32x4_t _c1; if (c_elempack == 1) { uint16x8_t _c = uint16x8_t(); @@ -6540,20 +6993,29 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 5); _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 6); _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c))); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { - uint16x8_t _cc = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_cc)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_cc)); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); + uint16x8_t _c = vld1q_u16(pC); + uint16x4x2_t _c01 = vzip_u16(vget_low_u16(_c), vget_high_u16(_c)); + _c0 = bfloat2float(_c01.val[0]); + _c1 = bfloat2float(_c01.val[1]); pC += 8; } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } } if (broadcast_type_C == 4) { @@ -6563,12 +7025,20 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vset_lane_u16(pC[0], _c, 2); _c = vset_lane_u16(pC[1], _c, 3); _c0 = bfloat2float(_c); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 2; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _fb0 = float2bfloat(_f0); uint16x4_t _fb1 = float2bfloat(_f1); @@ -6581,11 +7051,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& p0[out_hstep * 3] = vget_lane_u16(_fb1, 2); p0[out_hstep * 3 + 1] = vget_lane_u16(_fb1, 3); - // vst1_f32(p0, vget_low_f32(_f0)); - // vst1_f32(p1, vget_high_f32(_f0)); - // vst1_f32(p2, vget_low_f32(_f1)); - // vst1_f32(p3, vget_high_f32(_f1)); - pp += 8; p0 += 2; } @@ -6605,31 +7070,34 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + uint16x4_t _c; if (c_elempack == 1) { - uint16x4_t _c = uint16x4_t(); + _c = uint16x4_t(); _c = vset_lane_u16(pC[0], _c, 0); _c = vset_lane_u16(pC[c_hstep], _c, 1); _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - _f0 = vaddq_f32(_f0, bfloat2float(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _c = vld1_u16(pC); pC += 4; } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); pC += 1; } } + _f0 = vmulq_n_f32(_f0, alpha); + uint16x4_t _fb0 = float2bfloat(_f0); p0[0] = vget_lane_u16(_fb0, 0); @@ -6732,10 +7200,21 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 8; } if (broadcast_type_C == 4) @@ -6743,6 +7222,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -6751,6 +7236,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -6783,19 +7277,36 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& // c_elempack == 1 _c0 = bfloat2float(vld1_u16(pC)); float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 4; } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 4; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); @@ -6829,22 +7340,30 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += bfloat16_to_float32(pC[0]); - f01 += bfloat16_to_float32(pC[1]); - f10 += bfloat16_to_float32(pC[c_hstep]); - f11 += bfloat16_to_float32(pC[c_hstep + 1]); + f00 += bfloat16_to_float32(pC[0]) * beta; + f01 += bfloat16_to_float32(pC[1]) * beta; + f10 += bfloat16_to_float32(pC[c_hstep]) * beta; + f11 += bfloat16_to_float32(pC[c_hstep + 1]) * beta; pC += 2; } if (broadcast_type_C == 4) { - f00 += bfloat16_to_float32(pC[0]); - f01 += bfloat16_to_float32(pC[1]); - f10 += bfloat16_to_float32(pC[0]); - f11 += bfloat16_to_float32(pC[1]); + f00 += bfloat16_to_float32(pC[0]) * beta; + f01 += bfloat16_to_float32(pC[1]) * beta; + f10 += bfloat16_to_float32(pC[0]) * beta; + f11 += bfloat16_to_float32(pC[1]) * beta; pC += 2; } } + if (alpha != 1.f) + { + f00 *= alpha; + f01 *= alpha; + f10 *= alpha; + f11 *= alpha; + } + p0[0] = float32_to_bfloat16(f00); p0[1] = float32_to_bfloat16(f01); p0[out_hstep] = float32_to_bfloat16(f10); @@ -6874,18 +7393,24 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]); - f1 += bfloat16_to_float32(pC[c_hstep]); + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[c_hstep]) * beta; pC += 1; } if (broadcast_type_C == 4) { - f0 += bfloat16_to_float32(pC[0]); - f1 += bfloat16_to_float32(pC[0]); + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[0]) * beta; pC += 1; } } + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + p0[0] = float32_to_bfloat16(f0); p0[out_hstep] = float32_to_bfloat16(f1); @@ -6970,14 +7495,34 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 16; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -7005,12 +7550,28 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c01 = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c01)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 8; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); pp += 8; @@ -7030,11 +7591,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { // c_elempack == 1 _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } + _f0 = vmulq_n_f32(_f0, alpha); + vst1_u16(p0, float2bfloat(_f0)); pp += 4; @@ -7056,11 +7619,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x2_t _cc = float32x2_t(); _cc = vset_lane_f32(bfloat16_to_float32(pC[0]), _cc, 0); _cc = vset_lane_f32(bfloat16_to_float32(pC[1]), _cc, 1); - _f0 = vadd_f32(_f0, _cc); + _f0 = vmla_n_f32(_f0, _cc, beta); pC += 2; } } + _f0 = vmul_n_f32(_f0, alpha); + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); p0[1] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); @@ -7081,11 +7646,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3 || broadcast_type_C == 4) { // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]); + f0 += bfloat16_to_float32(pC[0]) * beta; pC += 1; } } + f0 *= alpha; + p0[0] = float32_to_bfloat16(f0); pp += 1; @@ -7095,12 +7662,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } -static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales) +static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) { #if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 if (ncnn::cpu_support_arm_asimddp()) { - transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales); + transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); return; } #endif @@ -7393,52 +7960,128 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 2 + 4)); float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 3)); float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 3 + 4)); - float32x4_t _c8 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - float32x4_t _c9 = bfloat2float(vld1_u16(pC + c_hstep * 4 + 4)); - float32x4_t _ca = bfloat2float(vld1_u16(pC + c_hstep * 5)); - float32x4_t _cb = bfloat2float(vld1_u16(pC + c_hstep * 5 + 4)); - float32x4_t _cc = bfloat2float(vld1_u16(pC + c_hstep * 6)); - float32x4_t _cd = bfloat2float(vld1_u16(pC + c_hstep * 6 + 4)); - float32x4_t _ce = bfloat2float(vld1_u16(pC + c_hstep * 7)); - float32x4_t _cf = bfloat2float(vld1_u16(pC + c_hstep * 7 + 4)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4 + 4)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 5 + 4)); + _c4 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + _c5 = bfloat2float(vld1_u16(pC + c_hstep * 6 + 4)); + _c6 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7 + 4)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8x4_t _cc0 = vld4q_u16(pC); - uint16x8x4_t _cc1 = vld4q_u16(pC + c_hstep * 4); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0.val[0]))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0.val[0]))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc0.val[1]))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc0.val[1]))); - _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc0.val[2]))); - _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc0.val[2]))); - _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc0.val[3]))); - _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc0.val[3]))); - _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc1.val[0]))); - _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc1.val[0]))); - _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc1.val[1]))); - _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc1.val[1]))); - _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc1.val[2]))); - _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc1.val[2]))); - _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc1.val[3]))); - _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc1.val[3]))); + _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); + float32x4_t _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); + float32x4_t _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); + float32x4_t _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); + float32x4_t _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); + float32x4_t _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); + float32x4_t _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _cc0 = vld4q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); + _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); + _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); + _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); + _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); + _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); + _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 32; } } @@ -7447,6 +8090,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -7467,6 +8116,27 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f8), float2bfloat(_fa))); @@ -7609,28 +8279,31 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); #else float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); #endif _f0 = vaddq_f32(_f0, _cc0); _f1 = vaddq_f32(_f1, _cc1); _f2 = vaddq_f32(_f2, _cc2); _f3 = vaddq_f32(_f3, _cc3); - _f4 = vaddq_f32(_f4, _cc4); - _f5 = vaddq_f32(_f5, _cc5); - _f6 = vaddq_f32(_f6, _cc6); - _f7 = vaddq_f32(_f7, _cc7); +#if __aarch64__ + _cc0 = vdupq_laneq_f32(_c1, 0); + _cc1 = vdupq_laneq_f32(_c1, 1); + _cc2 = vdupq_laneq_f32(_c1, 2); + _cc3 = vdupq_laneq_f32(_c1, 3); +#else + _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); + _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); + _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); + _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); +#endif + _f4 = vaddq_f32(_f4, _cc0); + _f5 = vaddq_f32(_f5, _cc1); + _f6 = vaddq_f32(_f6, _cc2); + _f7 = vaddq_f32(_f7, _cc3); } if (broadcast_type_C == 3) { @@ -7640,38 +8313,91 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c1 = bfloat2float(vld1_u16(pC + c_hstep)); float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 5)); - float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 6)); - float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 5)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 6)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 7)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x4x4_t _cc0 = vld4_u16(pC); - uint16x4x4_t _cc1 = vld4_u16(pC + c_hstep * 4); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2])); - _f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3])); - _f4 = vaddq_f32(_f4, bfloat2float(_cc1.val[0])); - _f5 = vaddq_f32(_f5, bfloat2float(_cc1.val[1])); - _f6 = vaddq_f32(_f6, bfloat2float(_cc1.val[2])); - _f7 = vaddq_f32(_f7, bfloat2float(_cc1.val[3])); + _c0 = bfloat2float(_cc0.val[0]); + _c1 = bfloat2float(_cc0.val[1]); + float32x4_t _c2 = bfloat2float(_cc0.val[2]); + float32x4_t _c3 = bfloat2float(_cc0.val[3]); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld4_u16(pC + c_hstep * 4); + _c0 = bfloat2float(_cc0.val[0]); + _c1 = bfloat2float(_cc0.val[1]); + _c2 = bfloat2float(_cc0.val[2]); + _c3 = bfloat2float(_cc0.val[3]); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 16; } } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -7684,6 +8410,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); @@ -7917,6 +8656,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 8x8 to 8x4 and 8x4 uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); @@ -7960,16 +8700,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _ff = vaddq_f32(_ff, _cf); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); uint16x8_t _c45 = vld1q_u16(pC + 16); uint16x8_t _c67 = vld1q_u16(pC + 24); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 4 + 8); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 4 + 16); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 4 + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -7978,43 +8714,86 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 32; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); - float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); - float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); - float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); - float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -8035,6 +8814,27 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f8))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f9))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_fa))); @@ -8156,6 +8956,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { + // TODO decompose 4x8 to 4x4 and 4x4 uint16x4_t _cc0 = vld1_u16(pC); uint16x4_t _cc1 = vld1_u16(pC + c_hstep); uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); @@ -8175,37 +8976,68 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 4 + 8); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 16; } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -8218,6 +9050,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f4))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f5))); vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_f6))); @@ -8288,6 +9133,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -8312,33 +9159,40 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c2 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + _c2 = bfloat2float(vget_high_u16(_c01)); + _c3 = bfloat2float(vget_high_u16(_c23)); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 8; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 8; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -8347,6 +9201,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); @@ -8385,28 +9248,42 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 1; } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); pp += 8; p0 += out_hstep; @@ -8598,7 +9475,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vaddq_f32(_f7, _c7); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8x4_t _c = vld4q_u16(pC); _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c.val[0]))); @@ -8629,6 +9506,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); @@ -8744,7 +9634,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f3 = vaddq_f32(_f3, _c3); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x4x4_t _c = vld4_u16(pC); _f0 = vaddq_f32(_f0, bfloat2float(_c.val[0])); @@ -8765,6 +9655,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); @@ -8905,7 +9804,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vaddq_f32(_f7, _c7); pC += 8; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -8952,6 +9851,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -9042,7 +9954,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); pC += 4; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -9071,6 +9983,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); @@ -9139,7 +10060,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f1 = vaddq_f32(_f1, _c1); pC += 2; } - if (c_elempack == 4) + else // if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c01)); @@ -9159,6 +10080,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep, float2bfloat(_f1)); @@ -9191,7 +10119,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f0 = vaddq_f32(_f0, bfloat2float(_c)); pC += 1; } - if (c_elempack == 4) + else // if (c_elempack == 4) { _c0 = bfloat2float(vld1_u16(pC)); _f0 = vaddq_f32(_f0, _c0); @@ -9206,6 +10134,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmulq_n_f32(_f0, alpha); + } + vst1_u16(p0, float2bfloat(_f0)); pp += 4; p0 += out_hstep; @@ -9320,6 +10253,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); @@ -9368,6 +10310,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + 4, float2bfloat(_f1)); @@ -9446,6 +10395,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); uint16x4_t _bf2 = float2bfloat(_f2); @@ -9520,6 +10478,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); @@ -9582,6 +10547,14 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + f00 *= alpha; + f01 *= alpha; + f10 *= alpha; + f11 *= alpha; + } + p0[0] = float32_to_bfloat16(f00); p0[1] = float32_to_bfloat16(f01); p0[out_hstep] = float32_to_bfloat16(f10); @@ -9624,6 +10597,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } + p0[0] = float32_to_bfloat16(f0); p0[1] = float32_to_bfloat16(f1); pp += 2; @@ -9714,6 +10693,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); vst1_u16(p0 + out_hstep * 8, float2bfloat(_f2)); @@ -9748,6 +10736,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + vst1_u16(p0, float2bfloat(_f0)); vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); pp += 8; @@ -9771,6 +10766,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmulq_n_f32(_f0, alpha); + } + vst1_u16(p0, float2bfloat(_f0)); pp += 4; p0 += out_hstep * 4; @@ -9819,6 +10819,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); uint16x4_t _bf2 = float2bfloat(_f2); @@ -9871,6 +10880,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); @@ -9905,6 +10921,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmulq_n_f32(_f0, alpha); + } + uint16x4_t _bf0 = float2bfloat(_f0); p0[0] = vget_lane_u16(_bf0, 0); @@ -9936,6 +10957,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + if (alpha != 1.f) + { + _f0 = vmul_n_f32(_f0, alpha); + } + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); p0[out_hstep] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); @@ -9961,6 +10987,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + f0 *= alpha; + p0[0] = float32_to_bfloat16(f0); pp += 1; From d59410fcb625646c448cbdc82882d087bda46351 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 25 Sep 2024 19:32:48 +0800 Subject: [PATCH 11/55] stash --- src/layer/arm/gemm_int8_bf16s.h | 611 ++++++++++++++++++++++---------- 1 file changed, 425 insertions(+), 186 deletions(-) diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index e1a6c0c14994..b3d156759860 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -4133,7 +4133,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { if (broadcast_type_C == 0) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { @@ -4141,6 +4141,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); _c1 = bfloat2float(vget_high_u16(_c)); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); } if (broadcast_type_C == 3) { @@ -5380,23 +5382,69 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& else // if (c_elempack == 4) { uint16x8x4_t _cc0 = vld4q_u16(pC); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_cc0.val[0]))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_cc0.val[0]))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_cc0.val[1]))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_cc0.val[1]))); - _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_cc0.val[2]))); - _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_cc0.val[2]))); - _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_cc0.val[3]))); - _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_cc0.val[3]))); + _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); + float32x4_t _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); + float32x4_t _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); + float32x4_t _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); + float32x4_t _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); + float32x4_t _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); + float32x4_t _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } _cc0 = vld4q_u16(pC + c_hstep * 4); - _f8 = vaddq_f32(_f8, bfloat2float(vget_low_u16(_cc0.val[0]))); - _f9 = vaddq_f32(_f9, bfloat2float(vget_high_u16(_cc0.val[0]))); - _fa = vaddq_f32(_fa, bfloat2float(vget_low_u16(_cc0.val[1]))); - _fb = vaddq_f32(_fb, bfloat2float(vget_high_u16(_cc0.val[1]))); - _fc = vaddq_f32(_fc, bfloat2float(vget_low_u16(_cc0.val[2]))); - _fd = vaddq_f32(_fd, bfloat2float(vget_high_u16(_cc0.val[2]))); - _fe = vaddq_f32(_fe, bfloat2float(vget_low_u16(_cc0.val[3]))); - _ff = vaddq_f32(_ff, bfloat2float(vget_high_u16(_cc0.val[3]))); + _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); + _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); + _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); + _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); + _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); + _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); + _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } pC += 32; } } @@ -6007,12 +6055,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { if (broadcast_type_C == 0) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const unsigned short*)C + i + ii; _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); } if (broadcast_type_C == 3) { @@ -7132,7 +7181,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { if (broadcast_type_C == 0) { - c0 = bfloat16_to_float32(pC[0]); + c0 = bfloat16_to_float32(pC[0]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); #endif @@ -7140,8 +7189,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const unsigned short*)C + i + ii; - c0 = bfloat16_to_float32(pC[0]); - c1 = bfloat16_to_float32(pC[1]); + c0 = bfloat16_to_float32(pC[0]) * beta; + c1 = bfloat16_to_float32(pC[1]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); _c1 = vdupq_n_f32(c1); @@ -7437,7 +7486,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { if (broadcast_type_C == 0) { - c0 = bfloat16_to_float32(pC[0]); + c0 = bfloat16_to_float32(pC[0]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); #endif @@ -7445,7 +7494,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const unsigned short*)C + i + ii; - c0 = bfloat16_to_float32(pC[0]); + c0 = bfloat16_to_float32(pC[0]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); #endif @@ -7698,7 +7747,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (broadcast_type_C == 0) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { @@ -7706,6 +7755,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); _c1 = bfloat2float(vget_high_u16(_c)); + _c0 = vmulq_n_f32(_c0, beta); + _c1 = vmulq_n_f32(_c1, beta); } if (broadcast_type_C == 3) { @@ -9301,12 +9352,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (broadcast_type_C == 0) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const unsigned short*)C + i + ii; _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); } if (broadcast_type_C == 3) { @@ -9451,6 +9503,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; if (c_elempack == 1) { uint16x8_t _c01 = vld1q_u16(pC); @@ -9458,13 +9517,30 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + pC += 8; + } + else // if (c_elempack == 4) + { + uint16x8x4_t _c = vld4q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c.val[0])); + _c1 = bfloat2float(vget_high_u16(_c.val[0])); + _c2 = bfloat2float(vget_low_u16(_c.val[1])); + _c3 = bfloat2float(vget_high_u16(_c.val[1])); + _c4 = bfloat2float(vget_low_u16(_c.val[2])); + _c5 = bfloat2float(vget_high_u16(_c.val[2])); + _c6 = bfloat2float(vget_low_u16(_c.val[3])); + _c7 = bfloat2float(vget_high_u16(_c.val[3])); + pC += 32; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -9473,20 +9549,18 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 8; } - else // if (c_elempack == 4) + else { - uint16x8x4_t _c = vld4q_u16(pC); - _f0 = vaddq_f32(_f0, bfloat2float(vget_low_u16(_c.val[0]))); - _f1 = vaddq_f32(_f1, bfloat2float(vget_high_u16(_c.val[0]))); - _f2 = vaddq_f32(_f2, bfloat2float(vget_low_u16(_c.val[1]))); - _f3 = vaddq_f32(_f3, bfloat2float(vget_high_u16(_c.val[1]))); - _f4 = vaddq_f32(_f4, bfloat2float(vget_low_u16(_c.val[2]))); - _f5 = vaddq_f32(_f5, bfloat2float(vget_high_u16(_c.val[2]))); - _f6 = vaddq_f32(_f6, bfloat2float(vget_low_u16(_c.val[3]))); - _f7 = vaddq_f32(_f7, bfloat2float(vget_high_u16(_c.val[3]))); - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } } if (broadcast_type_C == 4) @@ -9494,6 +9568,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -9622,31 +9702,46 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { _c0 = bfloat2float(vld1_u16(pC)); - float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); + _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); pC += 4; } else // if (c_elempack == 4) { uint16x4x4_t _c = vld4_u16(pC); - _f0 = vaddq_f32(_f0, bfloat2float(_c.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_c.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_c.val[2])); - _f3 = vaddq_f32(_f3, bfloat2float(_c.val[3])); + _c0 = bfloat2float(_c.val[0]); + _c1 = bfloat2float(_c.val[1]); + _c2 = bfloat2float(_c.val[2]); + _c3 = bfloat2float(_c.val[3]); pC += 16; } + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); _f2 = vaddq_f32(_f2, _c0); @@ -9779,45 +9874,37 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; if (c_elempack == 1) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); pC += 8; } else // if (c_elempack == 4) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + 16); - uint16x8_t _c67 = vld1q_u16(pC + 24); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -9826,19 +9913,39 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 32; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); - float32x4_t _c4 = vdupq_n_f32(bfloat16_to_float32(pC[4])); - float32x4_t _c5 = vdupq_n_f32(bfloat16_to_float32(pC[5])); - float32x4_t _c6 = vdupq_n_f32(bfloat16_to_float32(pC[6])); - float32x4_t _c7 = vdupq_n_f32(bfloat16_to_float32(pC[7])); + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -9941,6 +10048,9 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; if (c_elempack == 1) { uint16x4_t _cc0 = vld1_u16(pC); @@ -9948,10 +10058,10 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); - _f1 = vaddq_f32(_f1, bfloat2float(_cc1)); - _f2 = vaddq_f32(_f2, bfloat2float(_cc2)); - _f3 = vaddq_f32(_f3, bfloat2float(_cc3)); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); pC += 4; } else // if (c_elempack == 4) @@ -9959,22 +10069,35 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; + } + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 16; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); - float32x4_t _c2 = vdupq_n_f32(bfloat16_to_float32(pC[2])); - float32x4_t _c3 = vdupq_n_f32(bfloat16_to_float32(pC[3])); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -10043,9 +10166,10 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + uint16x8_t _c; if (c_elempack == 1) { - uint16x8_t _c = uint16x8_t(); + _c = uint16x8_t(); _c = vsetq_lane_u16(pC[0], _c, 0); _c = vsetq_lane_u16(pC[c_hstep], _c, 1); _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); @@ -10054,26 +10178,31 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); pC += 2; } else // if (c_elempack == 4) { - uint16x8_t _c01 = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + _c = vld1q_u16(pC); + pC += 8; + } + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) + { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 8; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); pC += 2; @@ -10109,35 +10238,33 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + uint16x4_t _c; if (c_elempack == 1) { - uint16x4_t _c = uint16x4_t(); + _c = uint16x4_t(); _c = vset_lane_u16(pC[0], _c, 0); _c = vset_lane_u16(pC[c_hstep], _c, 1); _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - _f0 = vaddq_f32(_f0, bfloat2float(_c)); pC += 1; } else // if (c_elempack == 4) { - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _c = vld1_u16(pC); pC += 4; } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); } if (broadcast_type_C == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0])); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); pC += 1; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1_u16(p0, float2bfloat(_f0)); pp += 4; @@ -10166,7 +10293,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (broadcast_type_C == 0) { - c0 = bfloat16_to_float32(pC[0]); + c0 = bfloat16_to_float32(pC[0]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); #endif @@ -10174,8 +10301,8 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const unsigned short*)C + i + ii; - c0 = bfloat16_to_float32(pC[0]); - c1 = bfloat16_to_float32(pC[1]); + c0 = bfloat16_to_float32(pC[0]) * beta; + c1 = bfloat16_to_float32(pC[1]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); _c1 = vdupq_n_f32(c1); @@ -10234,10 +10361,21 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 8; } if (broadcast_type_C == 4) @@ -10245,6 +10383,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); + } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c0); @@ -10297,13 +10441,23 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma // c_elempack == 1 _c0 = bfloat2float(vld1_u16(pC)); _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 4; } if (broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c0); pC += 4; @@ -10376,10 +10530,25 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x4x2_t _c02 = vzip_u16(vget_low_u16(_c01), vget_low_u16(_c23)); uint16x4x2_t _c13 = vzip_u16(vget_high_u16(_c01), vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, bfloat2float(_c02.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_c02.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_c13.val[0])); - _f3 = vaddq_f32(_f3, bfloat2float(_c13.val[1])); + _c0 = bfloat2float(_c02.val[0]); + float32x4_t _c1 = bfloat2float(_c02.val[1]); + float32x4_t _c2 = bfloat2float(_c13.val[0]); + float32x4_t _c3 = bfloat2float(_c13.val[1]); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 8; } if (broadcast_type_C == 4) @@ -10387,10 +10556,25 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c01 = vld1q_u16(pC); uint16x4x2_t _cc0 = vzip_u16(vget_low_u16(_c01), vget_low_u16(_c01)); uint16x4x2_t _cc1 = vzip_u16(vget_high_u16(_c01), vget_high_u16(_c01)); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_cc1.val[0])); - _f3 = vaddq_f32(_f3, bfloat2float(_cc1.val[1])); + _c0 = bfloat2float(_cc0.val[0]); + float32x4_t _c1 = bfloat2float(_cc0.val[1]); + float32x4_t _c2 = bfloat2float(_cc1.val[0]); + float32x4_t _c3 = bfloat2float(_cc1.val[1]); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 8; } } @@ -10464,16 +10648,38 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _cc0 = vld1_u16(pC); uint16x4_t _cc1 = vld1_u16(pC + c_hstep); uint16x4x2_t _c01 = vzip_u16(_cc0, _cc1); - _f0 = vaddq_f32(_f0, bfloat2float(_c01.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_c01.val[1])); + _c0 = bfloat2float(_c01.val[0]); + float32x4_t _c1 = bfloat2float(_c01.val[1]); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 4; } if (broadcast_type_C == 4) { uint16x4_t _c = vld1_u16(pC); uint16x4x2_t _cc = vzip_u16(_c, _c); - _f0 = vaddq_f32(_f0, bfloat2float(_cc.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_cc.val[1])); + _c0 = bfloat2float(_cc.val[0]); + float32x4_t _c1 = bfloat2float(_cc.val[1]); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 4; } } @@ -10529,16 +10735,16 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += bfloat16_to_float32(pC[0]); - f01 += bfloat16_to_float32(pC[c_hstep]); - f10 += bfloat16_to_float32(pC[1]); - f11 += bfloat16_to_float32(pC[c_hstep + 1]); + f00 += bfloat16_to_float32(pC[0]) * beta; + f01 += bfloat16_to_float32(pC[c_hstep]) * beta; + f10 += bfloat16_to_float32(pC[1]) * beta; + f11 += bfloat16_to_float32(pC[c_hstep + 1]) * beta; pC += 2; } if (broadcast_type_C == 4) { - float c0 = bfloat16_to_float32(pC[0]); - float c1 = bfloat16_to_float32(pC[1]); + float c0 = bfloat16_to_float32(pC[0]) * beta; + float c1 = bfloat16_to_float32(pC[1]) * beta; f00 += c0; f01 += c0; f10 += c1; @@ -10584,13 +10790,13 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 3) { // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]); - f1 += bfloat16_to_float32(pC[c_hstep]); + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[c_hstep]) * beta; pC += 1; } if (broadcast_type_C == 4) { - c0 = bfloat16_to_float32(pC[0]); + c0 = bfloat16_to_float32(pC[0]) * beta; f0 += c0; f1 += c0; pC += 1; @@ -10627,7 +10833,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (broadcast_type_C == 0) { - c0 = bfloat16_to_float32(pC[0]); + c0 = bfloat16_to_float32(pC[0]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); #endif @@ -10635,7 +10841,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 1 || broadcast_type_C == 2) { pC = (const unsigned short*)C + i + ii; - c0 = bfloat16_to_float32(pC[0]); + c0 = bfloat16_to_float32(pC[0]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); #endif @@ -10685,10 +10891,21 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 16; } } @@ -10730,8 +10947,17 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 8; } } @@ -10761,15 +10987,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 3 || broadcast_type_C == 4) { _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); vst1_u16(p0, float2bfloat(_f0)); pp += 4; @@ -10811,10 +11034,21 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } pC += 16; } } @@ -10876,6 +11110,17 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + } pC += 8; } } @@ -10916,15 +11161,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { // out_elempack == 1 _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vaddq_f32(_f0, _c0); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 4; } } - if (alpha != 1.f) - { - _f0 = vmulq_n_f32(_f0, alpha); - } + _f0 = vmulq_n_f32(_f0, alpha); uint16x4_t _bf0 = float2bfloat(_f0); @@ -10952,15 +11194,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x2_t _c = float32x2_t(); _c = vset_lane_f32(bfloat16_to_float32(pC[0]), _c, 0); _c = vset_lane_f32(bfloat16_to_float32(pC[1]), _c, 1); - _f0 = vadd_f32(_f0, _c); + _f0 = vmla_n_f32(_f0, _c, beta); pC += 2; } } - if (alpha != 1.f) - { - _f0 = vmul_n_f32(_f0, alpha); - } + _f0 = vmul_n_f32(_f0, alpha); p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); p0[out_hstep] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); @@ -10982,7 +11221,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 3 || broadcast_type_C == 4) { // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]); + f0 += bfloat16_to_float32(pC[0]) * beta; pC += 1; } } From 07f675594739d2005b625ab7daab5aab9d55becd Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 26 Sep 2024 17:23:46 +0800 Subject: [PATCH 12/55] fix int8 bf16s --- src/layer/arm/gemm_arm.cpp | 11 ++- src/layer/arm/gemm_int8_bf16s.h | 127 ++++++++++++++++++++++---------- 2 files changed, 97 insertions(+), 41 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 522ffd7d7049..04f775f44793 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -47,7 +47,7 @@ Gemm_arm::Gemm_arm() #endif // __ARM_NEON #if NCNN_BF16 - // support_bf16_storage = true; + support_bf16_storage = true; #endif nT = 0; @@ -6037,6 +6037,15 @@ int Gemm_arm::create_pipeline_int8(const Option& opt) } #endif +#if __ARM_NEON + if (constant_broadcast_type_C == 3 && opt.use_packing_layout && CT_data.h % 4 == 0) + { + Mat C2; + ncnn::convert_packing(CT_data, C2, 4, opt); + CT_data = C2; + } +#endif + if (opt.lightmode) C_data.release(); } diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index b3d156759860..786b3310368c 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -4914,11 +4914,11 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { - float32x4_t _c2; - float32x4_t _c3; + uint16x8_t _c01; + uint16x8_t _c23; if (c_elempack == 1) { - uint16x8_t _c01 = uint16x8_t(); + _c01 = uint16x8_t(); _c01 = vsetq_lane_u16(pC[0], _c01, 0); _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); @@ -4927,7 +4927,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); - uint16x8_t _c23 = uint16x8_t(); + _c23 = uint16x8_t(); _c23 = vsetq_lane_u16(pC[1], _c23, 0); _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); @@ -4936,22 +4936,18 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); pC += 2; } else // if (c_elempack == 4) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep * 4); pC += 8; } + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -8733,22 +8729,45 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c8); - _f5 = vaddq_f32(_f5, _ca); - _f6 = vaddq_f32(_f6, _cc); - _f7 = vaddq_f32(_f7, _ce); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c3); - _fa = vaddq_f32(_fa, _c5); - _fb = vaddq_f32(_fb, _c7); - _fc = vaddq_f32(_fc, _c9); - _fd = vaddq_f32(_fd, _cb); - _fe = vaddq_f32(_fe, _cd); - _ff = vaddq_f32(_ff, _cf); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c2); + _f2 = vaddq_f32(_f2, _c4); + _f3 = vaddq_f32(_f3, _c6); + _f4 = vaddq_f32(_f4, _c8); + _f5 = vaddq_f32(_f5, _ca); + _f6 = vaddq_f32(_f6, _cc); + _f7 = vaddq_f32(_f7, _ce); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c3); + _fa = vaddq_f32(_fa, _c5); + _fb = vaddq_f32(_fb, _c7); + _fc = vaddq_f32(_fc, _c9); + _fd = vaddq_f32(_fd, _cb); + _fe = vaddq_f32(_fe, _cd); + _ff = vaddq_f32(_ff, _cf); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c2, _beta); + _f2 = vmlaq_f32(_f2, _c4, _beta); + _f3 = vmlaq_f32(_f3, _c6, _beta); + _f4 = vmlaq_f32(_f4, _c8, _beta); + _f5 = vmlaq_f32(_f5, _ca, _beta); + _f6 = vmlaq_f32(_f6, _cc, _beta); + _f7 = vmlaq_f32(_f7, _ce, _beta); + _f8 = vmlaq_f32(_f8, _c1, _beta); + _f9 = vmlaq_f32(_f9, _c3, _beta); + _fa = vmlaq_f32(_fa, _c5, _beta); + _fb = vmlaq_f32(_fb, _c7, _beta); + _fc = vmlaq_f32(_fc, _c9, _beta); + _fd = vmlaq_f32(_fd, _cb, _beta); + _fe = vmlaq_f32(_fe, _cd, _beta); + _ff = vmlaq_f32(_ff, _cf, _beta); + } pC += 8; } else // if (c_elempack == 4) @@ -9017,14 +9036,44 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); transpose4x8_u16(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); - _f0 = vaddq_f32(_f0, bfloat2float(_cc0)); - _f1 = vaddq_f32(_f1, bfloat2float(_cc2)); - _f2 = vaddq_f32(_f2, bfloat2float(_cc4)); - _f3 = vaddq_f32(_f3, bfloat2float(_cc6)); - _f4 = vaddq_f32(_f4, bfloat2float(_cc1)); - _f5 = vaddq_f32(_f5, bfloat2float(_cc3)); - _f6 = vaddq_f32(_f6, bfloat2float(_cc5)); - _f7 = vaddq_f32(_f7, bfloat2float(_cc7)); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc2); + float32x4_t _c2 = bfloat2float(_cc4); + float32x4_t _c3 = bfloat2float(_cc6); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = bfloat2float(_cc1); + _c1 = bfloat2float(_cc3); + _c2 = bfloat2float(_cc5); + _c3 = bfloat2float(_cc7); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } pC += 4; } else // if (c_elempack == 4) @@ -11108,8 +11157,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c = vld1q_u16(pC); _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); From 6302017678428acd799e44f0abf50bc2a3d372fc Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 26 Sep 2024 18:39:42 +0800 Subject: [PATCH 13/55] build --- src/layer/arm/gemm_int8_bf16s.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 786b3310368c..3f58bd644385 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -10143,10 +10143,17 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { float32x4_t _c = bfloat2float(vld1_u16(pC)); _c = vmulq_n_f32(_c, beta); +#if __aarch64__ _c0 = vdupq_laneq_f32(_c, 0); float32x4_t _c1 = vdupq_laneq_f32(_c, 1); float32x4_t _c2 = vdupq_laneq_f32(_c, 2); float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); From dcd063636a33b0b677f0f1e301f7496735b9bd8e Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 26 Sep 2024 19:04:51 +0800 Subject: [PATCH 14/55] opt++ --- src/layer/arm/gemm_int8.h | 134 +++++++++++++++++-------------- src/layer/arm/gemm_int8_bf16s.h | 137 ++++++++++++++++---------------- 2 files changed, 143 insertions(+), 128 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 10045d2f7221..fd8f56985b78 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -10137,7 +10137,6 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 8x8 to 8x4 and 8x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + 4); float32x4_t _c2 = vld1q_f32(pC + c_hstep); @@ -10146,53 +10145,61 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - float32x4_t _c8 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c9 = vld1q_f32(pC + c_hstep * 4 + 4); - float32x4_t _ca = vld1q_f32(pC + c_hstep * 5); - float32x4_t _cb = vld1q_f32(pC + c_hstep * 5 + 4); - float32x4_t _cc = vld1q_f32(pC + c_hstep * 6); - float32x4_t _cd = vld1q_f32(pC + c_hstep * 6 + 4); - float32x4_t _ce = vld1q_f32(pC + c_hstep * 7); - float32x4_t _cf = vld1q_f32(pC + c_hstep * 7 + 4); - transpose8x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7, _c8, _c9, _ca, _cb, _cc, _cd, _ce, _cf); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c8); - _f5 = vaddq_f32(_f5, _ca); - _f6 = vaddq_f32(_f6, _cc); - _f7 = vaddq_f32(_f7, _ce); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c3); - _fa = vaddq_f32(_fa, _c5); - _fb = vaddq_f32(_fb, _c7); - _fc = vaddq_f32(_fc, _c9); - _fd = vaddq_f32(_fd, _cb); - _fe = vaddq_f32(_fe, _cd); - _ff = vaddq_f32(_ff, _cf); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c2, _beta); - _f2 = vmlaq_f32(_f2, _c4, _beta); - _f3 = vmlaq_f32(_f3, _c6, _beta); - _f4 = vmlaq_f32(_f4, _c8, _beta); - _f5 = vmlaq_f32(_f5, _ca, _beta); - _f6 = vmlaq_f32(_f6, _cc, _beta); - _f7 = vmlaq_f32(_f7, _ce, _beta); - _f8 = vmlaq_f32(_f8, _c1, _beta); - _f9 = vmlaq_f32(_f9, _c3, _beta); - _fa = vmlaq_f32(_fa, _c5, _beta); - _fb = vmlaq_f32(_fb, _c7, _beta); - _fc = vmlaq_f32(_fc, _c9, _beta); - _fd = vmlaq_f32(_fd, _cb, _beta); - _fe = vmlaq_f32(_fe, _cd, _beta); - _ff = vmlaq_f32(_ff, _cf, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } pC += 8; } @@ -10451,38 +10458,45 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 4x8 to 4x4 and 4x4 _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + c_hstep); float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 5); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 6); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 7); - transpose4x8_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + transpose4x4_ps(_c0, _c1, _c2, _c3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c3); - _f6 = vaddq_f32(_f6, _c5); - _f7 = vaddq_f32(_f7, _c7); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c2, _beta); - _f2 = vmlaq_f32(_f2, _c4, _beta); - _f3 = vmlaq_f32(_f3, _c6, _beta); - _f4 = vmlaq_f32(_f4, _c1, _beta); - _f5 = vmlaq_f32(_f5, _c3, _beta); - _f6 = vmlaq_f32(_f6, _c5, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } pC += 4; } diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 3f58bd644385..989cdf6866ac 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -5870,19 +5870,13 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } else // if (c_elempack == 4) { - // TODO optimize uint16x8_t _cc0 = vld1q_u16(pC); uint16x8_t _cc1 = vld1q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_cc0)); - _c1 = bfloat2float(vget_high_u16(_cc0)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_cc1)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_cc1)); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - float32x4x2_t _c23 = vzipq_f32(_c2, _c3); - _c0 = vcombine_f32(vget_low_f32(_c01.val[0]), vget_low_f32(_c01.val[1])); - _c1 = vcombine_f32(vget_high_f32(_c01.val[0]), vget_high_f32(_c01.val[1])); - _c2 = vcombine_f32(vget_low_f32(_c23.val[0]), vget_low_f32(_c23.val[1])); - _c3 = vcombine_f32(vget_high_f32(_c23.val[0]), vget_high_f32(_c23.val[1])); + uint16x8x2_t _cc = vzipq_u16(vcombine_u16(vget_low_u16(_cc0), vget_low_u16(_cc1)), vcombine_u16(vget_high_u16(_cc0), vget_high_u16(_cc1))); + _c0 = bfloat2float(vget_low_u16(_cc.val[0])); + _c1 = bfloat2float(vget_high_u16(_cc.val[0])); + _c2 = bfloat2float(vget_low_u16(_cc.val[1])); + _c3 = bfloat2float(vget_high_u16(_cc.val[1])); pC += 8; } if (beta == 1.f) @@ -8703,16 +8697,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 8x8 to 8x4 and 8x4 uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + c_hstep); uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - uint16x8_t _c89 = vld1q_u16(pC + c_hstep * 4); - uint16x8_t _cab = vld1q_u16(pC + c_hstep * 5); - uint16x8_t _ccd = vld1q_u16(pC + c_hstep * 6); - uint16x8_t _cef = vld1q_u16(pC + c_hstep * 7); - transpose8x8_u16(_c01, _c23, _c45, _c67, _c89, _cab, _ccd, _cef); + transpose8x4_u16(_c01, _c23, _c45, _c67); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -8721,52 +8710,64 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - float32x4_t _c8 = bfloat2float(vget_low_u16(_c89)); - float32x4_t _c9 = bfloat2float(vget_high_u16(_c89)); - float32x4_t _ca = bfloat2float(vget_low_u16(_cab)); - float32x4_t _cb = bfloat2float(vget_high_u16(_cab)); - float32x4_t _cc = bfloat2float(vget_low_u16(_ccd)); - float32x4_t _cd = bfloat2float(vget_high_u16(_ccd)); - float32x4_t _ce = bfloat2float(vget_low_u16(_cef)); - float32x4_t _cf = bfloat2float(vget_high_u16(_cef)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c2); - _f2 = vaddq_f32(_f2, _c4); - _f3 = vaddq_f32(_f3, _c6); - _f4 = vaddq_f32(_f4, _c8); - _f5 = vaddq_f32(_f5, _ca); - _f6 = vaddq_f32(_f6, _cc); - _f7 = vaddq_f32(_f7, _ce); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c3); - _fa = vaddq_f32(_fa, _c5); - _fb = vaddq_f32(_fb, _c7); - _fc = vaddq_f32(_fc, _c9); - _fd = vaddq_f32(_fd, _cb); - _fe = vaddq_f32(_fe, _cd); - _ff = vaddq_f32(_ff, _cf); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c2, _beta); - _f2 = vmlaq_f32(_f2, _c4, _beta); - _f3 = vmlaq_f32(_f3, _c6, _beta); - _f4 = vmlaq_f32(_f4, _c8, _beta); - _f5 = vmlaq_f32(_f5, _ca, _beta); - _f6 = vmlaq_f32(_f6, _cc, _beta); - _f7 = vmlaq_f32(_f7, _ce, _beta); - _f8 = vmlaq_f32(_f8, _c1, _beta); - _f9 = vmlaq_f32(_f9, _c3, _beta); - _fa = vmlaq_f32(_fa, _c5, _beta); - _fb = vmlaq_f32(_fb, _c7, _beta); - _fc = vmlaq_f32(_fc, _c9, _beta); - _fd = vmlaq_f32(_fd, _cb, _beta); - _fe = vmlaq_f32(_fe, _cd, _beta); - _ff = vmlaq_f32(_ff, _cf, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } pC += 8; } @@ -9026,20 +9027,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { if (c_elempack == 1) { - // TODO decompose 4x8 to 4x4 and 4x4 uint16x4_t _cc0 = vld1_u16(pC); uint16x4_t _cc1 = vld1_u16(pC + c_hstep); uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - uint16x4_t _cc4 = vld1_u16(pC + c_hstep * 4); - uint16x4_t _cc5 = vld1_u16(pC + c_hstep * 5); - uint16x4_t _cc6 = vld1_u16(pC + c_hstep * 6); - uint16x4_t _cc7 = vld1_u16(pC + c_hstep * 7); - transpose4x8_u16(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc2); - float32x4_t _c2 = bfloat2float(_cc4); - float32x4_t _c3 = bfloat2float(_cc6); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -9055,10 +9051,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _c0 = bfloat2float(_cc1); - _c1 = bfloat2float(_cc3); - _c2 = bfloat2float(_cc5); - _c3 = bfloat2float(_cc7); + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0); From 09af8476e644d0fafd99aea2eb18e7e4d0bb983c Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 26 Sep 2024 19:12:05 +0800 Subject: [PATCH 15/55] fix --- src/layer/arm/gemm_arm.cpp | 11 ----------- src/layer/arm/gemm_int8_bf16s.h | 16 ++++++++-------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 04f775f44793..676e4492ca39 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -58,8 +58,6 @@ void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk) const int elempack = A.elempack; const int A_hstep = A.dims == 3 ? (int)A.cstep : A.w; - NCNN_LOGE("pack_A_tile %d %d %d %d %d %d", i, max_ii, k, max_kk, elempack, A_hstep); - float* pp = AT; int ii = 0; @@ -6037,15 +6035,6 @@ int Gemm_arm::create_pipeline_int8(const Option& opt) } #endif -#if __ARM_NEON - if (constant_broadcast_type_C == 3 && opt.use_packing_layout && CT_data.h % 4 == 0) - { - Mat C2; - ncnn::convert_packing(CT_data, C2, 4, opt); - CT_data = C2; - } -#endif - if (opt.lightmode) C_data.release(); } diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 989cdf6866ac..77e6f790bb9c 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -4923,15 +4923,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); - _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); - _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); - _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); - _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c01 = vsetq_lane_u16(pC[1], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep + 1], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c01, 7); _c23 = uint16x8_t(); - _c23 = vsetq_lane_u16(pC[1], _c23, 0); - _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); - _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); - _c23 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep * 5], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 6], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 7], _c23, 3); _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); From 8068e6cd62da10f3de47d532fb6532843c7624c0 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 27 Sep 2024 11:01:22 +0800 Subject: [PATCH 16/55] test++ --- tests/test_gemm_3.cpp | 99 +++++++++++++++++-------------------------- 1 file changed, 40 insertions(+), 59 deletions(-) diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index 271d1a6bdaab..0d2f257c555f 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -14,7 +14,7 @@ #include "testutil.h" -static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M = 0) +static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M) { ncnn::ParamDict pd; pd.set(0, alpha); @@ -138,25 +138,35 @@ static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float al static int test_gemm_0(int M, int N, int K) { return 0 - || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0) - || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 0, 0, 0) - || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 0, 0, 0) - || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 0, 0, 0) - - || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0, 1) - || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 0, 0, 0, 1) - || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 0, 0, 0, 1) - || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 0, 0, 0, 1) - - || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 1, 0, 0) - || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 1, 0, 0) - || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 1, 0, 0) - || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 1, 0, 0) - - || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 1, 0, 0, 1) - || test_gemm_int8(M, N, K, 2.1f, 1, 1, 0, 1, 0, 0, 1) - || test_gemm_int8(M, N, K, 2.1f, 0, 0, 0, 1, 0, 0, 1) - || test_gemm_int8(M, N, K, 2.1f, 1, 0, 0, 1, 0, 0, 1); + || test_gemm_int8(M, N, K, 2.1f, 0, 1, 0, 0, 0, 0, 0) + || test_gemm_int8(M, N, K, 3.1f, 1, 1, 0, 0, 0, 0, 0) + || test_gemm_int8(M, N, K, 4.1f, 0, 0, 0, 0, 0, 0, 1) + || test_gemm_int8(M, N, K, 5.1f, 1, 0, 0, 0, 0, 0, 1) + + || test_gemm_int8(M, N, K, 0.2f, 0, 1, 0, 0, 1, 0, 1) + || test_gemm_int8(M, N, K, 0.3f, 1, 1, 0, 0, 1, 0, 1) + || test_gemm_int8(M, N, K, 0.4f, 0, 0, 0, 0, 0, 1, 0) + || test_gemm_int8(M, N, K, 0.5f, 0, 1, 0, 0, 0, 1, 0) + + || test_gemm_int8(M, N, K, 1.2f, 0, 1, 0, 0, 1, 1, 0) + || test_gemm_int8(M, N, K, 1.3f, 1, 1, 0, 0, 1, 1, 1) + || test_gemm_int8(M, N, K, 1.4f, 0, 0, 0, 0, 1, 1, 0) + || test_gemm_int8(M, N, K, 1.5f, 1, 0, 0, 0, 1, 1, 1) + + || test_gemm_int8(M, N, K, -1.2f, 0, 1, 0, 1, 0, 0, 0) + || test_gemm_int8(M, N, K, -1.3f, 1, 1, 0, 1, 0, 0, 0) + || test_gemm_int8(M, N, K, -1.4f, 0, 0, 0, 1, 0, 0, 1) + || test_gemm_int8(M, N, K, -1.5f, 1, 0, 0, 1, 0, 0, 1) + + || test_gemm_int8(M, N, K, -2.0f, 0, 1, 0, 1, 1, 0, 1) + || test_gemm_int8(M, N, K, -3.0f, 1, 1, 0, 1, 1, 0, 1) + || test_gemm_int8(M, N, K, -4.0f, 0, 0, 0, 1, 0, 1, 0) + || test_gemm_int8(M, N, K, -5.0f, 0, 1, 0, 1, 0, 1, 0) + + || test_gemm_int8(M, N, K, -2.1f, 0, 1, 0, 1, 1, 1, 0) + || test_gemm_int8(M, N, K, -3.1f, 1, 1, 0, 1, 1, 1, 1) + || test_gemm_int8(M, N, K, -4.1f, 0, 0, 0, 1, 1, 1, 0) + || test_gemm_int8(M, N, K, -5.1f, 1, 0, 0, 1, 1, 1, 1); } static int test_gemm_1(int M, int N, int K) @@ -165,16 +175,16 @@ static int test_gemm_1(int M, int N, int K) || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 1, 0, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 2, 1, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 3, 1, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 1, 0, 0, 0, 0) - - || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 2, 0, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 3, 0, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 0, 1, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, 0.8f, 1, 1, 1, 1, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, 0.5f, 0, 0, 2, 0, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, 0.6f, 0, 1, 3, 0, 1, 1, 1); + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 3, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 1, 0, 0, 0, 0) + + || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 3, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 0, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 0, 1, 1, 1); } int main() @@ -232,34 +242,5 @@ int main() return 0; } - for (int M = 1; M <= 15; M++) - { - for (int N = 1; N <= 15; N++) - { - for (int K = 1; K <= 15; K++) - { - // int ret = 0 - // || test_gemm_0(M, N, K) - // || test_gemm_1(M, N, K); - - // int ret = test_gemm_0(M, N, K); - - ncnn::Mat C(N, M); - Randomize(C, -100.f, 100.f); - - // fprintf(stderr, "C %f\n", C[0]); - - int ret = test_gemm_int8_bias(M, N, K, C, 1.f, 1.f, 0, 0, 0, 0, 0, 0, 1); - - if (ret != 0) - return 0; - - ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); - if (ret != 0) - return 0; - } - } - } - return 0; } From a881bd9bf3c8f750fb31fea294909ffe2e5f19b3 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 27 Sep 2024 11:57:43 +0800 Subject: [PATCH 17/55] less openmp args --- src/layer/arm/gemm_arm.cpp | 61 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 676e4492ca39..7ca50b56e758 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5301,6 +5301,18 @@ int Gemm_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector Date: Fri, 27 Sep 2024 03:51:21 +0000 Subject: [PATCH 18/55] apply code-format changes --- src/layer/arm/gemm_arm.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 7ca50b56e758..27dbc93423aa 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5469,7 +5469,7 @@ static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob if (topT.empty()) return -100; - const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta }; + const struct gemm_arm_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta}; #pragma omp parallel for num_threads(nT) for (int ppi = 0; ppi < nn_M; ppi++) @@ -5720,7 +5720,7 @@ static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat& if (topT.empty()) return -100; - const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta }; + const struct gemm_arm_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta}; #pragma omp parallel for num_threads(nT) for (int ppi = 0; ppi < nn_M; ppi++) @@ -5812,7 +5812,7 @@ static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, con if (topT.empty()) return -100; - const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta }; + const struct gemm_arm_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta}; #pragma omp parallel for num_threads(nT) for (int ppi = 0; ppi < nn_M; ppi++) @@ -5938,7 +5938,7 @@ static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Ma if (topT.empty()) return -100; - const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta }; + const struct gemm_arm_int8_omp_args args = {TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta}; #pragma omp parallel for num_threads(nT) for (int ppi = 0; ppi < nn_M; ppi++) From 4798d10bb25121a0ba3206ab31c40bed47024fea Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 27 Sep 2024 14:44:19 +0800 Subject: [PATCH 19/55] revert cpu runtime off build behavior --- cmake/ncnn_add_layer.cmake | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index 6ce5feadbf31..0b0fb3233334 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -133,15 +133,15 @@ macro(ncnn_add_layer class) set(layer_registry_vulkan "${layer_registry_vulkan}#if NCNN_STRING\n{\"${class}\", 0},\n#else\n{0},\n#endif\n") endif() - if(NCNN_TARGET_ARCH STREQUAL "x86") + if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "x86") if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - if(NCNN_RUNTIME_CPU AND NCNN_AVX512) + if(NCNN_AVX512) ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_RUNTIME_CPU AND NCNN_FMA) + if(NCNN_FMA) ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_RUNTIME_CPU AND NCNN_AVX) + if(NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__") endif() if(NCNN_AVX512VNNI) @@ -166,13 +166,13 @@ macro(ncnn_add_layer class) ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__F16C__") endif() elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") - if(NCNN_RUNTIME_CPU AND NCNN_AVX512) + if(NCNN_AVX512) ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_RUNTIME_CPU AND NCNN_FMA) + if(NCNN_FMA) ncnn_add_arch_opt_layer(${class} fma "/arch:AVX -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_RUNTIME_CPU AND NCNN_AVX) + if(NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__") endif() if(NCNN_AVX512VNNI) @@ -197,13 +197,13 @@ macro(ncnn_add_layer class) ncnn_add_arch_opt_source(${class} f16c "/arch:AVX -mf16c /D__SSSE3__ /D__SSE4_1__ /D__F16C__") endif() else() - if(NCNN_RUNTIME_CPU AND NCNN_AVX512) + if(NCNN_AVX512) ncnn_add_arch_opt_layer(${class} avx512 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c") endif() - if(NCNN_RUNTIME_CPU AND NCNN_FMA) + if(NCNN_FMA) ncnn_add_arch_opt_layer(${class} fma "-mavx -mfma -mf16c") endif() - if(NCNN_RUNTIME_CPU AND NCNN_AVX) + if(NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "-mavx") endif() if(NCNN_AVX512VNNI) @@ -230,7 +230,7 @@ macro(ncnn_add_layer class) endif() endif() - if(NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)) + if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) if(NCNN_VFPV4) ncnn_add_arch_opt_source(${class} vfpv4 "/arch:VFPv4 /D__ARM_FP=0x0E") @@ -246,7 +246,7 @@ macro(ncnn_add_layer class) endif() endif() - if(NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 8 OR NCNN_TARGET_ILP32)) + if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 8 OR NCNN_TARGET_ILP32)) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") if(NCNN_VFPV4) ncnn_add_arch_opt_source(${class} vfpv4 " ") @@ -344,8 +344,8 @@ macro(ncnn_add_layer class) endif() endif() - if(NCNN_TARGET_ARCH STREQUAL "mips") - if(NCNN_RUNTIME_CPU AND NCNN_MSA) + if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "mips") + if(NCNN_MSA) ncnn_add_arch_opt_layer(${class} msa "-mmsa") endif() if(NCNN_MMI) @@ -353,17 +353,17 @@ macro(ncnn_add_layer class) endif() endif() - if(NCNN_TARGET_ARCH STREQUAL "loongarch") - if(NCNN_RUNTIME_CPU AND NCNN_LASX) + if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "loongarch") + if(NCNN_LASX) ncnn_add_arch_opt_layer(${class} lasx "-mlasx -mlsx") endif() - if(NCNN_RUNTIME_CPU AND NCNN_LSX) + if(NCNN_LSX) ncnn_add_arch_opt_layer(${class} lsx "-mlsx") endif() endif() - if(NCNN_TARGET_ARCH STREQUAL "riscv" AND CMAKE_SIZEOF_VOID_P EQUAL 8) - if(NCNN_RUNTIME_CPU AND NCNN_RVV) + if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "riscv" AND CMAKE_SIZEOF_VOID_P EQUAL 8) + if(NCNN_RVV) if(NCNN_COMPILER_SUPPORT_RVV_ZFH) ncnn_add_arch_opt_layer(${class} rvv "-march=rv64gcv_zfh") elseif(NCNN_COMPILER_SUPPORT_RVV_ZVFH) From 392e38b2404fa1ea54ec1d1649ed2a70eab1cc3b Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 27 Sep 2024 15:14:32 +0800 Subject: [PATCH 20/55] revert 2nd --- cmake/ncnn_add_layer.cmake | 134 ++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 67 deletions(-) diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index 0b0fb3233334..7f334fb0b68d 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -133,104 +133,104 @@ macro(ncnn_add_layer class) set(layer_registry_vulkan "${layer_registry_vulkan}#if NCNN_STRING\n{\"${class}\", 0},\n#else\n{0},\n#endif\n") endif() - if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "x86") + if(NCNN_TARGET_ARCH STREQUAL "x86") if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - if(NCNN_AVX512) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512) ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_FMA) + if(NCNN_RUNTIME_CPU AND NCNN_FMA) ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_AVX) + if(NCNN_RUNTIME_CPU AND NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__") endif() - if(NCNN_AVX512VNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI) ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") endif() - if(NCNN_AVX512BF16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16) ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__") endif() - if(NCNN_AVX512FP16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16) ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__") endif() - if(NCNN_AVXVNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() - if(NCNN_AVX2) + if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_XOP) + if(NCNN_RUNTIME_CPU AND NCNN_XOP) ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__XOP__") endif() - if(NCNN_F16C) + if(NCNN_RUNTIME_CPU AND NCNN_F16C) ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__F16C__") endif() elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") - if(NCNN_AVX512) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512) ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_FMA) + if(NCNN_RUNTIME_CPU AND NCNN_FMA) ncnn_add_arch_opt_layer(${class} fma "/arch:AVX -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_AVX) + if(NCNN_RUNTIME_CPU AND NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__") endif() - if(NCNN_AVX512VNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI) ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") endif() - if(NCNN_AVX512BF16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16) ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__") endif() - if(NCNN_AVX512FP16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16) ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__") endif() - if(NCNN_AVXVNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() - if(NCNN_AVX2) + if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() - if(NCNN_XOP) + if(NCNN_RUNTIME_CPU AND NCNN_XOP) ncnn_add_arch_opt_source(${class} xop "/arch:AVX -mxop /D__SSSE3__ /D__SSE4_1__ /D__XOP__") endif() - if(NCNN_F16C) + if(NCNN_RUNTIME_CPU AND NCNN_F16C) ncnn_add_arch_opt_source(${class} f16c "/arch:AVX -mf16c /D__SSSE3__ /D__SSE4_1__ /D__F16C__") endif() else() - if(NCNN_AVX512) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512) ncnn_add_arch_opt_layer(${class} avx512 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c") endif() - if(NCNN_FMA) + if(NCNN_RUNTIME_CPU AND NCNN_FMA) ncnn_add_arch_opt_layer(${class} fma "-mavx -mfma -mf16c") endif() - if(NCNN_AVX) + if(NCNN_RUNTIME_CPU AND NCNN_AVX) ncnn_add_arch_opt_layer(${class} avx "-mavx") endif() - if(NCNN_AVX512VNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI) ncnn_add_arch_opt_source(${class} avx512vnni "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni") endif() - if(NCNN_AVX512BF16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16) ncnn_add_arch_opt_source(${class} avx512bf16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16") endif() - if(NCNN_AVX512FP16) + if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16) ncnn_add_arch_opt_source(${class} avx512fp16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16") endif() - if(NCNN_AVXVNNI) + if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI) ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni") endif() - if(NCNN_AVX2) + if(NCNN_RUNTIME_CPU AND NCNN_AVX2) ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c") endif() - if(NCNN_XOP) + if(NCNN_RUNTIME_CPU AND NCNN_XOP) ncnn_add_arch_opt_source(${class} xop "-mavx -mxop") endif() - if(NCNN_F16C) + if(NCNN_RUNTIME_CPU AND NCNN_F16C) ncnn_add_arch_opt_source(${class} f16c "-mavx -mf16c") endif() endif() endif() - if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)) + if(NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) if(NCNN_VFPV4) ncnn_add_arch_opt_source(${class} vfpv4 "/arch:VFPv4 /D__ARM_FP=0x0E") @@ -246,7 +246,7 @@ macro(ncnn_add_layer class) endif() endif() - if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 8 OR NCNN_TARGET_ILP32)) + if(NCNN_TARGET_ARCH STREQUAL "arm" AND (CMAKE_SIZEOF_VOID_P EQUAL 8 OR NCNN_TARGET_ILP32)) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") if(NCNN_VFPV4) ncnn_add_arch_opt_source(${class} vfpv4 " ") @@ -254,28 +254,28 @@ macro(ncnn_add_layer class) if(NCNN_ARM82) ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM82DOT) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT) ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD") endif() - if(NCNN_ARM82FP16FML) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML) ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML") endif() - if(NCNN_ARM84BF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16) ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM84I8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM) ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8") endif() # TODO add support for sve family - if(NCNN_ARM86SVE) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE) endif() - if(NCNN_ARM86SVE2) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2) endif() - if(NCNN_ARM86SVEBF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16) endif() - if(NCNN_ARM86SVEI8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM) endif() - if(NCNN_ARM86SVEF32MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM) endif() elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") if(NCNN_VFPV4) @@ -284,28 +284,28 @@ macro(ncnn_add_layer class) if(NCNN_ARM82) ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 -march=armv8.2-a+fp16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM82DOT) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT) ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 -march=armv8.2-a+fp16+dotprod /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD") endif() - if(NCNN_ARM82FP16FML) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML) ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 -march=armv8.2-a+fp16+fp16fml /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML") endif() - if(NCNN_ARM84BF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16) ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+bf16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC") endif() - if(NCNN_ARM84I8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM) ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+i8mm /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8") endif() # TODO add support for sve family - if(NCNN_ARM86SVE) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE) endif() - if(NCNN_ARM86SVE2) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2) endif() - if(NCNN_ARM86SVEBF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16) endif() - if(NCNN_ARM86SVEI8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM) endif() - if(NCNN_ARM86SVEF32MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM) endif() else() if(NCNN_VFPV4) @@ -314,38 +314,38 @@ macro(ncnn_add_layer class) if(NCNN_ARM82) ncnn_add_arch_opt_source(${class} asimdhp "-march=armv8.2-a+fp16") endif() - if(NCNN_ARM82DOT) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT) ncnn_add_arch_opt_source(${class} asimddp "-march=armv8.2-a+fp16+dotprod") endif() - if(NCNN_ARM82FP16FML) + if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML) ncnn_add_arch_opt_source(${class} asimdfhm "-march=armv8.2-a+fp16+fp16fml") endif() - if(NCNN_ARM84BF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16) ncnn_add_arch_opt_source(${class} bf16 "-march=armv8.4-a+fp16+dotprod+bf16") endif() - if(NCNN_ARM84I8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM) ncnn_add_arch_opt_source(${class} i8mm "-march=armv8.4-a+fp16+dotprod+i8mm") endif() - if(NCNN_ARM86SVE) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE) ncnn_add_arch_opt_source(${class} sve "-march=armv8.6-a+fp16+dotprod+sve") endif() - if(NCNN_ARM86SVE2) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2) ncnn_add_arch_opt_source(${class} sve2 "-march=armv8.6-a+fp16+dotprod+sve2") endif() - if(NCNN_ARM86SVEBF16) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16) ncnn_add_arch_opt_source(${class} svebf16 "-march=armv8.6-a+fp16+dotprod+sve+bf16") endif() - if(NCNN_ARM86SVEI8MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM) ncnn_add_arch_opt_source(${class} svei8mm "-march=armv8.6-a+fp16+dotprod+sve+i8mm") endif() - if(NCNN_ARM86SVEF32MM) + if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM) ncnn_add_arch_opt_source(${class} svef32mm "-march=armv8.6-a+fp16+dotprod+sve+f32mm") endif() endif() endif() - if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "mips") - if(NCNN_MSA) + if(NCNN_TARGET_ARCH STREQUAL "mips") + if(NCNN_RUNTIME_CPU AND NCNN_MSA) ncnn_add_arch_opt_layer(${class} msa "-mmsa") endif() if(NCNN_MMI) @@ -353,17 +353,17 @@ macro(ncnn_add_layer class) endif() endif() - if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "loongarch") - if(NCNN_LASX) + if(NCNN_TARGET_ARCH STREQUAL "loongarch") + if(NCNN_RUNTIME_CPU AND NCNN_LASX) ncnn_add_arch_opt_layer(${class} lasx "-mlasx -mlsx") endif() - if(NCNN_LSX) + if(NCNN_RUNTIME_CPU AND NCNN_LSX) ncnn_add_arch_opt_layer(${class} lsx "-mlsx") endif() endif() - if(NCNN_RUNTIME_CPU AND NCNN_TARGET_ARCH STREQUAL "riscv" AND CMAKE_SIZEOF_VOID_P EQUAL 8) - if(NCNN_RVV) + if(NCNN_TARGET_ARCH STREQUAL "riscv" AND CMAKE_SIZEOF_VOID_P EQUAL 8) + if(NCNN_RUNTIME_CPU AND NCNN_RVV) if(NCNN_COMPILER_SUPPORT_RVV_ZFH) ncnn_add_arch_opt_layer(${class} rvv "-march=rv64gcv_zfh") elseif(NCNN_COMPILER_SUPPORT_RVV_ZVFH) From 207e166d187d10c501455730ede8be622e0723ef Mon Sep 17 00:00:00 2001 From: nihuini Date: Sun, 29 Sep 2024 19:10:51 +0800 Subject: [PATCH 21/55] vfpv4 fp16 --- src/layer/arm/gemm_arm.cpp | 602 +- src/layer/arm/gemm_arm.h | 2 + src/layer/arm/gemm_arm_asimddp.cpp | 31 + src/layer/arm/gemm_arm_i8mm.cpp | 21 + src/layer/arm/gemm_arm_vfpv4.cpp | 51 + src/layer/arm/gemm_int8.h | 36 + src/layer/arm/gemm_int8_bf16s.h | 83 +- src/layer/arm/gemm_int8_fp16s.h | 12834 +++++++++++++++++++++++++++ 8 files changed, 13360 insertions(+), 300 deletions(-) create mode 100644 src/layer/arm/gemm_int8_fp16s.h diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 27dbc93423aa..cb25ee6a0a49 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -4250,7 +4250,7 @@ int Gemm_arm::create_pipeline(const Option& opt) if (int8_scale_term) { // support_packing = false; - support_fp16_storage = false; + // support_fp16_storage = false; // support_bf16_storage = false; return create_pipeline_int8(opt); // return 0; @@ -5301,6 +5301,208 @@ int Gemm_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector 4 ? 4 : out_elempack; + if (output_elempack) out_elempack = output_elempack; size_t out_elemsize = 4u * out_elempack; + // FIXME use output_elemtype instead of input_elemtype + int output_elemtype = input_elemtype; + + // TODO use output_elemtype if (opt.use_bf16_storage) { out_elemsize = 2u * out_elempack; } +#if NCNN_VFPV4 + else if (support_fp16_storage && opt.use_fp16_storage) + { + out_elemsize = 2u * out_elempack; + } +#endif Mat& top_blob = top_blobs[0]; if (output_transpose) @@ -6241,23 +6259,23 @@ int Gemm_arm::forward_int8(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector Date: Sun, 29 Sep 2024 11:04:53 +0000 Subject: [PATCH 22/55] apply code-format changes --- src/layer/arm/gemm_arm.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index cb25ee6a0a49..1630d7db5c07 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5976,16 +5976,16 @@ int Gemm_arm::create_pipeline_int8(const Option& opt) else #endif #if NCNN_VFPV4 - if (ncnn::cpu_support_arm_vfpv4()) - { - use_bf16 = opt.use_bf16_storage; - use_fp16 = opt.use_fp16_storage && !opt.use_bf16_storage; - } - else + if (ncnn::cpu_support_arm_vfpv4()) + { + use_bf16 = opt.use_bf16_storage; + use_fp16 = opt.use_fp16_storage && !opt.use_bf16_storage; + } + else #endif - { - input_elemtype = 1; // fp32 - } + { + input_elemtype = 1; // fp32 + } if (use_fp16) input_elemtype = 2; if (use_bf16) input_elemtype = 3; From 93fad8d41b46fb39aa0f2b3eee1e6196646af964 Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 29 Sep 2024 11:45:07 +0000 Subject: [PATCH 23/55] apply code-format changes --- src/layer/arm/gemm_int8_fp16s.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index a84c229d3276..c38d203c8fd5 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -78,7 +78,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s } float32x4_t _absmax0 = vcvt_f32_f16(vget_low_f16(_absmax)); float32x4_t _absmax1 = vcvt_f32_f16(vget_high_f16(_absmax)); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -179,7 +179,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s p0 += 4; } float32x4_t _absmax0 = vcvt_f32_f16(_absmax); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -1474,7 +1474,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, } float16x4_t _aa = vmax_f16(vget_low_f16(_absmax), vget_high_f16(_absmax)); float absmax = vmaxvq_f32(vcvt_f32_f16(_aa)); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 8; float32x4_t _absmax0 = vdupq_n_f32(0.f); From eb0c83336a70f743f99be4a032e1615bfe4cafbe Mon Sep 17 00:00:00 2001 From: nihuini Date: Sun, 29 Sep 2024 19:51:20 +0800 Subject: [PATCH 24/55] stash --- src/layer/arm/gemm_arm_asimdhp.cpp | 21 +++++++ src/layer/arm/gemm_int8_fp16s.h | 99 ++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/src/layer/arm/gemm_arm_asimdhp.cpp b/src/layer/arm/gemm_arm_asimdhp.cpp index cb0aa87e4add..dd5d4e6f8460 100644 --- a/src/layer/arm/gemm_arm_asimdhp.cpp +++ b/src/layer/arm/gemm_arm_asimdhp.cpp @@ -27,6 +27,10 @@ namespace ncnn { #include "gemm_bf16s_fp16s.h" #include "gemm_fp16s.h" +#if NCNN_INT8 +#include "gemm_int8_fp16s.h" +#endif + static void gemm_transB_packed_tile_fp16sa(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) { const int out_elempack = top_blob.elempack; @@ -3026,4 +3030,21 @@ int Gemm_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector Date: Mon, 30 Sep 2024 15:02:27 +0800 Subject: [PATCH 25/55] stash --- src/layer/arm/gemm_int8_fp16s.h | 105 +++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 10 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index a84c229d3276..a6886b4735e0 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -68,16 +68,42 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; - float16x8_t _absmax = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + 16); + float16x8_t _p3 = vld1q_f16(p0 + 24); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(_p2)); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(_p3)); + p0 += 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + p0 += 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); for (; kk < K; kk++) { float16x8_t _p = vld1q_f16(p0); - _absmax = vmaxq_f16(_absmax, vabsq_f16(_p)); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); p0 += 8; } - float32x4_t _absmax0 = vcvt_f32_f16(vget_low_f16(_absmax)); - float32x4_t _absmax1 = vcvt_f32_f16(vget_high_f16(_absmax)); + float32x4_t _absmax0 = vcvt_f32_f16(vget_low_f16(_amax0)); + float32x4_t _absmax1 = vcvt_f32_f16(vget_high_f16(_amax0)); #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; @@ -170,15 +196,48 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; - float16x4_t _absmax = vdup_n_f16((__fp16)0.f); + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); int kk = 0; + for (; kk + 7 < K; kk += 8) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + 16); + float16x8_t _p3 = vld1q_f16(p0 + 24); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(_p2)); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(_p3)); + p0 += 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 3 < K; kk += 4) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + p0 += 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p = vld1q_f16(p0); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); + p0 += 8; + } + float16x4_t _amax = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); for (; kk < K; kk++) { float16x4_t _p = vld1_f16(p0); - _absmax = vmax_f16(_absmax, vabs_f16(_p)); + _amax = vmax_f16(_amax, vabs_f16(_p)); p0 += 4; } - float32x4_t _absmax0 = vcvt_f32_f16(_absmax); + float32x4_t _absmax0 = vcvt_f32_f16(_amax); #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; @@ -262,15 +321,41 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; __fp16 absmax = 0.f; + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); int kk = 0; - float16x8_t _absmax = vdupq_n_f16((__fp16)0.f); + for (; kk + 31 < K; kk += 32) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + 16); + float16x8_t _p3 = vld1q_f16(p0 + 24); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(_p2)); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(_p3)); + p0 += 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 15 < K; kk += 16) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p0)); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(_p1)); + p0 += 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); for (; kk + 7 < K; kk += 8) { float16x8_t _p = vld1q_f16(p0); - _absmax = vmaxq_f16(_absmax, vabsq_f16(_p)); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); p0 += 8; } - float16x4_t _aa = vmax_f16(vget_low_f16(_absmax), vget_high_f16(_absmax)); + float16x4_t _aa = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); absmax = (__fp16)vmaxvq_f32(vcvt_f32_f16(_aa)); for (; kk < K; kk++) { From 1e50e8892d784aaca91b796fa345500957b80ec9 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 30 Sep 2024 15:27:44 +0800 Subject: [PATCH 26/55] build --- src/layer/arm/gemm_int8_fp16s.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index a6709c520616..9fb1481bb57f 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -102,8 +102,8 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); p0 += 8; } - float32x4_t _absmax0 = vcvt_f32_f16(vget_low_f16(_absmax)); - float32x4_t _absmax1 = vcvt_f32_f16(vget_high_f16(_absmax)); + float32x4_t _absmax0 = vcvt_f32_f16(vget_low_f16(_amax0)); + float32x4_t _absmax1 = vcvt_f32_f16(vget_high_f16(_amax0)); #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; @@ -237,7 +237,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s _amax = vmax_f16(_amax, vabs_f16(_p)); p0 += 4; } - float32x4_t _absmax0 = vcvt_f32_f16(_absmax); + float32x4_t _absmax0 = vcvt_f32_f16(_amax); #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; From d9e9b389a90c977b41bc09706e3bb0cd8425e9dc Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 15:57:52 +0800 Subject: [PATCH 27/55] opt++ --- src/layer/arm/gemm_int8_fp16s.h | 444 +++++++++++++++++++++++++++++++- 1 file changed, 432 insertions(+), 12 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 9fb1481bb57f..6acfc519bbcd 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -1543,22 +1543,123 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, if (elempack == 8) { int ii = 0; - // TODO unroll 2 + for (; ii + 1 < max_ii; ii += 2) + { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii) * 8; + + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + A_hstep * 8); + float16x8_t _p3 = vld1q_f16(p0 + A_hstep * 8 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + for (; kk < K; kk++) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + p0 += A_hstep * 8; + } + float absmax0 = (float)vmaxvq_f16(_absmax0); + float absmax1 = (float)vmaxvq_f16(_absmax1); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 8; + + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); + int kk = 0; + for (; kk < K; kk++) + { + uint16x8_t _p01 = vld1q_u16(p0); + uint16x8_t _p23 = vld1q_u16(p0 + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); +#if __aarch64__ + float absmax0 = vmaxvq_f32(_absmax0); + float absmax1 = vmaxvq_f32(_absmax1); +#else + float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); + float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); + float absmax0 = vget_lane_f32(_aa01, 0); + float absmax1 = vget_lane_f32(_aa01, 1); +#endif +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + + ps[0] = 127.f / absmax0; + ps[1] = 127.f / absmax1; + pods[0] = absmax0 / v127_B_scale; + pods[1] = absmax1 / v127_B_scale; + ps += 2; + pods += 2; + } for (; ii < max_ii; ii++) { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const __fp16* p0 = (const __fp16*)A + (i + ii) * 8; - float16x8_t _absmax = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); int kk = 0; + for (; kk + 3 < K; kk += 4) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + A_hstep * 8); + float16x8_t _p2 = vld1q_f16(p0 + A_hstep * 16); + float16x8_t _p3 = vld1q_f16(p0 + A_hstep * 24); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + p0 += A_hstep * 32; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + A_hstep * 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + p0 += A_hstep * 16; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax1); for (; kk < K; kk++) { float16x8_t _p = vld1q_f16(p0); - _absmax = vmaxq_f16(_absmax, vabsq_f16(_p)); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p)); p0 += A_hstep * 8; } - float16x4_t _aa = vmax_f16(vget_low_f16(_absmax), vget_high_f16(_absmax)); - float absmax = vmaxvq_f32(vcvt_f32_f16(_aa)); + float absmax = (float)vmaxvq_f16(_absmax0); #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 8; @@ -1593,8 +1694,12 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, p0 += A_hstep * 8; } _absmax0 = vmaxq_f32(_absmax0, _absmax1); +#if __aarch64__ + float absmax = vmaxvq_f32(_absmax0); +#else float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); +#endif #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ps[0] = 127.f / absmax; @@ -1608,6 +1713,39 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, int ii = 0; for (; ii + 3 < max_ii; ii += 4) { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii) * 4; + + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 1 < K; kk += 2) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + float16x8_t _p2 = vld1q_f16(p0 + A_hstep * 4); + float16x8_t _p3 = vld1q_f16(p0 + A_hstep * 4 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + for (; kk < K; kk++) + { + float16x8_t _p0 = vld1q_f16(p0); + float16x8_t _p1 = vld1q_f16(p0 + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + p0 += A_hstep * 4; + } + float16x8_t _aa0123 = vpmaxq_f16(_absmax0, _absmax1); + float32x4_t _absmax = vcvt_f32_f16(vpmax_f16(vget_low_f16(_aa0123), vget_high_f16(_aa0123))); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -1628,6 +1766,11 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); p0 += A_hstep * 4; } +#if __aarch64__ + float32x4_t _aa01 = vpmaxq_f32(_absmax0, _absmax1); + float32x4_t _aa23 = vpmaxq_f32(_absmax2, _absmax3); + float32x4_t _absmax = vpmaxq_f32(_aa01, _aa23); +#else float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); float32x2_t _aa2 = vmax_f32(vget_low_f32(_absmax2), vget_high_f32(_absmax2)); @@ -1635,6 +1778,8 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); float32x2_t _aa23 = vpmax_f32(_aa2, _aa3); float32x4_t _absmax = vcombine_f32(_aa01, _aa23); +#endif +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __aarch64__ float32x4_t _scale = vdivq_f32(_v127, _absmax); @@ -1669,6 +1814,59 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, } for (; ii < max_ii; ii++) { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii) * 4; + + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 7 < K; kk += 8) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep * 4); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 8); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 12); + float16x4_t _p4 = vld1_f16(p0 + A_hstep * 16); + float16x4_t _p5 = vld1_f16(p0 + A_hstep * 20); + float16x4_t _p6 = vld1_f16(p0 + A_hstep * 24); + float16x4_t _p7 = vld1_f16(p0 + A_hstep * 28); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(vcombine_f16(_p4, _p5))); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(vcombine_f16(_p6, _p7))); + p0 += A_hstep * 32; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 3 < K; kk += 4) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep * 4); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 8); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 12); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + p0 += A_hstep * 16; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk + 1 < K; kk += 2) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep * 4); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + p0 += A_hstep * 8; + } + float16x4_t _amax01 = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); + for (; kk < K; kk++) + { + float16x4_t _p = vld1_f16(p0); + _amax01 = vmax_f16(_amax01, vabs_f16(_p)); + p0 += A_hstep * 4; + } + float absmax = (float)vmaxv_f16(_amax01); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 4; float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -1705,8 +1903,13 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); p0 += A_hstep * 4; } +#if __aarch64__ + float absmax = vmaxvq_f32(_absmax0); +#else float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); +#endif +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ps[0] = 127.f / absmax; pods[0] = absmax / v127_B_scale; @@ -1721,6 +1924,59 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, #if __ARM_NEON for (; ii + 3 < max_ii; ii += 4) { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii); + + float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _amax3 = vdupq_n_f16((__fp16)0.f); + int kk = 0; + for (; kk + 7 < K; kk += 8) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 2); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 3); + float16x4_t _p4 = vld1_f16(p0 + A_hstep * 4); + float16x4_t _p5 = vld1_f16(p0 + A_hstep * 5); + float16x4_t _p6 = vld1_f16(p0 + A_hstep * 6); + float16x4_t _p7 = vld1_f16(p0 + A_hstep * 7); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + _amax2 = vmaxq_f16(_amax2, vabsq_f16(vcombine_f16(_p4, _p5))); + _amax3 = vmaxq_f16(_amax3, vabsq_f16(vcombine_f16(_p6, _p7))); + p0 += A_hstep * 8; + } + _amax0 = vmaxq_f16(_amax0, _amax2); + _amax1 = vmaxq_f16(_amax1, _amax3); + for (; kk + 3 < K; kk += 4) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep); + float16x4_t _p2 = vld1_f16(p0 + A_hstep * 2); + float16x4_t _p3 = vld1_f16(p0 + A_hstep * 3); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + _amax1 = vmaxq_f16(_amax1, vabsq_f16(vcombine_f16(_p2, _p3))); + p0 += A_hstep * 4; + } + _amax0 = vmaxq_f16(_amax0, _amax1); + for (; kk + 1 < K; kk += 2) + { + float16x4_t _p0 = vld1_f16(p0); + float16x4_t _p1 = vld1_f16(p0 + A_hstep); + _amax0 = vmaxq_f16(_amax0, vabsq_f16(vcombine_f16(_p0, _p1))); + p0 += A_hstep * 2; + } + float16x4_t _amax = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); + for (; kk < K; kk++) + { + float16x4_t _p = vld1_f16(p0); + _amax = vmax_f16(_amax, vabs_f16(_p)); + p0 += A_hstep; + } + float32x4_t _absmax0 = vcvt_f32_f16(_amax); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii); float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -1757,6 +2013,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); p0 += A_hstep; } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __aarch64__ float32x4_t _scale = vdivq_f32(_v127, _absmax0); @@ -1792,14 +2049,93 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, #endif // __ARM_NEON for (; ii < max_ii; ii++) { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* p0 = (const __fp16*)A + (i + ii); + + float absmax = 0.f; + int kk = 0; + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + for (; kk + 7 < K; kk += 8) + { + float16x8_t _p = float16x8_t(); + _p = vsetq_lane_f16(p0[0], _p, 0); + _p = vsetq_lane_f16(p0[A_hstep], _p, 1); + _p = vsetq_lane_f16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_f16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_f16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_f16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_f16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_f16(p0[A_hstep * 7], _p, 7); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p)); + p0 += A_hstep * 8; + } + float16x4_t _amax0 = vmax_f16(vget_low_f16(_absmax0), vget_high_f16(_absmax0)); + for (; kk + 3 < K; kk += 4) + { + float16x4_t _p = float16x4_t(); + _p = vset_lane_f16(p0[0], _p, 0); + _p = vset_lane_f16(p0[A_hstep], _p, 1); + _p = vset_lane_f16(p0[A_hstep * 2], _p, 2); + _p = vset_lane_f16(p0[A_hstep * 3], _p, 3); + _amax0 = vmax_f16(_amax0, vabs_f16(_p)); + p0 += A_hstep * 4; + } + absmax = (float)vmaxv_f16(_amax0); + for (; kk < K; kk++) + { + absmax = std::max(absmax, fabs((float)p0[0])); + p0 += A_hstep; + } +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii); float absmax = 0.f; - for (int kk = 0; kk < K; kk++) + int kk = 0; +#if __ARM_NEON + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + for (; kk + 7 < K; kk += 8) + { + uint16x8_t _p = uint16x8_t(); + _p = vsetq_lane_u16(p0[0], _p, 0); + _p = vsetq_lane_u16(p0[A_hstep], _p, 1); + _p = vsetq_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vsetq_lane_u16(p0[A_hstep * 3], _p, 3); + _p = vsetq_lane_u16(p0[A_hstep * 4], _p, 4); + _p = vsetq_lane_u16(p0[A_hstep * 5], _p, 5); + _p = vsetq_lane_u16(p0[A_hstep * 6], _p, 6); + _p = vsetq_lane_u16(p0[A_hstep * 7], _p, 7); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + p0 += A_hstep * 8; + } + _absmax0 = vmaxq_f32(_absmax0, _absmax1); + for (; kk + 3 < K; kk += 4) + { + uint16x4_t _p = uint16x4_t(); + _p = vset_lane_u16(p0[0], _p, 0); + _p = vset_lane_u16(p0[A_hstep], _p, 1); + _p = vset_lane_u16(p0[A_hstep * 2], _p, 2); + _p = vset_lane_u16(p0[A_hstep * 3], _p, 3); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)_p); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + p0 += A_hstep * 4; + } +#if __aarch64__ + absmax = vmaxvq_f32(_absmax0); +#else + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); + absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); +#endif +#endif // __ARM_NEON + for (; kk < K; kk++) { absmax = std::max(absmax, (float)fabs(float16_to_float32(p0[0]))); p0 += A_hstep; } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ps[0] = 127.f / absmax; pods[0] = absmax / v127_B_scale; @@ -3080,43 +3416,127 @@ static void compute_B_fp16_int8_scale(const Mat& B, float& scale) float absmax = 0.f; #if __ARM_NEON - float32x4_t _absmax = vdupq_n_f32(0.f); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); + float16x4_t _amax = vdup_n_f16((__fp16)0.f); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + float32x4_t _absmax0 = vdupq_n_f32(0.f); + float32x4_t _absmax1 = vdupq_n_f32(0.f); + float32x4_t _absmax2 = vdupq_n_f32(0.f); + float32x4_t _absmax3 = vdupq_n_f32(0.f); +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #endif for (int i = 0; i < (B.dims == 3 ? B.c : B.h); i++) { const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w; - const unsigned short* ptr = (const unsigned short*)B + i * B_hstep * B.elempack; const int size = B.w * B.elempack; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const __fp16* ptr = (const __fp16*)B + i * B_hstep * B.elempack; + + int j = 0; + for (; j + 31 < size; j += 32) + { + float16x8_t _p0 = vld1q_f16(ptr); + float16x8_t _p1 = vld1q_f16(ptr + 8); + float16x8_t _p2 = vld1q_f16(ptr + 16); + float16x8_t _p3 = vld1q_f16(ptr + 24); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + _absmax2 = vmaxq_f16(_absmax2, vabsq_f16(_p2)); + _absmax3 = vmaxq_f16(_absmax3, vabsq_f16(_p3)); + ptr += 32; + } + for (; j + 15 < size; j += 16) + { + float16x8_t _p0 = vld1q_f16(ptr); + float16x8_t _p1 = vld1q_f16(ptr + 8); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p0)); + _absmax1 = vmaxq_f16(_absmax1, vabsq_f16(_p1)); + ptr += 16; + } + for (; j + 7 < size; j += 8) + { + float16x8_t _p = vld1q_f16(ptr); + _absmax0 = vmaxq_f16(_absmax0, vabsq_f16(_p)); + ptr += 8; + } + for (; j + 3 < size; j += 4) + { + float16x4_t _p = vld1_f16(ptr); + _amax = vmax_f16(_amax, vabs_f16(_p)); + ptr += 4; + } + for (; j < size; j++) + { + absmax = std::max(absmax, fabs((float)ptr[0])); + ptr++; + } +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + const unsigned short* ptr = (const unsigned short*)B + i * B_hstep * B.elempack; + int j = 0; #if __ARM_NEON + for (; j + 15 < size; j += 16) + { + uint16x8_t _p = vld1q_u16(ptr); + uint16x8_t _q = vld1q_u16(ptr + 8); + float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); + float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); + float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_q)); + float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_q)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); + _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); + _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); + ptr += 16; + } for (; j + 7 < size; j += 8) { uint16x8_t _p = vld1q_u16(ptr); float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p)); float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p)); - _absmax = vmaxq_f32(_absmax, vabsq_f32(_p0)); - _absmax = vmaxq_f32(_absmax, vabsq_f32(_p1)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); + _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); ptr += 8; } for (; j + 3 < size; j += 4) { float32x4_t _p = vcvt_f32_f16((float16x4_t)vld1_u16(ptr)); - _absmax = vmaxq_f32(_absmax, vabsq_f32(_p)); + _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p)); ptr += 4; } -#endif +#endif // __ARM_NEON for (; j < size; j++) { absmax = std::max(absmax, (float)fabs(float16_to_float32(ptr[0]))); ptr++; } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC } #if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + _absmax0 = vmaxq_f16(_absmax0, _absmax2); + _absmax1 = vmaxq_f16(_absmax1, _absmax3); + _absmax0 = vmaxq_f16(_absmax0, _absmax1); + absmax = std::max(absmax, (float)vmaxvq_f16(_absmax0)); + absmax = std::max(absmax, (float)vmaxv_f16(_amax)); +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + _absmax0 = vmaxq_f32(_absmax0, _absmax2); + _absmax1 = vmaxq_f32(_absmax1, _absmax3); + _absmax0 = vmaxq_f32(_absmax0, _absmax1); +#if __aarch64__ + absmax = std::max(absmax, vmaxvq_f32(_absmax0)); +#else float32x2_t _aa = vmax_f32(vget_low_f32(_absmax), vget_high_f32(_absmax)); absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); #endif +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __ARM_NEON scale = absmax == 0.f ? 1.f : 127.f / absmax; } From 5566752075dc27fa72a7b75db7a5c9a61054f46e Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 16:28:05 +0800 Subject: [PATCH 28/55] opt++ --- src/layer/arm/gemm_int8.h | 120 +++++---------- src/layer/arm/gemm_int8_bf16s.h | 124 +++++++-------- src/layer/arm/gemm_int8_fp16s.h | 260 ++++++++++++++++---------------- 3 files changed, 220 insertions(+), 284 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 4dc97c7a014d..e54c56874633 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -8828,79 +8828,44 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } for (; jj + 1 < max_jj; jj += 2) { - // TODO neon optimize - float f00 = pp[0] * descale0; - float f01 = pp[1] * descale0; - float f10 = pp[2] * descale1; - float f11 = pp[3] * descale1; + int32x4_t _sum0 = vld1q_s32(pp); + + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); if (pC) { if (broadcast_type_C == 0) { - f00 += c0; - f01 += c0; - f10 += c0; - f11 += c0; + _f0 = vaddq_f32(_f0, _c0); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - f00 += c0; - f01 += c0; - f10 += c1; - f11 += c1; + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); } if (broadcast_type_C == 3) { // c_elempack == 1 - if (beta == 1.f) - { - f00 += pC[0]; - f01 += pC[1]; - f10 += pC[c_hstep]; - f11 += pC[c_hstep + 1]; - } - else - { - f00 += pC[0] * beta; - f01 += pC[1] * beta; - f10 += pC[c_hstep] * beta; - f11 += pC[c_hstep + 1] * beta; - } + _c0 = vcombine_f32(vld1_f32(pC), vld1_f32(pC + c_hstep)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } if (broadcast_type_C == 4) { - if (beta == 1.f) - { - f00 += pC[0]; - f01 += pC[1]; - f10 += pC[0]; - f11 += pC[1]; - } - else - { - f00 += pC[0] * beta; - f01 += pC[1] * beta; - f10 += pC[0] * beta; - f11 += pC[1] * beta; - } + float32x2_t _c = vld1_f32(pC); + _c0 = vcombine_f32(_c, _c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } } - if (alpha != 1.f) - { - f00 *= alpha; - f01 *= alpha; - f10 *= alpha; - f11 *= alpha; - } + _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = f00; - p0[1] = f01; - p0[out_hstep] = f10; - p0[out_hstep + 1] = f11; + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); pp += 4; p0 += 2; @@ -12196,61 +12161,48 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } for (; jj + 1 < max_jj; jj += 2) { - // TODO neon optimize // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); - float f00 = pp[0] * descale0; - float f01 = pp[2] * descale1; - float f10 = pp[1] * descale0; - float f11 = pp[3] * descale1; + float32x4_t _descale = vcombine_f32(_descale01, _descale01); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); if (pC) { if (broadcast_type_C == 0) { - f00 += c0; - f01 += c0; - f10 += c0; - f11 += c0; + _f0 = vaddq_f32(_f0, _c0); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - f00 += c0; - f01 += c1; - f10 += c0; - f11 += c1; + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); } if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += pC[0] * beta; - f01 += pC[c_hstep] * beta; - f10 += pC[1] * beta; - f11 += pC[c_hstep + 1] * beta; + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2x2_t _c01 = vzip_f32(_cc0, _cc1); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } if (broadcast_type_C == 4) { - f00 += pC[0] * beta; - f01 += pC[0] * beta; - f10 += pC[1] * beta; - f11 += pC[1] * beta; + float32x2_t _cc = vld1_f32(pC); + float32x2x2_t _c01 = vzip_f32(_cc, _cc); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } } - if (alpha != 1.f) - { - f00 *= alpha; - f01 *= alpha; - f10 *= alpha; - f11 *= alpha; - } + _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = f00; - p0[1] = f01; - p0[out_hstep] = f10; - p0[out_hstep + 1] = f11; + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); pp += 4; p0 += out_hstep * 2; diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index d27ced6ca9df..298e156b0dbe 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -7421,59 +7421,57 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } for (; jj + 1 < max_jj; jj += 2) { - // TODO neon optimize - float f00 = pp[0] * descale0; - float f01 = pp[1] * descale0; - float f10 = pp[2] * descale1; - float f11 = pp[3] * descale1; + int32x4_t _sum0 = vld1q_s32(pp); + + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); if (pC) { if (broadcast_type_C == 0) { - f00 += c0; - f01 += c0; - f10 += c0; - f11 += c0; + _f0 = vaddq_f32(_f0, _c0); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - f00 += c0; - f01 += c0; - f10 += c1; - f11 += c1; + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); } if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += bfloat16_to_float32(pC[0]) * beta; - f01 += bfloat16_to_float32(pC[1]) * beta; - f10 += bfloat16_to_float32(pC[c_hstep]) * beta; - f11 += bfloat16_to_float32(pC[c_hstep + 1]) * beta; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[c_hstep], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } if (broadcast_type_C == 4) { - f00 += bfloat16_to_float32(pC[0]) * beta; - f01 += bfloat16_to_float32(pC[1]) * beta; - f10 += bfloat16_to_float32(pC[0]) * beta; - f11 += bfloat16_to_float32(pC[1]) * beta; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[0], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } } - if (alpha != 1.f) - { - f00 *= alpha; - f01 *= alpha; - f10 *= alpha; - f11 *= alpha; - } + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _bf0 = float2bfloat(_f0); - p0[0] = float32_to_bfloat16(f00); - p0[1] = float32_to_bfloat16(f01); - p0[out_hstep] = float32_to_bfloat16(f10); - p0[out_hstep + 1] = float32_to_bfloat16(f11); + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); pp += 4; p0 += 2; @@ -10832,63 +10830,57 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } for (; jj + 1 < max_jj; jj += 2) { - // TODO neon optimize // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); - float f00 = pp[0] * descale0; - float f01 = pp[2] * descale1; - float f10 = pp[1] * descale0; - float f11 = pp[3] * descale1; + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); if (pC) { if (broadcast_type_C == 0) { - f00 += c0; - f01 += c0; - f10 += c0; - f11 += c0; + _f0 = vaddq_f32(_f0, _c0); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - f00 += c0; - f01 += c1; - f10 += c0; - f11 += c1; + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); } if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += bfloat16_to_float32(pC[0]) * beta; - f01 += bfloat16_to_float32(pC[c_hstep]) * beta; - f10 += bfloat16_to_float32(pC[1]) * beta; - f11 += bfloat16_to_float32(pC[c_hstep + 1]) * beta; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } if (broadcast_type_C == 4) { - float c0 = bfloat16_to_float32(pC[0]) * beta; - float c1 = bfloat16_to_float32(pC[1]) * beta; - f00 += c0; - f01 += c0; - f10 += c1; - f11 += c1; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[0], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } } - if (alpha != 1.f) - { - f00 *= alpha; - f01 *= alpha; - f10 *= alpha; - f11 *= alpha; - } + _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = float32_to_bfloat16(f00); - p0[1] = float32_to_bfloat16(f01); - p0[out_hstep] = float32_to_bfloat16(f10); - p0[out_hstep + 1] = float32_to_bfloat16(f11); + uint16x4_t _bf0 = float2bfloat(_f0); + + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); pp += 4; p0 += out_hstep * 2; diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 6acfc519bbcd..74c4e9f3caeb 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -3532,7 +3532,7 @@ static void compute_B_fp16_int8_scale(const Mat& B, float& scale) #if __aarch64__ absmax = std::max(absmax, vmaxvq_f32(_absmax0)); #else - float32x2_t _aa = vmax_f32(vget_low_f32(_absmax), vget_high_f32(_absmax)); + float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1))); #endif #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -9253,59 +9253,57 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } for (; jj + 1 < max_jj; jj += 2) { - // TODO neon optimize - float f00 = pp[0] * descale0; - float f01 = pp[1] * descale0; - float f10 = pp[2] * descale1; - float f11 = pp[3] * descale1; + int32x4_t _sum0 = vld1q_s32(pp); + + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); if (pC) { if (broadcast_type_C == 0) { - f00 += c0; - f01 += c0; - f10 += c0; - f11 += c0; + _f0 = vaddq_f32(_f0, _c0); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - f00 += c0; - f01 += c0; - f10 += c1; - f11 += c1; + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); } if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += float16_to_float32(pC[0]) * beta; - f01 += float16_to_float32(pC[1]) * beta; - f10 += float16_to_float32(pC[c_hstep]) * beta; - f11 += float16_to_float32(pC[c_hstep + 1]) * beta; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[c_hstep], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } if (broadcast_type_C == 4) { - f00 += float16_to_float32(pC[0]) * beta; - f01 += float16_to_float32(pC[1]) * beta; - f10 += float16_to_float32(pC[0]) * beta; - f11 += float16_to_float32(pC[1]) * beta; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[0], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } } - if (alpha != 1.f) - { - f00 *= alpha; - f01 *= alpha; - f10 *= alpha; - f11 *= alpha; - } + _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = float32_to_float16(f00); - p0[1] = float32_to_float16(f01); - p0[out_hstep] = float32_to_float16(f10); - p0[out_hstep + 1] = float32_to_float16(f11); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf0, 1); + p0[out_hstep] = vget_lane_u16(_hf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_hf0, 3); pp += 4; p0 += 2; @@ -12790,27 +12788,27 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - uint16x4_t _bf0 = (uint16x4_t)vcvt_f16_f32(_f0); - uint16x4_t _bf1 = (uint16x4_t)vcvt_f16_f32(_f1); - uint16x4_t _bf2 = (uint16x4_t)vcvt_f16_f32(_f2); - uint16x4_t _bf3 = (uint16x4_t)vcvt_f16_f32(_f3); - - p0[0] = vget_lane_u16(_bf0, 0); - p0[1] = vget_lane_u16(_bf0, 1); - p0[out_hstep] = vget_lane_u16(_bf0, 2); - p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 2] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 3] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf2, 0); - p0[out_hstep * 4 + 1] = vget_lane_u16(_bf2, 1); - p0[out_hstep * 5] = vget_lane_u16(_bf2, 2); - p0[out_hstep * 5 + 1] = vget_lane_u16(_bf2, 3); - p0[out_hstep * 6] = vget_lane_u16(_bf3, 0); - p0[out_hstep * 6 + 1] = vget_lane_u16(_bf3, 1); - p0[out_hstep * 7] = vget_lane_u16(_bf3, 2); - p0[out_hstep * 7 + 1] = vget_lane_u16(_bf3, 3); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf0, 1); + p0[out_hstep] = vget_lane_u16(_hf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 2] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 2 + 1] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 3] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 3 + 1] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf2, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 5] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 5 + 1] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 6] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 6 + 1] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 7] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 7 + 1] = vget_lane_u16(_hf3, 3); pp += 16; p0 += out_hstep * 8; @@ -12893,80 +12891,74 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - uint16x4_t _bf0 = (uint16x4_t)vcvt_f16_f32(_f0); - uint16x4_t _bf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); - p0[0] = vget_lane_u16(_bf0, 0); - p0[1] = vget_lane_u16(_bf0, 1); - p0[out_hstep] = vget_lane_u16(_bf0, 2); - p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 2] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 3] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf0, 1); + p0[out_hstep] = vget_lane_u16(_hf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 2] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 2 + 1] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 3] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 3 + 1] = vget_lane_u16(_hf1, 3); pp += 8; p0 += out_hstep * 4; } for (; jj + 1 < max_jj; jj += 2) { - // TODO neon optimize // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); - float f00 = pp[0] * descale0; - float f01 = pp[2] * descale1; - float f10 = pp[1] * descale0; - float f11 = pp[3] * descale1; + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); if (pC) { if (broadcast_type_C == 0) { - f00 += c0; - f01 += c0; - f10 += c0; - f11 += c0; + _f0 = vaddq_f32(_f0, _c0); } if (broadcast_type_C == 1 || broadcast_type_C == 2) { - f00 += c0; - f01 += c1; - f10 += c0; - f11 += c1; + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); } if (broadcast_type_C == 3) { // c_elempack == 1 - f00 += float16_to_float32(pC[0]) * beta; - f01 += float16_to_float32(pC[c_hstep]) * beta; - f10 += float16_to_float32(pC[1]) * beta; - f11 += float16_to_float32(pC[c_hstep + 1]) * beta; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } if (broadcast_type_C == 4) { - float c0 = float16_to_float32(pC[0]) * beta; - float c1 = float16_to_float32(pC[1]) * beta; - f00 += c0; - f01 += c0; - f10 += c1; - f11 += c1; + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[0], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = vcvt_f32_f16((float16x4_t)_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); pC += 2; } } - if (alpha != 1.f) - { - f00 *= alpha; - f01 *= alpha; - f10 *= alpha; - f11 *= alpha; - } + _f0 = vmulq_n_f32(_f0, alpha); + + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); - p0[0] = float32_to_float16(f00); - p0[1] = float32_to_float16(f01); - p0[out_hstep] = float32_to_float16(f10); - p0[out_hstep + 1] = float32_to_float16(f11); + p0[0] = vget_lane_u16(_hf0, 0); + p0[1] = vget_lane_u16(_hf0, 1); + p0[out_hstep] = vget_lane_u16(_hf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_hf0, 3); pp += 4; p0 += out_hstep * 2; @@ -13264,27 +13256,27 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - uint16x4_t _bf0 = (uint16x4_t)vcvt_f16_f32(_f0); - uint16x4_t _bf1 = (uint16x4_t)vcvt_f16_f32(_f1); - uint16x4_t _bf2 = (uint16x4_t)vcvt_f16_f32(_f2); - uint16x4_t _bf3 = (uint16x4_t)vcvt_f16_f32(_f3); - - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); pp += 16; p0 += out_hstep * 16; @@ -13332,17 +13324,17 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - uint16x4_t _bf0 = (uint16x4_t)vcvt_f16_f32(_f0); - uint16x4_t _bf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); pp += 8; p0 += out_hstep * 8; @@ -13368,12 +13360,12 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f0 = vmulq_n_f32(_f0, alpha); - uint16x4_t _bf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); pp += 4; p0 += out_hstep * 4; From 03168d7694a3c7d489a629820e8af4a531805c3b Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 8 Oct 2024 08:20:38 +0000 Subject: [PATCH 29/55] apply code-format changes --- src/layer/arm/gemm_int8_fp16s.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 74c4e9f3caeb..0cfb7bb84909 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -1577,7 +1577,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, } float absmax0 = (float)vmaxvq_f16(_absmax0); float absmax1 = (float)vmaxvq_f16(_absmax1); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 8; float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -1660,7 +1660,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, p0 += A_hstep * 8; } float absmax = (float)vmaxvq_f16(_absmax0); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 8; float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -1976,7 +1976,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, p0 += A_hstep; } float32x4_t _absmax0 = vcvt_f32_f16(_amax); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const unsigned short* p0 = (const unsigned short*)A + (i + ii); float32x4_t _absmax0 = vdupq_n_f32(0.f); @@ -3422,7 +3422,7 @@ static void compute_B_fp16_int8_scale(const Mat& B, float& scale) float16x8_t _absmax2 = vdupq_n_f16((__fp16)0.f); float16x8_t _absmax3 = vdupq_n_f16((__fp16)0.f); float16x4_t _amax = vdup_n_f16((__fp16)0.f); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC float32x4_t _absmax0 = vdupq_n_f32(0.f); float32x4_t _absmax1 = vdupq_n_f32(0.f); float32x4_t _absmax2 = vdupq_n_f32(0.f); From a280bfb33c269a2ed54371f75fb762355adfdb62 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 17:12:48 +0800 Subject: [PATCH 30/55] fast path --- src/layer/arm/gemm_int8.h | 124 ++++++++++++++++++++++---------- src/layer/arm/gemm_int8_bf16s.h | 115 ++++++++++++++++++++--------- src/layer/arm/gemm_int8_fp16s.h | 115 ++++++++++++++++++++--------- 3 files changed, 250 insertions(+), 104 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index e54c56874633..49c946d8fc00 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -12352,10 +12352,21 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 8, _f2); - vst1q_f32(p0 + out_hstep * 12, _f3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + } + else + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 8, _f2); + vst1q_f32(p0 + out_hstep * 12, _f3); + } + pp += 16; p0 += out_hstep * 16; } @@ -12401,8 +12412,17 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + else + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + } + pp += 8; p0 += out_hstep * 8; } @@ -12493,22 +12513,32 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); - p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); - p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); - p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); - p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); - p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0); - p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1); - p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2); - p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3); - p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0); - p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1); - p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); - p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + } + else + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0); + p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1); + p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2); + p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3); + p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0); + p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1); + p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); + p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); + } pp += 16; p0 += out_hstep * 16; @@ -12555,14 +12585,22 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); - p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); - p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); - p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); - p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + else + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + } pp += 8; p0 += out_hstep * 8; @@ -12588,10 +12626,17 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + } + else + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + } pp += 4; p0 += out_hstep * 4; @@ -12617,8 +12662,15 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f0 = vmul_n_f32(_f0, alpha); - p0[0] = vget_lane_f32(_f0, 0); - p0[out_hstep] = vget_lane_f32(_f0, 1); + if (out_hstep == 1) + { + vst1_f32(p0, _f0); + } + else + { + p0[0] = vget_lane_f32(_f0, 0); + p0[out_hstep] = vget_lane_f32(_f0, 1); + } pp += 2; p0 += out_hstep * 2; diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 298e156b0dbe..36a4e423031c 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -11035,10 +11035,24 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); - vst1_u16(p0 + out_hstep * 8, float2bfloat(_f2)); - vst1_u16(p0 + out_hstep * 12, float2bfloat(_f3)); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + } + else + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + vst1_u16(p0 + out_hstep * 8, _bf2); + vst1_u16(p0 + out_hstep * 12, _bf3); + } + pp += 16; p0 += out_hstep * 16; } @@ -11085,8 +11099,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + } + else + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + } + pp += 8; p0 += out_hstep * 8; } @@ -11183,22 +11208,30 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _bf2 = float2bfloat(_f2); uint16x4_t _bf3 = float2bfloat(_f3); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + } + else + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + } pp += 16; p0 += out_hstep * 16; @@ -11249,14 +11282,21 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + } + else + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + } pp += 8; p0 += out_hstep * 8; @@ -11284,10 +11324,17 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _bf0 = float2bfloat(_f0); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + if (out_hstep == 1) + { + vst1_u16(p0, _bf0); + } + else + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + } pp += 4; p0 += out_hstep * 4; diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 0cfb7bb84909..629bec91ad51 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -13113,10 +13113,24 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); - vst1_u16(p0 + out_hstep * 4, (uint16x4_t)vcvt_f16_f32(_f1)); - vst1_u16(p0 + out_hstep * 8, (uint16x4_t)vcvt_f16_f32(_f2)); - vst1_u16(p0 + out_hstep * 12, (uint16x4_t)vcvt_f16_f32(_f3)); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + } + else + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + vst1_u16(p0 + out_hstep * 8, _hf2); + vst1_u16(p0 + out_hstep * 12, _hf3); + } + pp += 16; p0 += out_hstep * 16; } @@ -13163,8 +13177,19 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); - vst1_u16(p0 + out_hstep * 4, (uint16x4_t)vcvt_f16_f32(_f1)); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } + else + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + } + pp += 8; p0 += out_hstep * 8; } @@ -13261,22 +13286,30 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + } + else + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); + } pp += 16; p0 += out_hstep * 16; @@ -13327,14 +13360,21 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } + else + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + } pp += 8; p0 += out_hstep * 8; @@ -13362,10 +13402,17 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + if (out_hstep == 1) + { + vst1_u16(p0, _hf0); + } + else + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + } pp += 4; p0 += out_hstep * 4; From 3748dcc15f16a723e24c1011e466bcd38d63ba31 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 17:15:01 +0800 Subject: [PATCH 31/55] cc --- src/layer/arm/gemm_arm.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 1630d7db5c07..eca41be82e2f 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -4249,11 +4249,7 @@ int Gemm_arm::create_pipeline(const Option& opt) #if NCNN_INT8 if (int8_scale_term) { - // support_packing = false; - // support_fp16_storage = false; - // support_bf16_storage = false; return create_pipeline_int8(opt); - // return 0; } #endif From 9fa25321e353c0127a0d078a28323f7f8c3ffb4b Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 19:20:06 +0800 Subject: [PATCH 32/55] fabsf --- src/layer/arm/gemm_int8.h | 6 +++--- src/layer/arm/gemm_int8_bf16s.h | 6 +++--- src/layer/arm/gemm_int8_fp16s.h | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 49c946d8fc00..09c94d5226f2 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -1854,7 +1854,7 @@ static void compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float B_s #endif // __ARM_NEON for (; kk < K; kk++) { - absmax = std::max(absmax, (float)fabs(p0[0])); + absmax = std::max(absmax, (float)fabsf(p0[0])); p0++; } @@ -3012,7 +3012,7 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, float absmax = 0.f; for (int kk = 0; kk < K; kk++) { - absmax = std::max(absmax, (float)fabs(p0[0])); + absmax = std::max(absmax, (float)fabsf(p0[0])); p0 += A_hstep; } @@ -3956,7 +3956,7 @@ static void compute_B_fp32_int8_scale(const Mat& B, float& scale) #endif for (; j < size; j++) { - absmax = std::max(absmax, (float)fabs(ptr[0])); + absmax = std::max(absmax, (float)fabsf(ptr[0])); ptr++; } } diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 36a4e423031c..3ec9f15000a4 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -174,7 +174,7 @@ static void compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float B_s #endif // __ARM_NEON for (; kk < K; kk++) { - absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(p0[0]))); + absmax = std::max(absmax, (float)fabsf(bfloat16_to_float32(p0[0]))); p0++; } @@ -1319,7 +1319,7 @@ static void transpose_compute_A_tile_bf16_int8_scales(const Mat& A, Mat& scales, float absmax = 0.f; for (int kk = 0; kk < K; kk++) { - absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(p0[0]))); + absmax = std::max(absmax, (float)fabsf(bfloat16_to_float32(p0[0]))); p0 += A_hstep; } @@ -2350,7 +2350,7 @@ static void compute_B_bf16_int8_scale(const Mat& B, float& scale) #endif for (; j < size; j++) { - absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(ptr[0]))); + absmax = std::max(absmax, (float)fabsf(bfloat16_to_float32(ptr[0]))); ptr++; } } diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 629bec91ad51..cfcd798f8033 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -359,7 +359,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s absmax = (__fp16)vmaxvq_f32(vcvt_f32_f16(_aa)); for (; kk < K; kk++) { - absmax = std::max(absmax, (__fp16)fabs(p0[0])); + absmax = std::max(absmax, (__fp16)fabsf(p0[0])); p0++; } #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -413,7 +413,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s #endif // __ARM_NEON for (; kk < K; kk++) { - absmax = std::max(absmax, (float)fabs(float16_to_float32(p0[0]))); + absmax = std::max(absmax, (float)fabsf(float16_to_float32(p0[0]))); p0++; } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -2083,7 +2083,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, absmax = (float)vmaxv_f16(_amax0); for (; kk < K; kk++) { - absmax = std::max(absmax, fabs((float)p0[0])); + absmax = std::max(absmax, (float)fabsf((float)p0[0])); p0 += A_hstep; } #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -2132,7 +2132,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, #endif // __ARM_NEON for (; kk < K; kk++) { - absmax = std::max(absmax, (float)fabs(float16_to_float32(p0[0]))); + absmax = std::max(absmax, (float)fabsf(float16_to_float32(p0[0]))); p0 += A_hstep; } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -3473,7 +3473,7 @@ static void compute_B_fp16_int8_scale(const Mat& B, float& scale) } for (; j < size; j++) { - absmax = std::max(absmax, fabs((float)ptr[0])); + absmax = std::max(absmax, (float)fabsf((float)ptr[0])); ptr++; } #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -3513,7 +3513,7 @@ static void compute_B_fp16_int8_scale(const Mat& B, float& scale) #endif // __ARM_NEON for (; j < size; j++) { - absmax = std::max(absmax, (float)fabs(float16_to_float32(ptr[0]))); + absmax = std::max(absmax, (float)fabsf(float16_to_float32(ptr[0]))); ptr++; } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC From 6a1c346f57951ccb9143df058770a8df521eb8b8 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 19:26:03 +0800 Subject: [PATCH 33/55] fix build --- src/layer/arm/gemm_int8_fp16s.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index cfcd798f8033..eb532e63663a 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -320,7 +320,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; - __fp16 absmax = 0.f; + float absmax = 0.f; float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); float16x8_t _amax1 = vdupq_n_f16((__fp16)0.f); float16x8_t _amax2 = vdupq_n_f16((__fp16)0.f); @@ -356,10 +356,10 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s p0 += 8; } float16x4_t _aa = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); - absmax = (__fp16)vmaxvq_f32(vcvt_f32_f16(_aa)); + absmax = vmaxvq_f32(vcvt_f32_f16(_aa)); for (; kk < K; kk++) { - absmax = std::max(absmax, (__fp16)fabsf(p0[0])); + absmax = std::max(absmax, (float)fabsf(p0[0])); p0++; } #else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC From 878f7236b8830fb473b248a2ce988c08317f12f0 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 19:27:17 +0800 Subject: [PATCH 34/55] fix build --- src/layer/arm/gemm_int8_fp16s.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index eb532e63663a..bca6f9271cfa 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -355,8 +355,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s _amax0 = vmaxq_f16(_amax0, vabsq_f16(_p)); p0 += 8; } - float16x4_t _aa = vmax_f16(vget_low_f16(_amax0), vget_high_f16(_amax0)); - absmax = vmaxvq_f32(vcvt_f32_f16(_aa)); + absmax = (float)vmaxvq_f16(_amax0); for (; kk < K; kk++) { absmax = std::max(absmax, (float)fabsf(p0[0])); From 8ab3c4d8b9ba0f8d6f95f603076f0a1998624f24 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 9 Oct 2024 11:38:25 +0800 Subject: [PATCH 35/55] fix tests --- tests/test_gemm.cpp | 2 +- tests/test_gemm_1.cpp | 4 ++-- tests/test_gemm_3.cpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_gemm.cpp b/tests/test_gemm.cpp index c2900e9ac611..95f5436f6843 100644 --- a/tests/test_gemm.cpp +++ b/tests/test_gemm.cpp @@ -300,7 +300,7 @@ int main() || test_gemm_1(M, N, K); if (ret != 0) - return 0; + return ret; } return 0; diff --git a/tests/test_gemm_1.cpp b/tests/test_gemm_1.cpp index 59a0c8256278..7179bf8d26f9 100644 --- a/tests/test_gemm_1.cpp +++ b/tests/test_gemm_1.cpp @@ -120,13 +120,13 @@ int main() int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K); if (ret != 0) - return 0; + return ret; } // test no tiling int ret = test_gemm_0(M, N, K, 100, 100, 100); if (ret != 0) - return 0; + return ret; } return 0; diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index 0d2f257c555f..c70b26ccefaa 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -239,7 +239,7 @@ int main() int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); if (ret != 0) - return 0; + return ret; } return 0; From 24d2c15531c19cc8d11dc37dcd11593134fa3050 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 9 Oct 2024 11:45:14 +0800 Subject: [PATCH 36/55] x86 riscv fallback --- src/layer/riscv/gemm_riscv.cpp | 15 +++++++++++++++ src/layer/x86/gemm_x86.cpp | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/layer/riscv/gemm_riscv.cpp b/src/layer/riscv/gemm_riscv.cpp index fa25a058cb1c..8dee572548ed 100644 --- a/src/layer/riscv/gemm_riscv.cpp +++ b/src/layer/riscv/gemm_riscv.cpp @@ -3947,6 +3947,14 @@ int Gemm_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt int Gemm_riscv::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + support_packing = false; + return 0; + } +#endif + if (constantA) { const int M = constantM; @@ -4070,6 +4078,13 @@ int Gemm_riscv::create_pipeline(const Option& opt) int Gemm_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + int M; int N; if (constantA && constantB) diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index d9b11f7ea8fa..268f85f332d8 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -7222,6 +7222,14 @@ static int gemm_AT_BT_x86(const Mat& AT, const Mat& BT, const Mat& C, Mat& top_b int Gemm_x86::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + support_packing = false; + return 0; + } +#endif + if (constantA) { const int M = constantM; @@ -7355,6 +7363,13 @@ int Gemm_x86::create_pipeline(const Option& opt) int Gemm_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return Gemm::forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + int M; int N; if (constantA && constantB) From 20a221f76ed13dfc6c5e48216018b5f8f3a1cc4f Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 9 Oct 2024 15:06:45 +0800 Subject: [PATCH 37/55] skip gemm vulkan int8 --- src/layer/vulkan/gemm_vulkan.cpp | 13 +++++++++++++ src/layer/vulkan/gemm_vulkan.h | 2 ++ 2 files changed, 15 insertions(+) diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index 0d403a5288b9..4bfe1f7d2e87 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -26,6 +26,19 @@ Gemm_vulkan::Gemm_vulkan() pipeline_gemm = 0; } +int Gemm_vulkan::load_param(const ParamDict& pd) +{ + int ret = Gemm::load_param(pd); + + if (int8_scale_term) + { + support_vulkan = false; + support_image_storage = false; + } + + return ret; +} + int Gemm_vulkan::create_pipeline(const Option& opt) { // const Mat& shape = top_shapes.empty() ? Mat() : top_shapes[0]; diff --git a/src/layer/vulkan/gemm_vulkan.h b/src/layer/vulkan/gemm_vulkan.h index d9fa92018e42..b1b37927b400 100644 --- a/src/layer/vulkan/gemm_vulkan.h +++ b/src/layer/vulkan/gemm_vulkan.h @@ -24,6 +24,8 @@ class Gemm_vulkan : public Gemm public: Gemm_vulkan(); + virtual int load_param(const ParamDict& pd); + virtual int create_pipeline(const Option& opt); virtual int destroy_pipeline(const Option& opt); From 5b099e4e6a6363be1d8ad4d36bd2a6ceccbe4f6a Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 9 Oct 2024 19:19:06 +0800 Subject: [PATCH 38/55] fix --- src/layer/arm/gemm_arm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index eca41be82e2f..79c551ae5a58 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5980,9 +5980,9 @@ int Gemm_arm::create_pipeline_int8(const Option& opt) else #endif { - input_elemtype = 1; // fp32 } + input_elemtype = 1; // fp32 if (use_fp16) input_elemtype = 2; if (use_bf16) input_elemtype = 3; From 55e2de740325037a31b22481750090147b35fdc8 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 10 Oct 2024 14:23:14 +0800 Subject: [PATCH 39/55] fix --- src/layer/arm/gemm_int8_fp16s.h | 48 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index bca6f9271cfa..34dde2629de7 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -2260,14 +2260,14 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #else // __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = float2int8(_p0, _p8); - int8x8_t _r1 = float2int8(_p1, _p9); - int8x8_t _r2 = float2int8(_p2, _pa); - int8x8_t _r3 = float2int8(_p3, _pb); - int8x8_t _r4 = float2int8(_p4, _pc); - int8x8_t _r5 = float2int8(_p5, _pd); - int8x8_t _r6 = float2int8(_p6, _pe); - int8x8_t _r7 = float2int8(_p7, _pf); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); @@ -2714,10 +2714,10 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); #else // __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = float2int8(_p0, _p4); - int8x8_t _r1 = float2int8(_p1, _p5); - int8x8_t _r2 = float2int8(_p2, _p6); - int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4)); @@ -4663,14 +4663,14 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int vst1q_s8(pp + 32, vcombine_s8(_r4, _r5)); vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #else // __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = float2int8(_p0, _p8); - int8x8_t _r1 = float2int8(_p1, _p9); - int8x8_t _r2 = float2int8(_p2, _pa); - int8x8_t _r3 = float2int8(_p3, _pb); - int8x8_t _r4 = float2int8(_p4, _pc); - int8x8_t _r5 = float2int8(_p5, _pd); - int8x8_t _r6 = float2int8(_p6, _pe); - int8x8_t _r7 = float2int8(_p7, _pf); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); vst1q_s8(pp, vcombine_s8(_r0, _r1)); vst1q_s8(pp + 16, vcombine_s8(_r2, _r3)); @@ -5074,10 +5074,10 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int int8x8_t _r2 = float2int8(_p4, _p5); int8x8_t _r3 = float2int8(_p6, _p7); #else // __ARM_FEATURE_MATMUL_INT8 - int8x8_t _r0 = float2int8(_p0, _p4); - int8x8_t _r1 = float2int8(_p1, _p5); - int8x8_t _r2 = float2int8(_p2, _p6); - int8x8_t _r3 = float2int8(_p3, _p7); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p1, _p3); + int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4)); From 061205aa2e20eaa705954be51039ed5d7d913165 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 10 Oct 2024 14:53:40 +0800 Subject: [PATCH 40/55] fix noint8 test, fix arm bf16 test --- src/layer/arm/gemm_arm.cpp | 1 + tests/test_gemm_3.cpp | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 79c551ae5a58..baf0da4e6504 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -5980,6 +5980,7 @@ int Gemm_arm::create_pipeline_int8(const Option& opt) else #endif { + use_bf16 = opt.use_bf16_storage; } input_elemtype = 1; // fp32 diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index c70b26ccefaa..5348f8d674e1 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -14,6 +14,7 @@ #include "testutil.h" +#if NCNN_INT8 static int test_gemm_int8(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M) { ncnn::ParamDict pd; @@ -186,11 +187,13 @@ static int test_gemm_1(int M, int N, int K) || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1) || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 0, 1, 1, 1); } +#endif // NCNN_INT8 int main() { SRAND(7767517); +#if NCNN_INT8 int mnk[][3] = { {1, 1, 1}, {2, 2, 2}, @@ -241,6 +244,9 @@ int main() if (ret != 0) return ret; } +#else + // test nothing for non-int8 build +#endif return 0; } From 73e73640c5dc677085c4578f1f97156ebea99578 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 10 Oct 2024 14:54:54 +0800 Subject: [PATCH 41/55] enable vfpv4 on neon build only --- CMakeLists.txt | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f32a80c86ee..875a8d06598f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -162,21 +162,25 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm") endif() if(CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32) - if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) - set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + check_cxx_source_compiles("#include \nint main() { float32x4_t _s, _a, _b; _s = vmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM_NEON) - unset(CMAKE_REQUIRED_FLAGS) - else() - set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + if(NCNN_COMPILER_SUPPORT_ARM_NEON) + if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) + set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4") + check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) - if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4) - set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) - endif() + unset(CMAKE_REQUIRED_FLAGS) + else() + set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4") + check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) - unset(CMAKE_REQUIRED_FLAGS) + if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4) + set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee") + check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) + endif() + + unset(CMAKE_REQUIRED_FLAGS) + endif() endif() if(NCNN_COMPILER_SUPPORT_ARM_VFPV4 OR NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) From 9c3057bb02f6dbe0bc9d8040c9d6394d47a6ea74 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 10 Oct 2024 15:30:47 +0800 Subject: [PATCH 42/55] fix test, test++ --- src/layer/arm/gemm_int8_fp16s.h | 48 ++++++++++++++++----------------- tests/test_gemm_3.cpp | 7 +++++ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 34dde2629de7..05dd8975da0a 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -2275,14 +2275,14 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int8x8_t _r0 = float2int8(_p0, _p1); - int8x8_t _r1 = float2int8(_p2, _p3); - int8x8_t _r2 = float2int8(_p4, _p5); - int8x8_t _r3 = float2int8(_p6, _p7); - int8x8_t _r4 = float2int8(_p8, _p9); - int8x8_t _r5 = float2int8(_pa, _pb); - int8x8_t _r6 = float2int8(_pc, _pd); - int8x8_t _r7 = float2int8(_pe, _pf); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); @@ -2720,10 +2720,10 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4)); - int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p5)); - int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p2, _p6)); - int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p3, _p7)); + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); int16x4x2_t _t23 = vuzp_s16(_t2, _t3); int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); @@ -4678,14 +4678,14 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int vst1q_s8(pp + 48, vcombine_s8(_r6, _r7)); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int8x8_t _r0 = float2int8(_p0, _p1); - int8x8_t _r1 = float2int8(_p2, _p3); - int8x8_t _r2 = float2int8(_p4, _p5); - int8x8_t _r3 = float2int8(_p6, _p7); - int8x8_t _r4 = float2int8(_p8, _p9); - int8x8_t _r5 = float2int8(_pa, _pb); - int8x8_t _r6 = float2int8(_pc, _pd); - int8x8_t _r7 = float2int8(_pe, _pf); + int8x8_t _r0 = float2int8(_p0, _p2); + int8x8_t _r1 = float2int8(_p4, _p6); + int8x8_t _r2 = float2int8(_p8, _pa); + int8x8_t _r3 = float2int8(_pc, _pe); + int8x8_t _r4 = float2int8(_p1, _p3); + int8x8_t _r5 = float2int8(_p5, _p7); + int8x8_t _r6 = float2int8(_p9, _pb); + int8x8_t _r7 = float2int8(_pd, _pf); int16x8_t _r01 = vreinterpretq_s16_s8(vcombine_s8(_r0, _r1)); int16x8_t _r23 = vreinterpretq_s16_s8(vcombine_s8(_r2, _r3)); @@ -5080,10 +5080,10 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int int8x8_t _r3 = float2int8(_p5, _p7); #endif // __ARM_FEATURE_MATMUL_INT8 #else // __ARM_FEATURE_DOTPROD - int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p4)); - int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p1, _p5)); - int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p2, _p6)); - int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p3, _p7)); + int16x4_t _t0 = vreinterpret_s16_s8(float2int8(_p0, _p2)); + int16x4_t _t1 = vreinterpret_s16_s8(float2int8(_p4, _p6)); + int16x4_t _t2 = vreinterpret_s16_s8(float2int8(_p1, _p3)); + int16x4_t _t3 = vreinterpret_s16_s8(float2int8(_p5, _p7)); int16x4x2_t _t01 = vuzp_s16(_t0, _t1); int16x4x2_t _t23 = vuzp_s16(_t2, _t3); int8x8_t _r0 = vreinterpret_s8_s16(_t01.val[0]); diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index 5348f8d674e1..f753a6594fab 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -243,6 +243,13 @@ int main() int ret = test_gemm_0(M, N, K) || test_gemm_1(M, N, K); if (ret != 0) return ret; + + if (M != N) + { + int ret = test_gemm_0(N, M, K) || test_gemm_1(N, M, K); + if (ret != 0) + return ret; + } } #else // test nothing for non-int8 build From 0b7755dce30d64a4883793ba94fd2c6983536db4 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 10 Oct 2024 16:59:38 +0800 Subject: [PATCH 43/55] fix gemm vulkan without C --- src/layer/vulkan/gemm_vulkan.cpp | 144 ++++++++++++++++++++----------- 1 file changed, 94 insertions(+), 50 deletions(-) diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index 4bfe1f7d2e87..81c92aeb6f17 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -182,56 +182,78 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vectorconvert_packing(A0, A, 1, cmd, opt); vkdev->convert_packing(B0, B, 1, cmd, opt); - vkdev->convert_packing(C0, C, 1, cmd, opt); const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; - int broadcast_type_C; + VkMat C; + int broadcast_type_C = -1; if (constantC) { + vkdev->convert_packing(C_data_gpu, C, 1, cmd, opt); broadcast_type_C = constant_broadcast_type_C; } else { - if (C.dims == 1 && C.w == 1) - { - // scalar - broadcast_type_C = 0; - } - if (C.dims == 1 && C.w == M) + VkMat C0; + if (constantA && constantB) { - // M - // auto broadcast from h to w is the ncnn-style convention - broadcast_type_C = 1; + C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkMat(); } - if (C.dims == 1 && C.w == N) + else if (constantA) { - // N - broadcast_type_C = 4; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat(); } - if (C.dims == 2 && C.w == 1 && C.h == M) + else if (constantB) { - // Mx1 - broadcast_type_C = 2; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkMat(); } - if (C.dims == 2 && C.w == N && C.h == M) + else { - // MxN - broadcast_type_C = 3; + C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkMat(); } - if (C.dims == 2 && C.w == N && C.h == 1) + + if (!C0.empty()) { - // 1xN - broadcast_type_C = 4; + vkdev->convert_packing(C0, C, 1, cmd, opt); + + if (C0.dims == 1 && C0.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C0.dims == 1 && C0.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C0.dims == 1 && C0.w == N) + { + // N + broadcast_type_C = 4; + } + if (C0.dims == 2 && C0.w == 1 && C0.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C0.dims == 2 && C0.w == N && C0.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C0.dims == 2 && C0.w == N && C0.h == 1) + { + // 1xN + broadcast_type_C = 4; + } } } @@ -314,56 +336,78 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vecto { const VkImageMat& A0 = constantA ? A_data_gpu_image : bottom_blobs[0]; const VkImageMat& B0 = constantB ? B_data_gpu_image : constantA ? bottom_blobs[0] : bottom_blobs[1]; - const VkImageMat& C0 = constantC ? C_data_gpu_image : bottom_blobs[bottom_blobs.size() - 1]; VkImageMat A; VkImageMat B; - VkImageMat C; vkdev->convert_packing(A0, A, 1, cmd, opt); vkdev->convert_packing(B0, B, 1, cmd, opt); - vkdev->convert_packing(C0, C, 1, cmd, opt); const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; - int broadcast_type_C; + VkImageMat C; + int broadcast_type_C = -1; if (constantC) { + vkdev->convert_packing(C_data_gpu_image, C, 1, cmd, opt); broadcast_type_C = constant_broadcast_type_C; } else { - if (C.dims == 1 && C.w == 1) - { - // scalar - broadcast_type_C = 0; - } - if (C.dims == 1 && C.w == M) + VkImageMat C0; + if (constantA && constantB) { - // M - // auto broadcast from h to w is the ncnn-style convention - broadcast_type_C = 1; + C0 = bottom_blobs.size() == 1 ? bottom_blobs[0] : VkImageMat(); } - if (C.dims == 1 && C.w == N) + else if (constantA) { - // N - broadcast_type_C = 4; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat(); } - if (C.dims == 2 && C.w == 1 && C.h == M) + else if (constantB) { - // Mx1 - broadcast_type_C = 2; + C0 = bottom_blobs.size() == 2 ? bottom_blobs[1] : VkImageMat(); } - if (C.dims == 2 && C.w == N && C.h == M) + else { - // MxN - broadcast_type_C = 3; + C0 = bottom_blobs.size() == 3 ? bottom_blobs[2] : VkImageMat(); } - if (C.dims == 2 && C.w == N && C.h == 1) + + if (!C0.empty()) { - // 1xN - broadcast_type_C = 4; + vkdev->convert_packing(C0, C, 1, cmd, opt); + + if (C.dims == 1 && C.w == 1) + { + // scalar + broadcast_type_C = 0; + } + if (C.dims == 1 && C.w == M) + { + // M + // auto broadcast from h to w is the ncnn-style convention + broadcast_type_C = 1; + } + if (C.dims == 1 && C.w == N) + { + // N + broadcast_type_C = 4; + } + if (C.dims == 2 && C.w == 1 && C.h == M) + { + // Mx1 + broadcast_type_C = 2; + } + if (C.dims == 2 && C.w == N && C.h == M) + { + // MxN + broadcast_type_C = 3; + } + if (C.dims == 2 && C.w == N && C.h == 1) + { + // 1xN + broadcast_type_C = 4; + } } } From f5a828b675f442bfe64c5b35169885658db13f83 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 10 Oct 2024 17:18:05 +0800 Subject: [PATCH 44/55] test++ --- tests/test_gemm_3.cpp | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index f753a6594fab..0ede5ab4232e 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -174,18 +174,30 @@ static int test_gemm_1(int M, int N, int K) { return 0 || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 0, 0, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 1, 0, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 2, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(1), 2.1f, 0.5f, 0, 0, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 2, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(M), 3.1f, 0.6f, 0, 1, 3, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 2, 0, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 3, 1, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 0, 0, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 1, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 2, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 3, 1, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 2, 0, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 3, 0, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 0, 1, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 1, 1, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 0, 1, 1, 1); + || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1), -2.1f, 0.5f, 0, 0, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(M), -3.1f, 0.6f, 0, 1, 3, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 3, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 1, 1, 1, 1); } #endif // NCNN_INT8 From 5e4fcfae80c46bb278b8bd62f80614b220da084b Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 10 Oct 2024 18:46:29 +0800 Subject: [PATCH 45/55] test++ --- src/layer/vulkan/gemm_vulkan.cpp | 12 +++---- tests/test_gemm_3.cpp | 60 ++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index 81c92aeb6f17..eed5dd357fd6 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -223,33 +223,33 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vectorconvert_packing(C0, C, 1, cmd, opt); - if (C0.dims == 1 && C0.w == 1) + if (C.dims == 1 && C.w == 1) { // scalar broadcast_type_C = 0; } - if (C0.dims == 1 && C0.w == M) + if (C.dims == 1 && C.w == M) { // M // auto broadcast from h to w is the ncnn-style convention broadcast_type_C = 1; } - if (C0.dims == 1 && C0.w == N) + if (C.dims == 1 && C.w == N) { // N broadcast_type_C = 4; } - if (C0.dims == 2 && C0.w == 1 && C0.h == M) + if (C.dims == 2 && C.w == 1 && C.h == M) { // Mx1 broadcast_type_C = 2; } - if (C0.dims == 2 && C0.w == N && C0.h == M) + if (C.dims == 2 && C.w == N && C.h == M) { // MxN broadcast_type_C = 3; } - if (C0.dims == 2 && C0.w == N && C0.h == 1) + if (C.dims == 2 && C.w == N && C.h == 1) { // 1xN broadcast_type_C = 4; diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index 0ede5ab4232e..f8433d18b215 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -181,8 +181,12 @@ static int test_gemm_1(int M, int N, int K) || test_gemm_int8_bias(M, N, K, RandomMat(1, M), 4.1f, 0.7f, 1, 0, 1, 1, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 2, 0, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 5.1f, -0.8f, 1, 1, 3, 1, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 0, 0, 0, 0, 0) - || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 2, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 2.1f, -0.5f, 0, 0, 3, 1, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 0, 0, 0) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 2, 0, 0, 0, 0) || test_gemm_int8_bias(M, N, K, RandomMat(N), 3.1f, -0.6f, 0, 1, 3, 1, 0, 0, 0) @@ -194,8 +198,12 @@ static int test_gemm_1(int M, int N, int K) || test_gemm_int8_bias(M, N, K, RandomMat(1, M), -4.1f, 0.7f, 1, 0, 1, 1, 1, 1, 1) || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 2, 0, 1, 1, 1) || test_gemm_int8_bias(M, N, K, RandomMat(N, M), -5.1f, -0.8f, 1, 1, 3, 1, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 0, 0, 1, 1, 1) - || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, M), 1.f, 1.f, 1, 1, 1, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 2, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), -2.1f, -0.5f, 0, 0, 3, 1, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N, 1), 0.8f, 1.f, 0, 0, 0, 0, 1, 1, 1) + || test_gemm_int8_bias(M, N, K, RandomMat(N), 0.8f, 1.f, 0, 0, 1, 1, 1, 1, 1) || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 2, 0, 1, 1, 1) || test_gemm_int8_bias(M, N, K, RandomMat(N), -3.1f, -0.6f, 0, 1, 3, 1, 1, 1, 1); } @@ -208,40 +216,40 @@ int main() #if NCNN_INT8 int mnk[][3] = { {1, 1, 1}, + {1, 1, 23}, + {1, 1, 47}, + {1, 23, 1}, + {1, 23, 23}, + {1, 31, 1}, + {1, 35, 1}, + {1, 35, 47}, + {1, 47, 1}, {2, 2, 2}, {3, 3, 3}, {4, 4, 4}, {5, 5, 5}, {6, 6, 6}, {7, 7, 7}, + {7, 31, 3}, {8, 8, 8}, - {15, 15, 15}, - {16, 16, 16}, - {31, 31, 31}, - {40, 40, 40}, - {1, 1, 23}, - {1, 31, 1}, - {23, 1, 1}, {12, 12, 23}, + {12, 23, 12}, {12, 31, 12}, - {23, 12, 12}, - {1, 1, 47}, - {1, 35, 1}, - {47, 1, 1}, - {24, 24, 47}, - {24, 35, 24}, - {47, 24, 24}, - {1, 35, 47}, + {15, 15, 15}, + {16, 16, 16}, + {19, 44, 7}, + {20, 28, 7}, {23, 31, 1}, - {23, 1, 23}, {23, 31, 23}, - {31, 7, 3}, - {28, 20, 7}, + {24, 24, 47}, + {24, 35, 24}, + {24, 47, 24}, + {31, 31, 31}, {32, 32, 9}, - {44, 19, 7}, - {47, 35, 48}, - {47, 48, 47}, - {48, 35, 47} + {35, 47, 48}, + {35, 48, 47}, + {40, 40, 40}, + {47, 48, 47} }; int mnk_count = sizeof(mnk) / sizeof(int) / 3; From 545f075ddb55f0b9a47414b8c4d14577ed3d1271 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 11 Oct 2024 16:22:37 +0800 Subject: [PATCH 46/55] fp16 pack8 output, cc --- src/layer/arm/gemm_arm.cpp | 10 +- src/layer/arm/gemm_int8_fp16s.h | 10685 +++++++++++------------------- 2 files changed, 3848 insertions(+), 6847 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index baf0da4e6504..200158f0911d 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -6201,7 +6201,15 @@ int Gemm_arm::forward_int8(const std::vector& bottom_blobs, std::vector Date: Fri, 11 Oct 2024 16:48:13 +0800 Subject: [PATCH 47/55] fix --- src/layer/arm/gemm_arm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 200158f0911d..aee3d9d3d28b 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -6201,7 +6201,7 @@ int Gemm_arm::forward_int8(const std::vector& bottom_blobs, std::vector Date: Fri, 11 Oct 2024 17:33:06 +0800 Subject: [PATCH 48/55] cc --- src/layer/arm/gemm_int8_bf16s.h | 9884 +++++++++++-------------------- 1 file changed, 3519 insertions(+), 6365 deletions(-) diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 3ec9f15000a4..220e2a696aa3 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -4199,800 +4199,567 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); #if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 #else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - { - _sum8 = vrev64q_s32(_sum8); - _sum9 = vrev64q_s32(_sum9); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sumc = vrev64q_s32(_sumc); - _sumd = vrev64q_s32(_sumd); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - _sum8 = vextq_s32(_sum8, _sum8, 2); - _sum9 = vextq_s32(_sum9, _sum9, 2); - _suma = vextq_s32(_suma, _suma, 2); - _sumb = vextq_s32(_sumb, _sumb, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - _sum9 = vrev64q_s32(_sum9); - _sumb = vrev64q_s32(_sumb); - _sumd = vrev64q_s32(_sumd); - _sumf = vrev64q_s32(_sumf); - } - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } #endif - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c1); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c1); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c1); - _ff = vaddq_f32(_ff, _c1); - } - if (broadcast_type_C == 3) + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) { - if (c_elempack == 1) + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 5); - _c45 = vld1q_u16(pC + c_hstep * 6); - _c67 = vld1q_u16(pC + c_hstep * 7); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - _c4 = bfloat2float(vget_low_u16(_c45)); - _c5 = bfloat2float(vget_high_u16(_c45)); - _c6 = bfloat2float(vget_low_u16(_c67)); - _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - else // if (c_elempack == 4) + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + 16); - uint16x8_t _c67 = vld1q_u16(pC + 24); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c45 = vld1q_u16(pC + c_hstep * 4 + 16); - _c67 = vld1q_u16(pC + c_hstep * 4 + 24); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - _c4 = bfloat2float(vget_low_u16(_c45)); - _c5 = bfloat2float(vget_high_u16(_c45)); - _c6 = bfloat2float(vget_low_u16(_c67)); - _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); } - } - if (broadcast_type_C == 4) - { - uint16x8_t _cc = vld1q_u16(pC); - float32x4_t _cc0 = bfloat2float(vget_low_u16(_cc)); - float32x4_t _cc1 = bfloat2float(vget_high_u16(_cc)); - if (beta != 1.f) + else { float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } - _c0 = vdupq_laneq_f32(_cc0, 0); - _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); pC += 8; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); - vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f8), float2bfloat(_f9))); - vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_fa), float2bfloat(_fb))); - vst1q_u16(p0 + out_hstep * 4 + 16, vcombine_u16(float2bfloat(_fc), float2bfloat(_fd))); - vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(float2bfloat(_fe), float2bfloat(_ff))); - - pp += 64; + else // if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _cc = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_cc)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_cc)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + uint16x4_t _bf8 = float2bfloat(_f8); + uint16x4_t _bf9 = float2bfloat(_f9); + uint16x4_t _bfa = float2bfloat(_fa); + uint16x4_t _bfb = float2bfloat(_fb); + uint16x4_t _bfc = float2bfloat(_fc); + uint16x4_t _bfd = float2bfloat(_fd); + uint16x4_t _bfe = float2bfloat(_fe); + uint16x4_t _bff = float2bfloat(_ff); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + vst1q_u16(p0 + 16, vcombine_u16(_bf4, _bf5)); + vst1q_u16(p0 + 24, vcombine_u16(_bf6, _bf7)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf8, _bf9)); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_bfa, _bfb)); + vst1q_u16(p0 + out_hstep * 4 + 16, vcombine_u16(_bfc, _bfd)); + vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(_bfe, _bff)); p0 += 32; } + else // if (out_elempack == 1) + { + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); + transpose4x4_u16(_bf8, _bf9, _bfa, _bfb); + transpose4x4_u16(_bfc, _bfd, _bfe, _bff); + vst1q_u16(p0, vcombine_u16(_bf0, _bf4)); + vst1q_u16(p0 + out_hstep, vcombine_u16(_bf1, _bf5)); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(_bf2, _bf6)); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(_bf3, _bf7)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf8, _bfc)); + vst1q_u16(p0 + out_hstep * 5, vcombine_u16(_bf9, _bfd)); + vst1q_u16(p0 + out_hstep * 6, vcombine_u16(_bfa, _bfe)); + vst1q_u16(p0 + out_hstep * 7, vcombine_u16(_bfb, _bff)); + p0 += 8; + } + + pp += 64; + } #endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 #else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c1); - _f7 = vaddq_f32(_f7, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - float32x4_t _c2 = bfloat2float(_cc2); - float32x4_t _c3 = bfloat2float(_cc3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _cc0 = vld1_u16(pC + c_hstep * 4); - _cc1 = vld1_u16(pC + c_hstep * 5); - _cc2 = vld1_u16(pC + c_hstep * 6); - _cc3 = vld1_u16(pC + c_hstep * 7); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - _c2 = bfloat2float(_cc2); - _c3 = bfloat2float(_cc3); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 16; - } - } - if (broadcast_type_C == 4) - { - float32x4_t _c = bfloat2float(vld1_u16(pC)); - _c = vmulq_n_f32(_c, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); -#endif - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - pC += 4; - } - } + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - if (alpha != 1.f) + if (pC) + { + if (broadcast_type_C == 0) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); - vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); - - pp += 32; - p0 += 16; - } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 -#else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - uint16x8_t _c01; - uint16x8_t _c23; - if (c_elempack == 1) - { - _c01 = uint16x8_t(); - _c01 = vsetq_lane_u16(pC[0], _c01, 0); - _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); - _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); - _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); - _c01 = vsetq_lane_u16(pC[1], _c01, 4); - _c01 = vsetq_lane_u16(pC[c_hstep + 1], _c01, 5); - _c01 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c01, 6); - _c01 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c01, 7); - _c23 = uint16x8_t(); - _c23 = vsetq_lane_u16(pC[c_hstep * 4], _c23, 0); - _c23 = vsetq_lane_u16(pC[c_hstep * 5], _c23, 1); - _c23 = vsetq_lane_u16(pC[c_hstep * 6], _c23, 2); - _c23 = vsetq_lane_u16(pC[c_hstep * 7], _c23, 3); - _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); - _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); - _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); - _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); - pC += 2; - } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + c_hstep * 4); - pC += 8; - } - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -5008,1097 +4775,414 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 2; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - - pp += 16; - p0 += 8; - } - for (; jj < max_jj; jj++) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + if (beta == 1.f) { - uint16x8_t _c01 = uint16x8_t(); - _c01 = vsetq_lane_u16(pC[0], _c01, 0); - _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); - _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); - _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); - _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); - _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); - _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); - _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - pC += 1; + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); } - else // if (c_elempack == 4) + else { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } + pC += 4; + } + else // if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 16; } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); - - pp += 8; + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf4, _bf5)); + vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_bf6, _bf7)); + p0 += 16; + } + else // if (out_elempack == 1) + { + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + vst1_u16(p0 + out_hstep * 4, _bf4); + vst1_u16(p0 + out_hstep * 5, _bf5); + vst1_u16(p0 + out_hstep * 6, _bf6); + vst1_u16(p0 + out_hstep * 7, _bf7); p0 += 4; } + + pp += 32; } - if (out_elempack == 1) + for (; jj + 1 < max_jj; jj += 2) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); - int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); - int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); - _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); - } + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 #else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - } + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); - float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); - float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); - float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); - float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); - float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); - float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); - float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); - float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); - - if (pC) - { - if (broadcast_type_C == 0) + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + if (c_elempack == 1) + { + _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[1], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep + 1], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c01, 7); + _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[c_hstep * 4], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep * 5], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 6], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 7], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); + pC += 2; + } + else // if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep * 4); + pC += 8; } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - _f8 = vaddq_f32(_f8, _cc4); - _f9 = vaddq_f32(_f9, _cc4); - _fa = vaddq_f32(_fa, _cc5); - _fb = vaddq_f32(_fb, _cc5); - _fc = vaddq_f32(_fc, _cc6); - _fd = vaddq_f32(_fd, _cc6); - _fe = vaddq_f32(_fe, _cc7); - _ff = vaddq_f32(_ff, _cc7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3) + else { - if (c_elempack == 1) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 5); - _c45 = vld1q_u16(pC + c_hstep * 6); - _c67 = vld1q_u16(pC + c_hstep * 7); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - _c4 = bfloat2float(vget_low_u16(_c45)); - _c5 = bfloat2float(vget_high_u16(_c45)); - _c6 = bfloat2float(vget_low_u16(_c67)); - _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; - } - else // if (c_elempack == 4) - { - uint16x8x4_t _cc0 = vld4q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); - _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); - float32x4_t _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); - float32x4_t _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); - float32x4_t _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); - float32x4_t _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); - float32x4_t _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); - float32x4_t _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _cc0 = vld4q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); - _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); - _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); - _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); - _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); - _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); - _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); - _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; - } - } - if (broadcast_type_C == 4) - { - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - _c1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c1); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); - vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f8), float2bfloat(_f9))); - vst1q_u16(p0 + out_hstep * 5, vcombine_u16(float2bfloat(_fa), float2bfloat(_fb))); - vst1q_u16(p0 + out_hstep * 6, vcombine_u16(float2bfloat(_fc), float2bfloat(_fd))); - vst1q_u16(p0 + out_hstep * 7, vcombine_u16(float2bfloat(_fe), float2bfloat(_ff))); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); - pp += 64; + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf2, _bf3)); p0 += 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + else // if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - } -#else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 - { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); - float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); - float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); - float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); - float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); -#endif - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); - float32x4_t _cc4 = vdupq_lane_f32(vget_low_f32(_c1), 0); - float32x4_t _cc5 = vdupq_lane_f32(vget_low_f32(_c1), 1); - float32x4_t _cc6 = vdupq_lane_f32(vget_high_f32(_c1), 0); - float32x4_t _cc7 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); - _f4 = vaddq_f32(_f4, _cc4); - _f5 = vaddq_f32(_f5, _cc5); - _f6 = vaddq_f32(_f6, _cc6); - _f7 = vaddq_f32(_f7, _cc7); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 1)); - float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 5)); - _c2 = bfloat2float(vld1_u16(pC + c_hstep * 6)); - _c3 = bfloat2float(vld1_u16(pC + c_hstep * 7)); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - uint16x4x4_t _cc0 = vld4_u16(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0])); - _f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1])); - _f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2])); - _f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3])); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, bfloat2float(_cc0.val[0]), _beta); - _f1 = vmlaq_f32(_f1, bfloat2float(_cc0.val[1]), _beta); - _f2 = vmlaq_f32(_f2, bfloat2float(_cc0.val[2]), _beta); - _f3 = vmlaq_f32(_f3, bfloat2float(_cc0.val[3]), _beta); - } - _cc0 = vld4_u16(pC + c_hstep * 4); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, bfloat2float(_cc0.val[0])); - _f5 = vaddq_f32(_f5, bfloat2float(_cc0.val[1])); - _f6 = vaddq_f32(_f6, bfloat2float(_cc0.val[2])); - _f7 = vaddq_f32(_f7, bfloat2float(_cc0.val[3])); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, bfloat2float(_cc0.val[0]), _beta); - _f5 = vmlaq_f32(_f5, bfloat2float(_cc0.val[1]), _beta); - _f6 = vmlaq_f32(_f6, bfloat2float(_cc0.val[2]), _beta); - _f7 = vmlaq_f32(_f7, bfloat2float(_cc0.val[3]), _beta); - } - pC += 16; - } - } - if (broadcast_type_C == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - pC += 4; - } - } + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf1, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 5 + 1] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 6 + 1] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 7 + 1] = vget_lane_u16(_bf3, 3); + p0 += 2; + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } + pp += 16; + } + for (; jj < max_jj; jj++) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); - vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); - vst1_u16(p0 + out_hstep * 4, float2bfloat(_f4)); - vst1_u16(p0 + out_hstep * 5, float2bfloat(_f5)); - vst1_u16(p0 + out_hstep * 6, float2bfloat(_f6)); - vst1_u16(p0 + out_hstep * 7, float2bfloat(_f7)); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - pp += 32; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) + if (pC) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - // e0 e1 f0 f1 - // g0 g1 h0 h1 + if (broadcast_type_C == 0) { - int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _sum13 = vzipq_s32(_sum2, _sum3); - _sum0 = _sum02.val[0]; - _sum1 = _sum02.val[1]; - _sum2 = _sum13.val[0]; - _sum3 = _sum13.val[1]; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - // e0 e1 f0 f1 - // g0 g1 h0 h1 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - int32x4x2_t _t0 = vuzpq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vuzpq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_t0.val[0], _t1.val[0]); - int32x4x2_t _t3 = vzipq_s32(_t1.val[1], _t0.val[1]); - _sum0 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4x2_t _descale01 = vzipq_f32(_descale0, _descale0); - float32x4x2_t _descale23 = vzipq_f32(_descale1, _descale1); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale23.val[0]); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale23.val[1]); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc1.val[0]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - } - if (broadcast_type_C == 3) - { - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - uint16x8_t _cc0 = uint16x8_t(); - _cc0 = vsetq_lane_u16(pC[0], _cc0, 0); - _cc0 = vsetq_lane_u16(pC[1], _cc0, 1); - _cc0 = vsetq_lane_u16(pC[c_hstep * 1], _cc0, 2); - _cc0 = vsetq_lane_u16(pC[c_hstep * 1 + 1], _cc0, 3); - _cc0 = vsetq_lane_u16(pC[c_hstep * 2], _cc0, 4); - _cc0 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _cc0, 5); - _cc0 = vsetq_lane_u16(pC[c_hstep * 3], _cc0, 6); - _cc0 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _cc0, 7); - uint16x8_t _cc1 = uint16x8_t(); - _cc1 = vsetq_lane_u16(pC[c_hstep * 4], _cc1, 0); - _cc1 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _cc1, 1); - _cc1 = vsetq_lane_u16(pC[c_hstep * 5], _cc1, 2); - _cc1 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _cc1, 3); - _cc1 = vsetq_lane_u16(pC[c_hstep * 6], _cc1, 4); - _cc1 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _cc1, 5); - _cc1 = vsetq_lane_u16(pC[c_hstep * 7], _cc1, 6); - _cc1 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _cc1, 7); - _c0 = bfloat2float(vget_low_u16(_cc0)); - _c1 = bfloat2float(vget_high_u16(_cc0)); - _c2 = bfloat2float(vget_low_u16(_cc1)); - _c3 = bfloat2float(vget_high_u16(_cc1)); - pC += 2; - } - else // if (c_elempack == 4) - { - uint16x8_t _cc0 = vld1q_u16(pC); - uint16x8_t _cc1 = vld1q_u16(pC + c_hstep * 4); - uint16x8x2_t _cc = vzipq_u16(vcombine_u16(vget_low_u16(_cc0), vget_low_u16(_cc1)), vcombine_u16(vget_high_u16(_cc0), vget_high_u16(_cc1))); - _c0 = bfloat2float(vget_low_u16(_cc.val[0])); - _c1 = bfloat2float(vget_high_u16(_cc.val[0])); - _c2 = bfloat2float(vget_low_u16(_cc.val[1])); - _c3 = bfloat2float(vget_high_u16(_cc.val[1])); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - } - if (broadcast_type_C == 4) + if (c_elempack == 1) { - uint16x4_t _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[1], _c, 1); - _c = vset_lane_u16(pC[0], _c, 2); - _c = vset_lane_u16(pC[1], _c, 3); - _c0 = bfloat2float(_c); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 2; + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + pC += 1; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - uint16x4_t _fb0 = float2bfloat(_f0); - uint16x4_t _fb1 = float2bfloat(_f1); - uint16x4_t _fb2 = float2bfloat(_f2); - uint16x4_t _fb3 = float2bfloat(_f3); - - p0[0] = vget_lane_u16(_fb0, 0); - p0[1] = vget_lane_u16(_fb0, 1); - p0[out_hstep] = vget_lane_u16(_fb0, 2); - p0[out_hstep + 1] = vget_lane_u16(_fb0, 3); - p0[out_hstep * 2] = vget_lane_u16(_fb1, 0); - p0[out_hstep * 2 + 1] = vget_lane_u16(_fb1, 1); - p0[out_hstep * 3] = vget_lane_u16(_fb1, 2); - p0[out_hstep * 3 + 1] = vget_lane_u16(_fb1, 3); - p0[out_hstep * 4] = vget_lane_u16(_fb2, 0); - p0[out_hstep * 4 + 1] = vget_lane_u16(_fb2, 1); - p0[out_hstep * 5] = vget_lane_u16(_fb2, 2); - p0[out_hstep * 5 + 1] = vget_lane_u16(_fb2, 3); - p0[out_hstep * 6] = vget_lane_u16(_fb3, 0); - p0[out_hstep * 6 + 1] = vget_lane_u16(_fb3, 1); - p0[out_hstep * 7] = vget_lane_u16(_fb3, 2); - p0[out_hstep * 7 + 1] = vget_lane_u16(_fb3, 3); - - pp += 16; - p0 += 2; - } - for (; jj < max_jj; jj++) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) + else // if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - uint16x8_t _c = uint16x8_t(); - _c = vsetq_lane_u16(pC[0], _c, 0); - _c = vsetq_lane_u16(pC[c_hstep], _c, 1); - _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); - _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); - _c = vsetq_lane_u16(pC[c_hstep * 4], _c, 4); - _c = vsetq_lane_u16(pC[c_hstep * 5], _c, 5); - _c = vsetq_lane_u16(pC[c_hstep * 6], _c, 6); - _c = vsetq_lane_u16(pC[c_hstep * 7], _c, 7); - _c0 = bfloat2float(vget_low_u16(_c)); - _c1 = bfloat2float(vget_high_u16(_c)); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - pC += 4; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - } - if (broadcast_type_C == 4) + else { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; } + } - uint16x4_t _fb0 = float2bfloat(_f0); - uint16x4_t _fb1 = float2bfloat(_f1); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - p0[0] = vget_lane_u16(_fb0, 0); - p0[out_hstep] = vget_lane_u16(_fb0, 1); - p0[out_hstep * 2] = vget_lane_u16(_fb0, 2); - p0[out_hstep * 3] = vget_lane_u16(_fb0, 3); - p0[out_hstep * 4] = vget_lane_u16(_fb1, 0); - p0[out_hstep * 5] = vget_lane_u16(_fb1, 1); - p0[out_hstep * 6] = vget_lane_u16(_fb1, 2); - p0[out_hstep * 7] = vget_lane_u16(_fb1, 3); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); - pp += 8; + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + p0 += 4; + } + else // if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); p0++; } + + pp += 8; } } for (; ii + 3 < max_ii; ii += 4) @@ -6130,186 +5214,143 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 #else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; + if (c_elempack == 1) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + pC += 8; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - uint16x8_t _c01; - uint16x8_t _c23; - uint16x8_t _c45; - uint16x8_t _c67; - if (c_elempack == 1) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + c_hstep); - _c45 = vld1q_u16(pC + c_hstep * 2); - _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); - pC += 8; - } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + 8); - _c45 = vld1q_u16(pC + 16); - _c67 = vld1q_u16(pC + 24); - pC += 32; - } - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; } - if (broadcast_type_C == 4) + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) { - uint16x8_t _c = vld1q_u16(pC); - float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); - } - _c0 = vdupq_laneq_f32(_cc0, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -6318,921 +5359,752 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 8; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); - vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); - - pp += 32; + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + vst1q_u16(p0 + 16, vcombine_u16(_bf4, _bf5)); + vst1q_u16(p0 + 24, vcombine_u16(_bf6, _bf7)); p0 += 32; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + else // if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); + vst1q_u16(p0, vcombine_u16(_bf0, _bf4)); + vst1q_u16(p0 + out_hstep, vcombine_u16(_bf1, _bf5)); + vst1q_u16(p0 + out_hstep * 2, vcombine_u16(_bf2, _bf6)); + vst1q_u16(p0 + out_hstep * 3, vcombine_u16(_bf3, _bf7)); + p0 += 8; + } + + pp += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 #else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep * 1); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + pC += 4; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep * 1); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - _c2 = bfloat2float(_cc2); - _c3 = bfloat2float(_cc3); - pC += 4; - } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _c = bfloat2float(vld1_u16(pC)); - _c = vmulq_n_f32(_c, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - float32x4_t _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; } + } - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - pp += 16; + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); p0 += 16; } - for (; jj + 1 < max_jj; jj += 2) + else // if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + p0 += 4; + } + + pp += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 #else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - { - _sum1 = vrev64q_s32(_sum1); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - } + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c; + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + pC += 2; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - uint16x8_t _c; - if (c_elempack == 1) - { - _c = uint16x8_t(); - _c = vsetq_lane_u16(pC[0], _c, 0); - _c = vsetq_lane_u16(pC[c_hstep], _c, 1); - _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); - _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); - _c = vsetq_lane_u16(pC[1], _c, 4); - _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); - _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); - _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - pC += 2; - } - else // if (c_elempack == 4) - { - _c = vld1q_u16(pC); - pC += 8; - } - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } + _c = vld1q_u16(pC); + pC += 8; } - if (broadcast_type_C == 4) + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 2; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; } + } - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - pp += 8; + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); p0 += 8; } - for (; jj < max_jj; jj++) + else // if (out_elempack == 1) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf1, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); + p0 += 2; + } + + pp += 8; + } + for (; jj < max_jj; jj++) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + uint16x4_t _c; + if (c_elempack == 1) { - uint16x4_t _c; - if (c_elempack == 1) - { - _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[c_hstep], _c, 1); - _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); - _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c = vld1_u16(pC); - pC += 4; - } - _c0 = bfloat2float(_c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); + _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + pC += 1; } - if (broadcast_type_C == 4) + else // if (c_elempack == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; + _c = vld1_u16(pC); + pC += 4; } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - vst1_u16(p0, float2bfloat(_f0)); + uint16x4_t _bf0 = float2bfloat(_f0); - pp += 4; + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); p0 += 4; } + else // if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0++; + } + + pp += 4; } - if (out_elempack == 1) + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; + + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) + if (broadcast_type_C == 0) + { + c0 = bfloat16_to_float32(pC[0]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const unsigned short*)C + i + ii; + c0 = bfloat16_to_float32(pC[0]) * beta; + c1 = bfloat16_to_float32(pC[1]) * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); +#endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const unsigned short*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + pC = (const unsigned short*)C + j; + } + } -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } -#else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3) + else { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; - if (c_elempack == 1) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - _c4 = bfloat2float(vget_low_u16(_c45)); - _c5 = bfloat2float(vget_high_u16(_c45)); - _c6 = bfloat2float(vget_low_u16(_c67)); - _c7 = bfloat2float(vget_high_u16(_c67)); - pC += 8; - } - else // if (c_elempack == 4) - { - uint16x8x4_t _cc = vld4q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_cc.val[0])); - _c1 = bfloat2float(vget_high_u16(_cc.val[0])); - _c2 = bfloat2float(vget_low_u16(_cc.val[1])); - _c3 = bfloat2float(vget_high_u16(_cc.val[1])); - _c4 = bfloat2float(vget_low_u16(_cc.val[2])); - _c5 = bfloat2float(vget_high_u16(_cc.val[2])); - _c6 = bfloat2float(vget_low_u16(_cc.val[3])); - _c7 = bfloat2float(vget_high_u16(_cc.val[3])); - pC += 32; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - if (broadcast_type_C == 4) + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) { - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); - vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - pp += 32; - p0 += 8; - } + pp += 16; + p0 += 8; + } #endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum1 = vextq_s32(_sum1, _sum1, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); -#endif - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); - } - if (broadcast_type_C == 3) - { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - pC += 4; - } - else // if (c_elempack == 4) - { - uint16x4x4_t _c = vld4_u16(pC); - _c0 = bfloat2float(_c.val[0]); - _c1 = bfloat2float(_c.val[1]); - _c2 = bfloat2float(_c.val[2]); - _c3 = bfloat2float(_c.val[3]); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - _c0 = bfloat2float(vld1_u16(pC)); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 4; } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; } - - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); - vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); - - pp += 16; - p0 += 4; } - for (; jj + 1 < max_jj; jj += 2) + + if (alpha != 1.f) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - { - int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); - _sum0 = _sum01.val[0]; - _sum1 = _sum01.val[1]; - } -#else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - { - int32x4_t _t0 = vuzpq_s32(_sum0, _sum1).val[0]; - int32x4_t _t1 = vuzpq_s32(_sum1, _sum0).val[1]; - int32x4x2_t _t3 = vuzpq_s32(_t0, _t1); - _sum0 = _t3.val[0]; - _sum1 = _t3.val[1]; - } -#endif // __ARM_FEATURE_DOTPROD + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); - float32x4x2_t _descale01 = vzipq_f32(_descale, _descale); + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - } - if (broadcast_type_C == 3) - { - float32x4_t _c1; - if (c_elempack == 1) - { - uint16x8_t _c = uint16x8_t(); - _c = vsetq_lane_u16(pC[0], _c, 0); - _c = vsetq_lane_u16(pC[1], _c, 1); - _c = vsetq_lane_u16(pC[c_hstep], _c, 2); - _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 3); - _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 4); - _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 5); - _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 6); - _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - _c0 = bfloat2float(vget_low_u16(_c)); - _c1 = bfloat2float(vget_high_u16(_c)); - pC += 2; - } - else // if (c_elempack == 4) - { - uint16x8_t _c = vld1q_u16(pC); - uint16x4x2_t _c01 = vzip_u16(vget_low_u16(_c), vget_high_u16(_c)); - _c0 = bfloat2float(_c01.val[0]); - _c1 = bfloat2float(_c01.val[1]); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - } - if (broadcast_type_C == 4) - { - uint16x4_t _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[1], _c, 1); - _c = vset_lane_u16(pC[0], _c, 2); - _c = vset_lane_u16(pC[1], _c, 3); - _c0 = bfloat2float(_c); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 2; - } + _f0 = vaddq_f32(_f0, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[c_hstep], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[1], _c, 1); + _c = vset_lane_u16(pC[0], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; } + } - uint16x4_t _fb0 = float2bfloat(_f0); - uint16x4_t _fb1 = float2bfloat(_f1); + _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = vget_lane_u16(_fb0, 0); - p0[1] = vget_lane_u16(_fb0, 1); - p0[out_hstep] = vget_lane_u16(_fb0, 2); - p0[out_hstep + 1] = vget_lane_u16(_fb0, 3); - p0[out_hstep * 2] = vget_lane_u16(_fb1, 0); - p0[out_hstep * 2 + 1] = vget_lane_u16(_fb1, 1); - p0[out_hstep * 3] = vget_lane_u16(_fb1, 2); - p0[out_hstep * 3 + 1] = vget_lane_u16(_fb1, 3); + uint16x4_t _bf0 = float2bfloat(_f0); - pp += 8; - p0 += 2; - } - for (; jj < max_jj; jj++) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); + + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) - { - uint16x4_t _c; - if (c_elempack == 1) - { - _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[c_hstep], _c, 1); - _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); - _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c = vld1_u16(pC); - pC += 4; - } - _c0 = bfloat2float(_c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; - } + f0 += c0; + f1 += c0; } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[c_hstep]) * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[0]) * beta; + pC += 1; + } + } - _f0 = vmulq_n_f32(_f0, alpha); - - uint16x4_t _fb0 = float2bfloat(_f0); + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; + } - p0[0] = vget_lane_u16(_fb0, 0); - p0[out_hstep] = vget_lane_u16(_fb0, 1); - p0[out_hstep * 2] = vget_lane_u16(_fb0, 2); - p0[out_hstep * 3] = vget_lane_u16(_fb0, 3); + p0[0] = float32_to_bfloat16(f0); + p0[out_hstep] = float32_to_bfloat16(f1); - pp += 4; - p0++; - } + pp += 2; + p0++; } } -#endif // __ARM_NEON - for (; ii + 1 < max_ii; ii += 2) + for (; ii < max_ii; ii += 1) { // out_elempack == 1 unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale = descales[ii]; #if __ARM_NEON - float32x2_t _descale = vld1_f32((const float*)descales + ii); + float32x4_t _descale = vdupq_n_f32(descale); #endif float c0; - float c1; #if __ARM_NEON float32x4_t _c0; - float32x4_t _c1; #endif if (pC) { @@ -7247,10 +6119,8 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { pC = (const unsigned short*)C + i + ii; c0 = bfloat16_to_float32(pC[0]) * beta; - c1 = bfloat16_to_float32(pC[1]) * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); - _c1 = vdupq_n_f32(c1); #endif } if (broadcast_type_C == 3) @@ -7264,516 +6134,212 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } } - // if (out_elempack == 1) - { - int jj = 0; + int jj = 0; #if __ARM_NEON -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 4) + else { - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } + pC += 16; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - pp += 16; - p0 += 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = bfloat2float(vld1_u16(pC)); - float32x4_t _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; - } - if (broadcast_type_C == 4) + else { - _c0 = bfloat2float(vld1_u16(pC)); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 8; } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - - pp += 8; - p0 += 4; } - for (; jj + 1 < max_jj; jj += 2) + + if (alpha != 1.f) { - int32x4_t _sum0 = vld1q_s32(pp); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - float32x2x2_t _descale01 = vzip_f32(_descale, _descale); - float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); - _f0 = vaddq_f32(_f0, _c0011); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - uint16x4_t _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[1], _c, 1); - _c = vset_lane_u16(pC[c_hstep], _c, 2); - _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); - _c0 = bfloat2float(_c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } - if (broadcast_type_C == 4) - { - uint16x4_t _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[1], _c, 1); - _c = vset_lane_u16(pC[0], _c, 2); - _c = vset_lane_u16(pC[1], _c, 3); - _c0 = bfloat2float(_c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - uint16x4_t _bf0 = float2bfloat(_f0); + vst1_u16(p0, float2bfloat(_f0)); - p0[0] = vget_lane_u16(_bf0, 0); - p0[1] = vget_lane_u16(_bf0, 1); - p0[out_hstep] = vget_lane_u16(_bf0, 2); - p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - pp += 4; - p0 += 2; + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + float32x2_t _cc = float32x2_t(); + _cc = vset_lane_f32(bfloat16_to_float32(pC[0]), _cc, 0); + _cc = vset_lane_f32(bfloat16_to_float32(pC[1]), _cc, 1); + _f0 = vmla_n_f32(_f0, _cc, beta); + pC += 2; + } } + + _f0 = vmul_n_f32(_f0, alpha); + + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); + p0[1] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); + + pp += 2; + p0 += 2; + } #endif // __ARM_NEON - for (; jj < max_jj; jj++) - { - float f0 = pp[0] * descale0; - float f1 = pp[1] * descale1; + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - f0 += c0; - f1 += c0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - f1 += c1; - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]) * beta; - f1 += bfloat16_to_float32(pC[c_hstep]) * beta; - pC += 1; - } - if (broadcast_type_C == 4) - { - f0 += bfloat16_to_float32(pC[0]) * beta; - f1 += bfloat16_to_float32(pC[0]) * beta; - pC += 1; - } + f0 += c0; } - - if (alpha != 1.f) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - f0 *= alpha; - f1 *= alpha; + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + pC += 1; } + } - p0[0] = float32_to_bfloat16(f0); - p0[out_hstep] = float32_to_bfloat16(f1); + f0 *= alpha; - pp += 2; - p0++; - } + p0[0] = float32_to_bfloat16(f0); + + pp += 1; + p0++; } } - for (; ii < max_ii; ii += 1) - { - // out_elempack == 1 - unsigned short* p0 = (unsigned short*)top_blob + (i + ii) * out_hstep + j; +} - const float descale = descales[ii]; -#if __ARM_NEON - float32x4_t _descale = vdupq_n_f32(descale); -#endif - - float c0; -#if __ARM_NEON - float32x4_t _c0; -#endif - if (pC) - { - if (broadcast_type_C == 0) - { - c0 = bfloat16_to_float32(pC[0]) * beta; -#if __ARM_NEON - _c0 = vdupq_n_f32(c0); -#endif - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - pC = (const unsigned short*)C + i + ii; - c0 = bfloat16_to_float32(pC[0]) * beta; -#if __ARM_NEON - _c0 = vdupq_n_f32(c0); -#endif - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - pC = (const unsigned short*)C + (i + ii) * c_hstep + j; - } - if (broadcast_type_C == 4) - { - pC = (const unsigned short*)C + j; - } - } - - // if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON - for (; jj + 15 < max_jj; jj += 16) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - - pp += 16; - p0 += 16; - } - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - uint16x8_t _c01 = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - - pp += 8; - p0 += 8; - } - for (; jj + 3 < max_jj; jj += 4) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; - } - } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1_u16(p0, float2bfloat(_f0)); - - pp += 4; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vadd_f32(_f0, vget_low_f32(_c0)); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - float32x2_t _cc = float32x2_t(); - _cc = vset_lane_f32(bfloat16_to_float32(pC[0]), _cc, 0); - _cc = vset_lane_f32(bfloat16_to_float32(pC[1]), _cc, 1); - _f0 = vmla_n_f32(_f0, _cc, beta); - pC += 2; - } - } - - _f0 = vmul_n_f32(_f0, alpha); - - p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); - p0[1] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); - - pp += 2; - p0 += 2; - } -#endif // __ARM_NEON - for (; jj < max_jj; jj++) - { - float f0 = pp[0] * descale; - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]) * beta; - pC += 1; - } - } - - f0 *= alpha; - - p0[0] = float32_to_bfloat16(f0); - - pp += 1; - p0++; - } - } - } -} - -static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) -{ -#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 - if (ncnn::cpu_support_arm_asimddp()) - { - transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); - return; - } +static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD && !__ARM_FEATURE_MATMUL_INT8 + if (ncnn::cpu_support_arm_asimddp()) + { + transpose_unpack_output_tile_int32_to_bf16_asimddp(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta); + return; + } #endif const int out_elempack = top_blob.elempack; @@ -7823,1637 +6389,929 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); #if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); - int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); - int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); - _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); - } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - } + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); - float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); - float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); - float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); - float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); - float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); - float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); - float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); - float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); - - if (pC) - { - if (broadcast_type_C == 0) + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - _f8 = vaddq_f32(_f8, _cc4); - _f9 = vaddq_f32(_f9, _cc4); - _fa = vaddq_f32(_fa, _cc5); - _fb = vaddq_f32(_fb, _cc5); - _fc = vaddq_f32(_fc, _cc6); - _fd = vaddq_f32(_fd, _cc6); - _fe = vaddq_f32(_fe, _cc7); - _ff = vaddq_f32(_ff, _cc7); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + 4)); - float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep)); - float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep + 4)); - float32x4_t _c4 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - float32x4_t _c5 = bfloat2float(vld1_u16(pC + c_hstep * 2 + 4)); - float32x4_t _c6 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - float32x4_t _c7 = bfloat2float(vld1_u16(pC + c_hstep * 3 + 4)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4 + 4)); - _c2 = bfloat2float(vld1_u16(pC + c_hstep * 5)); - _c3 = bfloat2float(vld1_u16(pC + c_hstep * 5 + 4)); - _c4 = bfloat2float(vld1_u16(pC + c_hstep * 6)); - _c5 = bfloat2float(vld1_u16(pC + c_hstep * 6 + 4)); - _c6 = bfloat2float(vld1_u16(pC + c_hstep * 7)); - _c7 = bfloat2float(vld1_u16(pC + c_hstep * 7 + 4)); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - else // if (c_elempack == 4) + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) { - uint16x8x4_t _cc0 = vld4q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); - _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); - float32x4_t _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); - float32x4_t _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); - float32x4_t _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); - float32x4_t _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); - float32x4_t _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); - float32x4_t _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _cc0 = vld4q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_cc0.val[0])); - _c1 = bfloat2float(vget_high_u16(_cc0.val[0])); - _c2 = bfloat2float(vget_low_u16(_cc0.val[1])); - _c3 = bfloat2float(vget_high_u16(_cc0.val[1])); - _c4 = bfloat2float(vget_low_u16(_cc0.val[2])); - _c5 = bfloat2float(vget_high_u16(_cc0.val[2])); - _c6 = bfloat2float(vget_low_u16(_cc0.val[3])); - _c7 = bfloat2float(vget_high_u16(_cc0.val[3])); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); } - } - if (broadcast_type_C == 4) - { - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - _c1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) + else { float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c1); pC += 8; } - } + else // if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + _c4 = bfloat2float(vget_low_u16(_c45)); + _c5 = bfloat2float(vget_high_u16(_c45)); + _c6 = bfloat2float(vget_low_u16(_c67)); + _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 32; + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + uint16x8_t _bf0 = vcombine_u16(float2bfloat(_f0), float2bfloat(_f8)); + uint16x8_t _bf1 = vcombine_u16(float2bfloat(_f1), float2bfloat(_f9)); + uint16x8_t _bf2 = vcombine_u16(float2bfloat(_f2), float2bfloat(_fa)); + uint16x8_t _bf3 = vcombine_u16(float2bfloat(_f3), float2bfloat(_fb)); + uint16x8_t _bf4 = vcombine_u16(float2bfloat(_f4), float2bfloat(_fc)); + uint16x8_t _bf5 = vcombine_u16(float2bfloat(_f5), float2bfloat(_fd)); + uint16x8_t _bf6 = vcombine_u16(float2bfloat(_f6), float2bfloat(_fe)); + uint16x8_t _bf7 = vcombine_u16(float2bfloat(_f7), float2bfloat(_ff)); + + if (out_elempack == 4) + { + uint16x8x4_t _bfa; + uint16x8x4_t _bfb; + _bfa.val[0] = _bf0; + _bfa.val[1] = _bf1; + _bfa.val[2] = _bf2; + _bfa.val[3] = _bf3; + _bfb.val[0] = _bf4; + _bfb.val[1] = _bf5; + _bfb.val[2] = _bf6; + _bfb.val[3] = _bf7; + vst4q_u16(p0, _bfa); + vst4q_u16(p0 + out_hstep * 4, _bfb); + } + else // if (out_elempack == 1) + { + vst1q_u16(p0, _bf0); + vst1q_u16(p0 + out_hstep, _bf1); + vst1q_u16(p0 + out_hstep * 2, _bf2); + vst1q_u16(p0 + out_hstep * 3, _bf3); + vst1q_u16(p0 + out_hstep * 4, _bf4); + vst1q_u16(p0 + out_hstep * 5, _bf5); + vst1q_u16(p0 + out_hstep * 6, _bf6); + vst1q_u16(p0 + out_hstep * 7, _bf7); + } + + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); - vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f8), float2bfloat(_fa))); - vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_fc), float2bfloat(_fe))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); - vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_f5), float2bfloat(_f7))); - vst1q_u16(p0 + out_hstep * 4 + 16, vcombine_u16(float2bfloat(_f9), float2bfloat(_fb))); - vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(float2bfloat(_fd), float2bfloat(_ff))); - pp += 64; - p0 += out_hstep * 8; +#else + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); +#endif // __ARM_FEATURE_DOTPROD -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); - float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); - float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); - float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); - float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); -#endif - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); -#if __aarch64__ - _cc0 = vdupq_laneq_f32(_c1, 0); - _cc1 = vdupq_laneq_f32(_c1, 1); - _cc2 = vdupq_laneq_f32(_c1, 2); - _cc3 = vdupq_laneq_f32(_c1, 3); -#else - _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); - _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); - _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); - _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif - _f4 = vaddq_f32(_f4, _cc0); - _f5 = vaddq_f32(_f5, _cc1); - _f6 = vaddq_f32(_f6, _cc2); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - if (c_elempack == 1) + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); + if (beta == 1.f) { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - float32x4_t _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - float32x4_t _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 5)); - _c2 = bfloat2float(vld1_u16(pC + c_hstep * 6)); - _c3 = bfloat2float(vld1_u16(pC + c_hstep * 7)); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - else // if (c_elempack == 4) + else { - uint16x4x4_t _cc0 = vld4_u16(pC); - _c0 = bfloat2float(_cc0.val[0]); - _c1 = bfloat2float(_cc0.val[1]); - float32x4_t _c2 = bfloat2float(_cc0.val[2]); - float32x4_t _c3 = bfloat2float(_cc0.val[3]); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _cc0 = vld4_u16(pC + c_hstep * 4); - _c0 = bfloat2float(_cc0.val[0]); - _c1 = bfloat2float(_cc0.val[1]); - _c2 = bfloat2float(_cc0.val[2]); - _c3 = bfloat2float(_cc0.val[3]); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); pC += 4; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - vst1q_u16(p0 + 16, vcombine_u16(float2bfloat(_f4), float2bfloat(_f5))); - vst1q_u16(p0 + 24, vcombine_u16(float2bfloat(_f6), float2bfloat(_f7))); - pp += 32; - p0 += out_hstep * 4; - } - } - if (out_elempack == 1) - { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); -#else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - { - _sum8 = vrev64q_s32(_sum8); - _sum9 = vrev64q_s32(_sum9); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sumc = vrev64q_s32(_sumc); - _sumd = vrev64q_s32(_sumd); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - _sum8 = vextq_s32(_sum8, _sum8, 2); - _sum9 = vextq_s32(_sum9, _sum9, 2); - _suma = vextq_s32(_suma, _suma, 2); - _sumb = vextq_s32(_sumb, _sumb, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - _sum9 = vrev64q_s32(_sum9); - _sumb = vrev64q_s32(_sumb); - _sumd = vrev64q_s32(_sumd); - _sumf = vrev64q_s32(_sumf); - } - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); -#endif // __ARM_FEATURE_DOTPROD - - if (pC) - { - if (broadcast_type_C == 0) + else // if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c1); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c1); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c1); - _ff = vaddq_f32(_ff, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 5); - _c45 = vld1q_u16(pC + c_hstep * 6); - _c67 = vld1q_u16(pC + c_hstep * 7); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - _c4 = bfloat2float(vget_low_u16(_c45)); - _c5 = bfloat2float(vget_high_u16(_c45)); - _c6 = bfloat2float(vget_low_u16(_c67)); - _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - else // if (c_elempack == 4) + else { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + 16); - uint16x8_t _c67 = vld1q_u16(pC + 24); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c45 = vld1q_u16(pC + c_hstep * 4 + 16); - _c67 = vld1q_u16(pC + c_hstep * 4 + 24); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - _c4 = bfloat2float(vget_low_u16(_c45)); - _c5 = bfloat2float(vget_high_u16(_c45)); - _c6 = bfloat2float(vget_low_u16(_c67)); - _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - uint16x8_t _c = vld1q_u16(pC); - float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else { float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } - _c0 = vdupq_laneq_f32(_cc0, 0); - _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - pC += 8; + pC += 16; } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f8))); - vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f9))); - vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_fa))); - vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f3), float2bfloat(_fb))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f4), float2bfloat(_fc))); - vst1q_u16(p0 + out_hstep * 5, vcombine_u16(float2bfloat(_f5), float2bfloat(_fd))); - vst1q_u16(p0 + out_hstep * 6, vcombine_u16(float2bfloat(_f6), float2bfloat(_fe))); - vst1q_u16(p0 + out_hstep * 7, vcombine_u16(float2bfloat(_f7), float2bfloat(_ff))); + uint16x8_t _bf0 = vcombine_u16(float2bfloat(_f0), float2bfloat(_f4)); + uint16x8_t _bf1 = vcombine_u16(float2bfloat(_f1), float2bfloat(_f5)); + uint16x8_t _bf2 = vcombine_u16(float2bfloat(_f2), float2bfloat(_f6)); + uint16x8_t _bf3 = vcombine_u16(float2bfloat(_f3), float2bfloat(_f7)); - pp += 64; - p0 += out_hstep * 8; + if (out_elempack == 4) + { + uint16x8x4_t _bf; + _bf.val[0] = _bf0; + _bf.val[1] = _bf1; + _bf.val[2] = _bf2; + _bf.val[3] = _bf3; + vst4q_u16(p0, _bf); } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + else // if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + vst1q_u16(p0, _bf0); + vst1q_u16(p0 + out_hstep, _bf1); + vst1q_u16(p0 + out_hstep * 2, _bf2); + vst1q_u16(p0 + out_hstep * 3, _bf3); + } -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 + pp += 32; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 #else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 1) + { + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + + uint16x8_t _c23 = uint16x8_t(); + _c23 = vsetq_lane_u16(pC[1], _c23, 0); + _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); + _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); + _c23 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c23, 3); + _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); + _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); + _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); + _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c1); - _f7 = vaddq_f32(_f7, _c1); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_low_u16(_c23)); + _c2 = bfloat2float(vget_high_u16(_c01)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 2; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - if (c_elempack == 1) - { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - float32x4_t _c2 = bfloat2float(_cc2); - float32x4_t _c3 = bfloat2float(_cc3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _cc0 = vld1_u16(pC + c_hstep * 4); - _cc1 = vld1_u16(pC + c_hstep * 5); - _cc2 = vld1_u16(pC + c_hstep * 6); - _cc3 = vld1_u16(pC + c_hstep * 7); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - _c2 = bfloat2float(_cc2); - _c3 = bfloat2float(_cc3); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 16; - } + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 8; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _c = bfloat2float(vld1_u16(pC)); - _c = vmulq_n_f32(_c, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f4))); - vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f5))); - vst1q_u16(p0 + out_hstep * 2, vcombine_u16(float2bfloat(_f2), float2bfloat(_f6))); - vst1q_u16(p0 + out_hstep * 3, vcombine_u16(float2bfloat(_f3), float2bfloat(_f7))); - - pp += 32; - p0 += out_hstep * 4; } - for (; jj + 1 < max_jj; jj += 2) + + if (alpha != 1.f) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 -#else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } -#endif // __ARM_FEATURE_DOTPROD + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); + vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) - { - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - uint16x8_t _c01 = uint16x8_t(); - _c01 = vsetq_lane_u16(pC[0], _c01, 0); - _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); - _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); - _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); - _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); - _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); - _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); - _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); - - uint16x8_t _c23 = uint16x8_t(); - _c23 = vsetq_lane_u16(pC[1], _c23, 0); - _c23 = vsetq_lane_u16(pC[c_hstep + 1], _c23, 1); - _c23 = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c23, 2); - _c23 = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c23, 3); - _c23 = vsetq_lane_u16(pC[c_hstep * 4 + 1], _c23, 4); - _c23 = vsetq_lane_u16(pC[c_hstep * 5 + 1], _c23, 5); - _c23 = vsetq_lane_u16(pC[c_hstep * 6 + 1], _c23, 6); - _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); - - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_low_u16(_c23)); - _c2 = bfloat2float(vget_high_u16(_c01)); - _c3 = bfloat2float(vget_high_u16(_c23)); - pC += 2; - } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 2; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); - vst1q_u16(p0 + out_hstep, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); - - pp += 16; - p0 += out_hstep * 2; - } - for (; jj < max_jj; jj += 1) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); + uint16x8_t _c01 = uint16x8_t(); + _c01 = vsetq_lane_u16(pC[0], _c01, 0); + _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); + _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); + _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); + _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); + _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); + _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); + _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + pC += 1; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - if (c_elempack == 1) - { - uint16x8_t _c01 = uint16x8_t(); - _c01 = vsetq_lane_u16(pC[0], _c01, 0); - _c01 = vsetq_lane_u16(pC[c_hstep], _c01, 1); - _c01 = vsetq_lane_u16(pC[c_hstep * 2], _c01, 2); - _c01 = vsetq_lane_u16(pC[c_hstep * 3], _c01, 3); - _c01 = vsetq_lane_u16(pC[c_hstep * 4], _c01, 4); - _c01 = vsetq_lane_u16(pC[c_hstep * 5], _c01, 5); - _c01 = vsetq_lane_u16(pC[c_hstep * 6], _c01, 6); - _c01 = vsetq_lane_u16(pC[c_hstep * 7], _c01, 7); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - pC += 4; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; + _f1 = vaddq_f32(_f1, _c1); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; } + } - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - pp += 8; - p0 += out_hstep; + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); } + + vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); + pp += 8; + p0 += out_hstep; } } for (; ii + 3 < max_ii; ii += 4) @@ -9485,913 +7343,558 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); #if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - } + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 +#else + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + uint16x8_t _c01; + uint16x8_t _c23; + uint16x8_t _c45; + uint16x8_t _c67; + if (c_elempack == 1) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep); + _c45 = vld1q_u16(pC + c_hstep * 2); + _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + pC += 8; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; - if (c_elempack == 1) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - _c4 = bfloat2float(vget_low_u16(_c45)); - _c5 = bfloat2float(vget_high_u16(_c45)); - _c6 = bfloat2float(vget_low_u16(_c67)); - _c7 = bfloat2float(vget_high_u16(_c67)); - pC += 8; - } - else // if (c_elempack == 4) - { - uint16x8x4_t _c = vld4q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c.val[0])); - _c1 = bfloat2float(vget_high_u16(_c.val[0])); - _c2 = bfloat2float(vget_low_u16(_c.val[1])); - _c3 = bfloat2float(vget_high_u16(_c.val[1])); - _c4 = bfloat2float(vget_low_u16(_c.val[2])); - _c5 = bfloat2float(vget_high_u16(_c.val[2])); - _c6 = bfloat2float(vget_low_u16(_c.val[3])); - _c7 = bfloat2float(vget_high_u16(_c.val[3])); - pC += 32; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; } - if (broadcast_type_C == 4) + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); + float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); + float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); + float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); + if (beta == 1.f) { - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - pC += 8; + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f4), float2bfloat(_f6))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); - vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(float2bfloat(_f5), float2bfloat(_f7))); - - pp += 32; - p0 += out_hstep * 8; - } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + uint16x4_t _bf4 = float2bfloat(_f4); + uint16x4_t _bf5 = float2bfloat(_f5); + uint16x4_t _bf6 = float2bfloat(_f6); + uint16x4_t _bf7 = float2bfloat(_f7); + + if (out_elempack == 4) + { + uint16x4x4_t _bfa; + uint16x4x4_t _bfb; + _bfa.val[0] = _bf0; + _bfa.val[1] = _bf1; + _bfa.val[2] = _bf2; + _bfa.val[3] = _bf3; + _bfb.val[0] = _bf4; + _bfb.val[1] = _bf5; + _bfb.val[2] = _bf6; + _bfb.val[3] = _bf7; + vst4_u16(p0, _bfa); + vst4_u16(p0 + out_hstep * 4, _bfb); + } + else // if (out_elempack == 1) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + vst1_u16(p0 + out_hstep * 4, _bf4); + vst1_u16(p0 + out_hstep * 5, _bf5); + vst1_u16(p0 + out_hstep * 6, _bf6); + vst1_u16(p0 + out_hstep * 7, _bf7); + } + + pp += 32; + p0 += out_hstep * 8; + } #endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - { - int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); - } + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 #else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - { - _sum1 = vextq_s32(_sum1, _sum1, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); -#endif - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); - } - if (broadcast_type_C == 3) - { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - _c2 = bfloat2float(vld1_u16(pC + c_hstep * 2)); - _c3 = bfloat2float(vld1_u16(pC + c_hstep * 3)); - pC += 4; - } - else // if (c_elempack == 4) - { - uint16x4x4_t _c = vld4_u16(pC); - _c0 = bfloat2float(_c.val[0]); - _c1 = bfloat2float(_c.val[1]); - _c2 = bfloat2float(_c.val[2]); - _c3 = bfloat2float(_c.val[3]); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - } - if (broadcast_type_C == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 4; - } - } + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (alpha != 1.f) + if (pC) + { + if (broadcast_type_C == 0) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f1))); - vst1q_u16(p0 + 8, vcombine_u16(float2bfloat(_f2), float2bfloat(_f3))); - - pp += 16; - p0 += out_hstep * 4; - } - } - if (out_elempack == 1) - { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 -#else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); + pC += 4; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - uint16x8_t _c01; - uint16x8_t _c23; - uint16x8_t _c45; - uint16x8_t _c67; - if (c_elempack == 1) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + c_hstep); - _c45 = vld1q_u16(pC + c_hstep * 2); - _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); - pC += 8; - } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + 8); - _c45 = vld1q_u16(pC + 16); - _c67 = vld1q_u16(pC + 24); - pC += 32; - } + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - float32x4_t _c4 = bfloat2float(vget_low_u16(_c45)); - float32x4_t _c5 = bfloat2float(vget_high_u16(_c45)); - float32x4_t _c6 = bfloat2float(vget_low_u16(_c67)); - float32x4_t _c7 = bfloat2float(vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - uint16x8_t _c = vld1q_u16(pC); - float32x4_t _cc0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _cc1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); - } - _c0 = vdupq_laneq_f32(_cc0, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - pC += 8; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); - vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); - vst1_u16(p0 + out_hstep * 4, float2bfloat(_f4)); - vst1_u16(p0 + out_hstep * 5, float2bfloat(_f5)); - vst1_u16(p0 + out_hstep * 6, float2bfloat(_f6)); - vst1_u16(p0 + out_hstep * 7, float2bfloat(_f7)); - - pp += 32; - p0 += out_hstep * 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 -#else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); } - if (broadcast_type_C == 3) + else { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - _c2 = bfloat2float(_cc2); - _c3 = bfloat2float(_cc3); - pC += 4; - } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - if (broadcast_type_C == 4) - { - float32x4_t _c = bfloat2float(vld1_u16(pC)); - _c = vmulq_n_f32(_c, beta); + } + if (broadcast_type_C == 4) + { + float32x4_t _c = bfloat2float(vld1_u16(pC)); + _c = vmulq_n_f32(_c, beta); #if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - float32x4_t _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); #else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); #endif - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - pC += 4; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - vst1_u16(p0 + out_hstep * 2, float2bfloat(_f2)); - vst1_u16(p0 + out_hstep * 3, float2bfloat(_f3)); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); - pp += 16; - p0 += out_hstep * 4; + if (out_elempack == 4) + { + uint16x4x4_t _bf; + _bf.val[0] = _bf0; + _bf.val[1] = _bf1; + _bf.val[2] = _bf2; + _bf.val[3] = _bf3; + vst4_u16(p0, _bf); } - for (; jj + 1 < max_jj; jj += 2) + else // if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep, _bf1); + vst1_u16(p0 + out_hstep * 2, _bf2); + vst1_u16(p0 + out_hstep * 3, _bf3); + } + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 #else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - { - _sum1 = vrev64q_s32(_sum1); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - } + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + uint16x8_t _c; + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _c = uint16x8_t(); + _c = vsetq_lane_u16(pC[0], _c, 0); + _c = vsetq_lane_u16(pC[c_hstep], _c, 1); + _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); + _c = vsetq_lane_u16(pC[1], _c, 4); + _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); + _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); + _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); + pC += 2; } - if (broadcast_type_C == 3) + else // if (c_elempack == 4) { - uint16x8_t _c; - if (c_elempack == 1) - { - _c = uint16x8_t(); - _c = vsetq_lane_u16(pC[0], _c, 0); - _c = vsetq_lane_u16(pC[c_hstep], _c, 1); - _c = vsetq_lane_u16(pC[c_hstep * 2], _c, 2); - _c = vsetq_lane_u16(pC[c_hstep * 3], _c, 3); - _c = vsetq_lane_u16(pC[1], _c, 4); - _c = vsetq_lane_u16(pC[c_hstep + 1], _c, 5); - _c = vsetq_lane_u16(pC[c_hstep * 2 + 1], _c, 6); - _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); - pC += 2; - } - else // if (c_elempack == 4) - { - _c = vld1q_u16(pC); - pC += 8; - } - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } + _c = vld1q_u16(pC); + pC += 8; } - if (broadcast_type_C == 4) + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 2; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + float32x4_t _c1 = vdupq_n_f32(bfloat16_to_float32(pC[1]) * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; } - - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - - pp += 8; - p0 += out_hstep * 2; } - for (; jj < max_jj; jj += 1) + + if (alpha != 1.f) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1_u16(p0, float2bfloat(_f0)); + vst1_u16(p0 + out_hstep, float2bfloat(_f1)); - if (pC) + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + uint16x4_t _c; + if (c_elempack == 1) { - uint16x4_t _c; - if (c_elempack == 1) - { - _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[c_hstep], _c, 1); - _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); - _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c = vld1_u16(pC); - pC += 4; - } - _c0 = bfloat2float(_c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); + _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[c_hstep * 2], _c, 2); + _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); + pC += 1; } - if (broadcast_type_C == 4) + else // if (c_elempack == 4) { - _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; + _c = vld1_u16(pC); + pC += 4; } + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(bfloat16_to_float32(pC[0]) * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; + } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - vst1_u16(p0, float2bfloat(_f0)); - pp += 4; - p0 += out_hstep; - } + vst1_u16(p0, float2bfloat(_f0)); + pp += 4; + p0 += out_hstep; } } #endif // __ARM_NEON @@ -10441,495 +7944,306 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + int jj = 0; #if __ARM_NEON - if (out_elempack == 4) - { - int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; - } - if (broadcast_type_C == 4) - { - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - _c1 = bfloat2float(vget_high_u16(_c)); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 8; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); } - - vst1q_u16(p0, vcombine_u16(float2bfloat(_f0), float2bfloat(_f2))); - vst1q_u16(p0 + out_hstep * 4, vcombine_u16(float2bfloat(_f1), float2bfloat(_f3))); - - pp += 16; - p0 += out_hstep * 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - // a0 a1 a2 a3 - // b0 b1 b2 b3 - - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3) + else { - // c_elempack == 1 - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - if (broadcast_type_C == 4) + pC += 8; + } + if (broadcast_type_C == 4) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + _c1 = bfloat2float(vget_high_u16(_c)); + if (beta != 1.f) { - _c0 = bfloat2float(vld1_u16(pC)); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + 4, float2bfloat(_f1)); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); - pp += 8; - p0 += out_hstep * 4; + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf2)); + vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf1, _bf3)); } - } -#endif // __ARM_NEON - if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) + else // if (out_elempack == 1) { - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf2, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 3 + 1] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 4 + 1] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 5 + 1] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 6 + 1] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 7 + 1] = vget_lane_u16(_bf3, 3); + } - int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _sum13 = vzipq_s32(_sum1, _sum3); + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 - float32x4_t _descale = vcombine_f32(_descale01, _descale01); + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum02.val[0]), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum02.val[1]), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum13.val[0]), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum13.val[1]), _descale); + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - _f1 = vaddq_f32(_f1, _cc); - _f2 = vaddq_f32(_f2, _cc); - _f3 = vaddq_f32(_f3, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x4x2_t _c02 = vzip_u16(vget_low_u16(_c01), vget_low_u16(_c23)); - uint16x4x2_t _c13 = vzip_u16(vget_high_u16(_c01), vget_high_u16(_c23)); - _c0 = bfloat2float(_c02.val[0]); - float32x4_t _c1 = bfloat2float(_c02.val[1]); - float32x4_t _c2 = bfloat2float(_c13.val[0]); - float32x4_t _c3 = bfloat2float(_c13.val[1]); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; - } - if (broadcast_type_C == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x4x2_t _cc0 = vzip_u16(vget_low_u16(_c01), vget_low_u16(_c01)); - uint16x4x2_t _cc1 = vzip_u16(vget_high_u16(_c01), vget_high_u16(_c01)); - _c0 = bfloat2float(_cc0.val[0]); - float32x4_t _c1 = bfloat2float(_cc0.val[1]); - float32x4_t _c2 = bfloat2float(_cc1.val[0]); - float32x4_t _c3 = bfloat2float(_cc1.val[1]); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - - uint16x4_t _bf0 = float2bfloat(_f0); - uint16x4_t _bf1 = float2bfloat(_f1); - uint16x4_t _bf2 = float2bfloat(_f2); - uint16x4_t _bf3 = float2bfloat(_f3); - - p0[0] = vget_lane_u16(_bf0, 0); - p0[1] = vget_lane_u16(_bf0, 1); - p0[out_hstep] = vget_lane_u16(_bf0, 2); - p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 2] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 3] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf2, 0); - p0[out_hstep * 4 + 1] = vget_lane_u16(_bf2, 1); - p0[out_hstep * 5] = vget_lane_u16(_bf2, 2); - p0[out_hstep * 5 + 1] = vget_lane_u16(_bf2, 3); - p0[out_hstep * 6] = vget_lane_u16(_bf3, 0); - p0[out_hstep * 6 + 1] = vget_lane_u16(_bf3, 1); - p0[out_hstep * 7] = vget_lane_u16(_bf3, 2); - p0[out_hstep * 7 + 1] = vget_lane_u16(_bf3, 3); - - pp += 16; - p0 += out_hstep * 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - // a0 a1 a2 a3 - // b0 b1 b2 b3 - - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); - - float32x4_t _descale = vcombine_f32(_descale01, _descale01); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum01.val[0]), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum01.val[1]), _descale); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + // c_elempack == 1 + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - _f1 = vaddq_f32(_f1, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4x2_t _c01 = vzip_u16(_cc0, _cc1); - _c0 = bfloat2float(_c01.val[0]); - float32x4_t _c1 = bfloat2float(_c01.val[1]); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - uint16x4_t _c = vld1_u16(pC); - uint16x4x2_t _cc = vzip_u16(_c, _c); - _c0 = bfloat2float(_cc.val[0]); - float32x4_t _c1 = bfloat2float(_cc.val[1]); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 4; } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = bfloat2float(vld1_u16(pC)); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - uint16x4_t _bf0 = float2bfloat(_f0); - uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + if (out_elempack == 4) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + } + else // if (out_elempack == 1) + { p0[0] = vget_lane_u16(_bf0, 0); - p0[1] = vget_lane_u16(_bf0, 1); - p0[out_hstep] = vget_lane_u16(_bf0, 2); - p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 2] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 3] = vget_lane_u16(_bf1, 2); + p0[1] = vget_lane_u16(_bf1, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep + 1] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 2 + 1] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); p0[out_hstep * 3 + 1] = vget_lane_u16(_bf1, 3); - - pp += 8; - p0 += out_hstep * 4; } - for (; jj + 1 < max_jj; jj += 2) - { - // a0 a1 b0 b1 - int32x2x2_t _sum0 = vld2_s32(pp); - float32x4_t _descale = vcombine_f32(_descale01, _descale01); + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); + + float32x4_t _descale = vcombine_f32(_descale01, _descale01); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - uint16x4_t _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[c_hstep], _c, 1); - _c = vset_lane_u16(pC[1], _c, 2); - _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); - _c0 = bfloat2float(_c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } - if (broadcast_type_C == 4) - { - uint16x4_t _c = uint16x4_t(); - _c = vset_lane_u16(pC[0], _c, 0); - _c = vset_lane_u16(pC[0], _c, 1); - _c = vset_lane_u16(pC[1], _c, 2); - _c = vset_lane_u16(pC[1], _c, 3); - _c0 = bfloat2float(_c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[c_hstep], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[c_hstep + 1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + uint16x4_t _c = uint16x4_t(); + _c = vset_lane_u16(pC[0], _c, 0); + _c = vset_lane_u16(pC[0], _c, 1); + _c = vset_lane_u16(pC[1], _c, 2); + _c = vset_lane_u16(pC[1], _c, 3); + _c0 = bfloat2float(_c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf0 = float2bfloat(_f0); - p0[0] = vget_lane_u16(_bf0, 0); - p0[1] = vget_lane_u16(_bf0, 1); - p0[out_hstep] = vget_lane_u16(_bf0, 2); - p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); + p0[0] = vget_lane_u16(_bf0, 0); + p0[1] = vget_lane_u16(_bf0, 1); + p0[out_hstep] = vget_lane_u16(_bf0, 2); + p0[out_hstep + 1] = vget_lane_u16(_bf0, 3); - pp += 4; - p0 += out_hstep * 2; - } + pp += 4; + p0 += out_hstep * 2; + } #endif // __ARM_NEON - for (; jj < max_jj; jj += 1) - { - float f0 = pp[0] * descale0; - float f1 = pp[1] * descale1; + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - f0 += c0; - f1 += c0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - f1 += c1; - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]) * beta; - f1 += bfloat16_to_float32(pC[c_hstep]) * beta; - pC += 1; - } - if (broadcast_type_C == 4) - { - c0 = bfloat16_to_float32(pC[0]) * beta; - f0 += c0; - f1 += c0; - pC += 1; - } + f0 += c0; + f1 += c0; } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) { - f0 *= alpha; - f1 *= alpha; + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + f1 += bfloat16_to_float32(pC[c_hstep]) * beta; + pC += 1; } + if (broadcast_type_C == 4) + { + c0 = bfloat16_to_float32(pC[0]) * beta; + f0 += c0; + f1 += c0; + pC += 1; + } + } - p0[0] = float32_to_bfloat16(f0); - p0[1] = float32_to_bfloat16(f1); - pp += 2; - p0 += out_hstep; + if (alpha != 1.f) + { + f0 *= alpha; + f1 *= alpha; } + + p0[0] = float32_to_bfloat16(f0); + p0[1] = float32_to_bfloat16(f1); + pp += 2; + p0 += out_hstep; } } for (; ii < max_ii; ii += 1) @@ -10973,427 +8287,267 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } } + int jj = 0; #if __ARM_NEON - if (out_elempack == 4) + for (; jj + 15 < max_jj; jj += 16) { - int jj = 0; - for (; jj + 15 < max_jj; jj += 16) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - uint16x4_t _bf0 = float2bfloat(_f0); - uint16x4_t _bf1 = float2bfloat(_f1); - uint16x4_t _bf2 = float2bfloat(_f2); - uint16x4_t _bf3 = float2bfloat(_f3); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (out_hstep == 1) - { - vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); - vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); - } - else + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - vst1_u16(p0, _bf0); - vst1_u16(p0 + out_hstep * 4, _bf1); - vst1_u16(p0 + out_hstep * 8, _bf2); - vst1_u16(p0 + out_hstep * 12, _bf3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } - - pp += 16; - p0 += out_hstep * 16; - } - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - - if (pC) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + // c_elempack == 1 + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3 || broadcast_type_C == 4) + else { - // c_elempack == 1 - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } + pC += 16; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - uint16x4_t _bf0 = float2bfloat(_f0); - uint16x4_t _bf1 = float2bfloat(_f1); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - if (out_hstep == 1) - { - vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); - } - else - { - vst1_u16(p0, _bf0); - vst1_u16(p0 + out_hstep * 4, _bf1); - } + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); - pp += 8; - p0 += out_hstep * 8; + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); } - for (; jj + 3 < max_jj; jj += 4) + else if (out_elempack == 4) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; - } - } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1_u16(p0, float2bfloat(_f0)); - pp += 4; - p0 += out_hstep * 4; + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + vst1_u16(p0 + out_hstep * 8, _bf2); + vst1_u16(p0 + out_hstep * 12, _bf3); } + else if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + } + + pp += 16; + p0 += out_hstep * 16; } -#endif // __ARM_NEON - if (out_elempack == 1) + for (; jj + 7 < max_jj; jj += 8) { - int jj = 0; -#if __ARM_NEON - for (; jj + 15 < max_jj; jj += 16) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + uint16x8_t _c = vld1q_u16(pC); + _c0 = bfloat2float(vget_low_u16(_c)); + float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 3 || broadcast_type_C == 4) + else { - // c_elempack == 1 - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - uint16x4_t _bf0 = float2bfloat(_f0); - uint16x4_t _bf1 = float2bfloat(_f1); - uint16x4_t _bf2 = float2bfloat(_f2); - uint16x4_t _bf3 = float2bfloat(_f3); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - if (out_hstep == 1) - { - vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); - vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); - } - else - { - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); - } + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); - pp += 16; - p0 += out_hstep * 16; + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); } - for (; jj + 7 < max_jj; jj += 8) + else if (out_elempack == 4) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + } + else // if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + } - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - uint16x8_t _c = vld1q_u16(pC); - _c0 = bfloat2float(vget_low_u16(_c)); - float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; - } + _f0 = vaddq_f32(_f0, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = bfloat2float(vld1_u16(pC)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; } + } - uint16x4_t _bf0 = float2bfloat(_f0); - uint16x4_t _bf1 = float2bfloat(_f1); + _f0 = vmulq_n_f32(_f0, alpha); - if (out_hstep == 1) - { - vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); - } - else - { - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); - } + uint16x4_t _bf0 = float2bfloat(_f0); - pp += 8; - p0 += out_hstep * 8; + if (out_hstep == 1 || out_elempack == 4) + { + vst1_u16(p0, _bf0); } - for (; jj + 3 < max_jj; jj += 4) + else // if (out_elempack == 1) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = bfloat2float(vld1_u16(pC)); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; - } - } - - _f0 = vmulq_n_f32(_f0, alpha); + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + } - uint16x4_t _bf0 = float2bfloat(_f0); + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - if (out_hstep == 1) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - vst1_u16(p0, _bf0); + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); } - else + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + // c_elempack == 1 + float32x2_t _c = float32x2_t(); + _c = vset_lane_f32(bfloat16_to_float32(pC[0]), _c, 0); + _c = vset_lane_f32(bfloat16_to_float32(pC[1]), _c, 1); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; } - - pp += 4; - p0 += out_hstep * 4; } - for (; jj + 1 < max_jj; jj += 2) - { - float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vadd_f32(_f0, vget_low_f32(_c0)); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - float32x2_t _c = float32x2_t(); - _c = vset_lane_f32(bfloat16_to_float32(pC[0]), _c, 0); - _c = vset_lane_f32(bfloat16_to_float32(pC[1]), _c, 1); - _f0 = vmla_n_f32(_f0, _c, beta); - pC += 2; - } - } - _f0 = vmul_n_f32(_f0, alpha); + _f0 = vmul_n_f32(_f0, alpha); - p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); - p0[out_hstep] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); + p0[0] = float32_to_bfloat16(vget_lane_f32(_f0, 0)); + p0[out_hstep] = float32_to_bfloat16(vget_lane_f32(_f0, 1)); - pp += 2; - p0 += out_hstep * 2; - } + pp += 2; + p0 += out_hstep * 2; + } #endif // __ARM_NEON - for (; jj < max_jj; jj += 1) - { - float f0 = pp[0] * descale; + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - f0 += bfloat16_to_float32(pC[0]) * beta; - pC += 1; - } + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += bfloat16_to_float32(pC[0]) * beta; + pC += 1; } + } - f0 *= alpha; + f0 *= alpha; - p0[0] = float32_to_bfloat16(f0); + p0[0] = float32_to_bfloat16(f0); - pp += 1; - p0 += out_hstep; - } + pp += 1; + p0 += out_hstep; } } } From a1f9b2861c30fc443c0e25d4e31ecabc6c5f3f19 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 11 Oct 2024 19:13:46 +0800 Subject: [PATCH 49/55] fix --- src/layer/arm/gemm_arm.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index aee3d9d3d28b..7607d8f523e5 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -6201,15 +6201,14 @@ int Gemm_arm::forward_int8(const std::vector& bottom_blobs, std::vector Date: Fri, 11 Oct 2024 19:47:13 +0800 Subject: [PATCH 50/55] test fp16s --- tests/test_gemm_3.cpp | 59 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index f8433d18b215..35b30623c872 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -136,6 +136,60 @@ static int test_gemm_int8_bias(int M, int N, int K, const ncnn::Mat& C, float al return ret; } +static int test_gemm_int8_fp16s(int M, int N, int K, float alpha, int transA, int transB, int output_elemtype, int output_transpose, int constantA, int constantB, int output_N1M) +{ + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, 1.f); // beta + pd.set(2, transA); + pd.set(3, transB); + pd.set(4, constantA); + pd.set(5, constantB); + pd.set(6, 1); + pd.set(7, M); + pd.set(8, N); + pd.set(9, K); + pd.set(10, -1); + pd.set(11, output_N1M); + pd.set(13, output_elemtype); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + std::vector weights; + if (constantA) weights.push_back(transA ? (output_N1M ? RandomS8Mat(M, 1, K) : RandomS8Mat(M, K)) : (output_N1M ? RandomS8Mat(K, 1, M) : RandomS8Mat(K, M))); + if (constantB) weights.push_back(transB ? (output_N1M ? RandomS8Mat(K, 1, N) : RandomS8Mat(K, N)) : (output_N1M ? RandomS8Mat(N, 1, K) : RandomS8Mat(N, K))); + if (constantA) weights.push_back(RandomMat(M, 10.f, 20.f)); + if (constantB) weights.push_back(RandomMat(1, 10.f, 20.f)); + + std::vector a; + if (!constantA) a.push_back(transA ? (output_N1M ? ncnn::Mat(M, 1, K) : ncnn::Mat(M, K)) : (output_N1M ? ncnn::Mat(K, 1, M) : ncnn::Mat(K, M))); + if (!constantB) a.push_back(transB ? (output_N1M ? ncnn::Mat(K, 1, N) : ncnn::Mat(K, N)) : (output_N1M ? ncnn::Mat(N, 1, K) : ncnn::Mat(N, K))); + + for (size_t i = 0; i < a.size(); i++) + { + Randomize(a[i], -10.f, 10.f); + } + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = true; + opt.use_fp16_packed = false; + opt.use_fp16_storage = true; + opt.use_fp16_arithmetic = false; + opt.use_bf16_storage = false; + + float epsilon = 0.001; + + int ret = test_layer_opt("Gemm", pd, weights, opt, a, 1, epsilon); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8_fp16s failed M=%d N=%d K=%d alpha=%f transA=%d transB=%d output_elemtype=%d output_transpose=%d constantA=%d constantB=%d output_N1M=%d\n", M, N, K, alpha, transA, transB, output_elemtype, output_transpose, constantA, constantB, output_N1M); + return ret; + } + + return 0; +} + static int test_gemm_0(int M, int N, int K) { return 0 @@ -167,7 +221,10 @@ static int test_gemm_0(int M, int N, int K) || test_gemm_int8(M, N, K, -2.1f, 0, 1, 0, 1, 1, 1, 0) || test_gemm_int8(M, N, K, -3.1f, 1, 1, 0, 1, 1, 1, 1) || test_gemm_int8(M, N, K, -4.1f, 0, 0, 0, 1, 1, 1, 0) - || test_gemm_int8(M, N, K, -5.1f, 1, 0, 0, 1, 1, 1, 1); + || test_gemm_int8(M, N, K, -5.1f, 1, 0, 0, 1, 1, 1, 1) + + || test_gemm_int8_fp16s(M, N, K, 1.f, 0, 1, 0, 0, 0, 0, 0) + || test_gemm_int8_fp16s(M, N, K, 1.f, 1, 0, 0, 1, 0, 0, 0); } static int test_gemm_1(int M, int N, int K) From e078b6ef8210038aac0bf5e7e90fb9aef69ee229 Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 12 Oct 2024 10:51:48 +0800 Subject: [PATCH 51/55] enable elempack=8 only for asimdhp+ --- src/layer/arm/gemm_int8_fp16s.h | 1153 ++++++++++++++----------------- 1 file changed, 526 insertions(+), 627 deletions(-) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index f0a0f3743560..a7e1f15d5ddf 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -56,16 +56,14 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s float* pods = out_descales; #if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { -#if __aarch64__ float32x4_t _v127 = vdupq_n_f32(127.f); float32x4_t _v127_B_scale = vdupq_n_f32(v127_B_scale); -#endif for (int ii = 0; ii + 7 < max_ii; ii += 8) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const __fp16* p0 = (const __fp16*)A + (i + ii) * A_hstep; float16x8_t _amax0 = vdupq_n_f16((__fp16)0.f); @@ -104,42 +102,7 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s } float32x4_t _absmax0 = vcvt_f32_f16(vget_low_f16(_amax0)); float32x4_t _absmax1 = vcvt_f32_f16(vget_high_f16(_amax0)); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * A_hstep; - - float32x4_t _absmax0 = vdupq_n_f32(0.f); - float32x4_t _absmax1 = vdupq_n_f32(0.f); - float32x4_t _absmax2 = vdupq_n_f32(0.f); - float32x4_t _absmax3 = vdupq_n_f32(0.f); - int kk = 0; - for (; kk + 1 < K; kk += 2) - { - uint16x8_t _p01 = vld1q_u16(p0); - uint16x8_t _p23 = vld1q_u16(p0 + 8); - float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); - float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); - float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); - float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); - _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); - _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); - _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); - _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); - p0 += 16; - } - _absmax0 = vmaxq_f32(_absmax0, _absmax2); - _absmax1 = vmaxq_f32(_absmax1, _absmax3); - for (; kk < K; kk++) - { - uint16x8_t _p01 = vld1q_u16(p0); - float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); - float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); - _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); - _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); - p0 += 8; - } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#if __aarch64__ float32x4_t _scale0 = vdivq_f32(_v127, _absmax0); float32x4_t _scale1 = vdivq_f32(_v127, _absmax1); float32x4_t _out_descale0 = vdivq_f32(_absmax0, _v127_B_scale); @@ -149,41 +112,12 @@ static void compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float B_s vst1q_f32(ps + 4, _scale1); vst1q_f32(pods, _out_descale0); vst1q_f32(pods + 4, _out_descale1); -#else - // float32x4_t _recp_absmax = vrecpeq_f32(_absmax0); - // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); - // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); - // _recp_absmax = vmulq_f32(vrecpsq_f32(_absmax0, _recp_absmax), _recp_absmax); - // float32x4_t _scale = vmulq_f32(_v127, _recp_absmax); - // float32x4_t _out_descale = vmulq_f32(_absmax0, _recp_v127_B_scale); - - float tmp[8]; - vst1q_f32(tmp, _absmax0); - vst1q_f32(tmp + 4, _absmax1); - ps[0] = 127.f / tmp[0]; - ps[1] = 127.f / tmp[1]; - ps[2] = 127.f / tmp[2]; - ps[3] = 127.f / tmp[3]; - ps[4] = 127.f / tmp[4]; - ps[5] = 127.f / tmp[5]; - ps[6] = 127.f / tmp[6]; - ps[7] = 127.f / tmp[7]; - - pods[0] = tmp[0] / v127_B_scale; - pods[1] = tmp[1] / v127_B_scale; - pods[2] = tmp[2] / v127_B_scale; - pods[3] = tmp[3] / v127_B_scale; - pods[4] = tmp[4] / v127_B_scale; - pods[5] = tmp[5] / v127_B_scale; - pods[6] = tmp[6] / v127_B_scale; - pods[7] = tmp[7] / v127_B_scale; - -#endif ps += 8; pods += 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { #if __aarch64__ @@ -459,6 +393,7 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -626,6 +561,7 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i p0 += 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -1539,12 +1475,12 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, float* pods = out_descales; #if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int ii = 0; for (; ii + 1 < max_ii; ii += 2) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const __fp16* p0 = (const __fp16*)A + (i + ii) * 8; float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); @@ -1576,41 +1512,6 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, } float absmax0 = (float)vmaxvq_f16(_absmax0); float absmax1 = (float)vmaxvq_f16(_absmax1); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 8; - - float32x4_t _absmax0 = vdupq_n_f32(0.f); - float32x4_t _absmax1 = vdupq_n_f32(0.f); - float32x4_t _absmax2 = vdupq_n_f32(0.f); - float32x4_t _absmax3 = vdupq_n_f32(0.f); - int kk = 0; - for (; kk < K; kk++) - { - uint16x8_t _p01 = vld1q_u16(p0); - uint16x8_t _p23 = vld1q_u16(p0 + 8); - float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); - float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); - float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); - float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); - _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); - _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); - _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); - _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); - p0 += A_hstep * 8; - } - _absmax0 = vmaxq_f32(_absmax0, _absmax2); - _absmax1 = vmaxq_f32(_absmax1, _absmax3); -#if __aarch64__ - float absmax0 = vmaxvq_f32(_absmax0); - float absmax1 = vmaxvq_f32(_absmax1); -#else - float32x2_t _aa0 = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); - float32x2_t _aa1 = vmax_f32(vget_low_f32(_absmax1), vget_high_f32(_absmax1)); - float32x2_t _aa01 = vpmax_f32(_aa0, _aa1); - float absmax0 = vget_lane_f32(_aa01, 0); - float absmax1 = vget_lane_f32(_aa01, 1); -#endif -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ps[0] = 127.f / absmax0; ps[1] = 127.f / absmax1; @@ -1621,7 +1522,6 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, } for (; ii < max_ii; ii++) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const __fp16* p0 = (const __fp16*)A + (i + ii) * 8; float16x8_t _absmax0 = vdupq_n_f16((__fp16)0.f); @@ -1659,47 +1559,6 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, p0 += A_hstep * 8; } float absmax = (float)vmaxvq_f16(_absmax0); -#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - const unsigned short* p0 = (const unsigned short*)A + (i + ii) * 8; - - float32x4_t _absmax0 = vdupq_n_f32(0.f); - float32x4_t _absmax1 = vdupq_n_f32(0.f); - float32x4_t _absmax2 = vdupq_n_f32(0.f); - float32x4_t _absmax3 = vdupq_n_f32(0.f); - int kk = 0; - for (; kk + 1 < K; kk += 2) - { - uint16x8_t _p01 = vld1q_u16(p0); - uint16x8_t _p23 = vld1q_u16(p0 + A_hstep * 8); - float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); - float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); - float32x4_t _p2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p23)); - float32x4_t _p3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p23)); - _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); - _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); - _absmax2 = vmaxq_f32(_absmax2, vabsq_f32(_p2)); - _absmax3 = vmaxq_f32(_absmax3, vabsq_f32(_p3)); - p0 += A_hstep * 16; - } - _absmax0 = vmaxq_f32(_absmax0, _absmax2); - _absmax1 = vmaxq_f32(_absmax1, _absmax3); - for (; kk < K; kk++) - { - uint16x8_t _p01 = vld1q_u16(p0); - float32x4_t _p0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_p01)); - float32x4_t _p1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_p01)); - _absmax0 = vmaxq_f32(_absmax0, vabsq_f32(_p0)); - _absmax1 = vmaxq_f32(_absmax1, vabsq_f32(_p1)); - p0 += A_hstep * 8; - } - _absmax0 = vmaxq_f32(_absmax0, _absmax1); -#if __aarch64__ - float absmax = vmaxvq_f32(_absmax0); -#else - float32x2_t _aa = vmax_f32(vget_low_f32(_absmax0), vget_high_f32(_absmax0)); - float absmax = std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)); -#endif -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC ps[0] = 127.f / absmax; pods[0] = absmax / v127_B_scale; @@ -1707,6 +1566,7 @@ static void transpose_compute_A_tile_fp16_int8_scales(const Mat& A, Mat& scales, pods++; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int ii = 0; @@ -2178,6 +2038,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -2208,7 +2069,6 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _pe = vcvt_f32_f16((float16x4_t)vget_low_u16(_w)); float32x4_t _pf = vcvt_f32_f16((float16x4_t)vget_high_u16(_w)); -#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale0, 0); _p1 = vmulq_laneq_f32(_p1, _scale0, 0); _p2 = vmulq_laneq_f32(_p2, _scale0, 1); @@ -2225,24 +2085,6 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int _pd = vmulq_laneq_f32(_pd, _scale1, 2); _pe = vmulq_laneq_f32(_pe, _scale1, 3); _pf = vmulq_laneq_f32(_pf, _scale1, 3); -#else - _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale0), 0); - _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale0), 0); - _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale0), 1); - _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale0), 1); - _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale0), 0); - _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale0), 0); - _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale0), 1); - _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale0), 1); - _p8 = vmulq_lane_f32(_p8, vget_low_f32(_scale1), 0); - _p9 = vmulq_lane_f32(_p9, vget_low_f32(_scale1), 0); - _pa = vmulq_lane_f32(_pa, vget_low_f32(_scale1), 1); - _pb = vmulq_lane_f32(_pb, vget_low_f32(_scale1), 1); - _pc = vmulq_lane_f32(_pc, vget_high_f32(_scale1), 0); - _pd = vmulq_lane_f32(_pd, vget_high_f32(_scale1), 0); - _pe = vmulq_lane_f32(_pe, vget_high_f32(_scale1), 1); - _pf = vmulq_lane_f32(_pf, vget_high_f32(_scale1), 1); -#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -2301,6 +2143,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -2669,6 +2512,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _scale = vld1q_f32((const float*)scales + ii); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -2687,7 +2531,6 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _p6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_s)); float32x4_t _p7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_s)); -#if __aarch64__ _p0 = vmulq_laneq_f32(_p0, _scale, 0); _p1 = vmulq_laneq_f32(_p1, _scale, 0); _p2 = vmulq_laneq_f32(_p2, _scale, 1); @@ -2696,16 +2539,6 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int _p5 = vmulq_laneq_f32(_p5, _scale, 2); _p6 = vmulq_laneq_f32(_p6, _scale, 3); _p7 = vmulq_laneq_f32(_p7, _scale, 3); -#else - _p0 = vmulq_lane_f32(_p0, vget_low_f32(_scale), 0); - _p1 = vmulq_lane_f32(_p1, vget_low_f32(_scale), 0); - _p2 = vmulq_lane_f32(_p2, vget_low_f32(_scale), 1); - _p3 = vmulq_lane_f32(_p3, vget_low_f32(_scale), 1); - _p4 = vmulq_lane_f32(_p4, vget_high_f32(_scale), 0); - _p5 = vmulq_lane_f32(_p5, vget_high_f32(_scale), 0); - _p6 = vmulq_lane_f32(_p6, vget_high_f32(_scale), 1); - _p7 = vmulq_lane_f32(_p7, vget_high_f32(_scale), 1); -#endif #if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_MATMUL_INT8 @@ -2739,6 +2572,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -2975,6 +2809,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_NEON float32x4_t _scale0 = vdupq_n_f32(scale0); float32x4_t _scale1 = vdupq_n_f32(scale1); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -3014,6 +2849,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -3236,6 +3072,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_NEON float32x4_t _scale = vdupq_n_f32(scale); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -3277,6 +3114,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -3576,6 +3414,7 @@ static void pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i { const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -3743,6 +3582,7 @@ static void pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i p0 += 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -4600,6 +4440,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int { const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -4704,6 +4545,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -5040,6 +4882,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int { const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -5099,6 +4942,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -5306,6 +5150,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; #if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -5345,6 +5190,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -5563,6 +5409,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; #if __ARM_NEON +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 8) { int kk = 0; @@ -5604,6 +5451,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (elempack == 4) { int kk = 0; @@ -5973,21 +5821,25 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { - if (c_elempack == 1) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); - float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); - float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); - float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + uint16x8_t _c08 = vld1q_u16(pC); + uint16x8_t _c19 = vld1q_u16(pC + 8); + uint16x8_t _c2a = vld1q_u16(pC + 16); + uint16x8_t _c3b = vld1q_u16(pC + 24); + uint16x8_t _c4c = vld1q_u16(pC + 32); + uint16x8_t _c5d = vld1q_u16(pC + 40); + uint16x8_t _c6e = vld1q_u16(pC + 48); + uint16x8_t _c7f = vld1q_u16(pC + 56); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c08)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c19)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c2a)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c3b)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c4c)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c5d)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c6e)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c7f)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -6011,19 +5863,14 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _f6 = vmlaq_f32(_f6, _c6, _beta); _f7 = vmlaq_f32(_f7, _c7, _beta); } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 5); - _c45 = vld1q_u16(pC + c_hstep * 6); - _c67 = vld1q_u16(pC + c_hstep * 7); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); - _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); - _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); - _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + _c0 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c08)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c19)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c2a)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c3b)); + _c4 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c4c)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c5d)); + _c6 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c6e)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c7f)); if (beta == 1.f) { _f8 = vaddq_f32(_f8, _c0); @@ -6047,9 +5894,10 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _fe = vmlaq_f32(_fe, _c6, _beta); _ff = vmlaq_f32(_ff, _c7, _beta); } - pC += 8; + pC += 64; } - else if (c_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -6123,24 +5971,21 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } pC += 32; } - else // if (c_elempack == 8) + if (c_elempack == 1) { - uint16x8_t _c08 = vld1q_u16(pC); - uint16x8_t _c19 = vld1q_u16(pC + 8); - uint16x8_t _c2a = vld1q_u16(pC + 16); - uint16x8_t _c3b = vld1q_u16(pC + 24); - uint16x8_t _c4c = vld1q_u16(pC + 32); - uint16x8_t _c5d = vld1q_u16(pC + 40); - uint16x8_t _c6e = vld1q_u16(pC + 48); - uint16x8_t _c7f = vld1q_u16(pC + 56); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c08)); - _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c19)); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c2a)); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c3b)); - float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c4c)); - float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c5d)); - float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c6e)); - float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c7f)); + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -6164,14 +6009,19 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _f6 = vmlaq_f32(_f6, _c6, _beta); _f7 = vmlaq_f32(_f7, _c7, _beta); } - _c0 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c08)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c19)); - _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c2a)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c3b)); - _c4 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c4c)); - _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c5d)); - _c6 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c6e)); - _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c7f)); + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); if (beta == 1.f) { _f8 = vaddq_f32(_f8, _c0); @@ -6195,7 +6045,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _fe = vmlaq_f32(_fe, _c6, _beta); _ff = vmlaq_f32(_ff, _c7, _beta); } - pC += 64; + pC += 8; } } if (broadcast_type_C == 4) @@ -6275,6 +6125,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hfe = (uint16x4_t)vcvt_f16_f32(_fe); uint16x4_t _hff = (uint16x4_t)vcvt_f16_f32(_ff); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf8)); @@ -6287,7 +6138,8 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 56, vcombine_u16(_hf7, _hff)); p0 += 64; } - else if (out_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); @@ -6299,7 +6151,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(_hfe, _hff)); p0 += 32; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); transpose4x4_u16(_hf4, _hf5, _hf6, _hf7); @@ -6423,17 +6275,17 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { - if (c_elempack == 1) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = vcvt_f32_f16((float16x4_t)_cc0); - _c1 = vcvt_f32_f16((float16x4_t)_cc1); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)_cc2); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)_cc3); + uint16x8_t _c04 = vld1q_u16(pC); + uint16x8_t _c15 = vld1q_u16(pC + 8); + uint16x8_t _c26 = vld1q_u16(pC + 16); + uint16x8_t _c37 = vld1q_u16(pC + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c04)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c15)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c26)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c37)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -6449,15 +6301,10 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _cc0 = vld1_u16(pC + c_hstep * 4); - _cc1 = vld1_u16(pC + c_hstep * 5); - _cc2 = vld1_u16(pC + c_hstep * 6); - _cc3 = vld1_u16(pC + c_hstep * 7); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = vcvt_f32_f16((float16x4_t)_cc0); - _c1 = vcvt_f32_f16((float16x4_t)_cc1); - _c2 = vcvt_f32_f16((float16x4_t)_cc2); - _c3 = vcvt_f32_f16((float16x4_t)_cc3); + _c0 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c04)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c15)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c26)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c37)); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0); @@ -6473,9 +6320,10 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _f6 = vmlaq_f32(_f6, _c2, _beta); _f7 = vmlaq_f32(_f7, _c3, _beta); } - pC += 4; + pC += 32; } - else if (c_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -6521,16 +6369,17 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } pC += 16; } - else // if (c_elempack == 8) + if (c_elempack == 1) { - uint16x8_t _c04 = vld1q_u16(pC); - uint16x8_t _c15 = vld1q_u16(pC + 8); - uint16x8_t _c26 = vld1q_u16(pC + 16); - uint16x8_t _c37 = vld1q_u16(pC + 24); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c04)); - _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c15)); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c26)); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c37)); + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)_cc2); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)_cc3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -6546,10 +6395,15 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _c0 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c04)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c15)); - _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c26)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c37)); + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + _c2 = vcvt_f32_f16((float16x4_t)_cc2); + _c3 = vcvt_f32_f16((float16x4_t)_cc3); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0); @@ -6565,7 +6419,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _f6 = vmlaq_f32(_f6, _c2, _beta); _f7 = vmlaq_f32(_f7, _c3, _beta); } - pC += 32; + pC += 4; } } if (broadcast_type_C == 4) @@ -6617,6 +6471,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf4)); @@ -6625,7 +6480,8 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 24, vcombine_u16(_hf3, _hf7)); p0 += 32; } - else if (out_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); @@ -6633,7 +6489,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_hf6, _hf7)); p0 += 16; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); transpose4x4_u16(_hf4, _hf5, _hf6, _hf7); @@ -6714,6 +6570,28 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& { float32x4_t _c2; float32x4_t _c3; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) + { + uint16x8_t _cc0 = vld1q_u16(pC); + uint16x8_t _cc1 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_cc0)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_cc1)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc0)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc1)); + pC += 16; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 8; + } if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -6740,26 +6618,6 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); pC += 2; } - else if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - pC += 8; - } - else // if (c_elempack == 8) - { - uint16x8_t _cc0 = vld1q_u16(pC); - uint16x8_t _cc1 = vld1q_u16(pC + 8); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_cc0)); - _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_cc1)); - _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc0)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc1)); - pC += 16; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -6802,19 +6660,21 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf2)); vst1q_u16(p0 + 8, vcombine_u16(_hf1, _hf3)); p0 += 16; } - else if (out_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_hf2, _hf3)); p0 += 8; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_hf0, 0); p0[1] = vget_lane_u16(_hf1, 0); @@ -6859,6 +6719,21 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + pC += 8; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) + { + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep * 4)); + pC += 4; + } if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -6874,19 +6749,6 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); pC += 1; } - else if (c_elempack == 4) - { - _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); - _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep * 4)); - pC += 4; - } - else // if (c_elempack == 8) - { - uint16x8_t _c = vld1q_u16(pC); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); - pC += 8; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -6918,18 +6780,20 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); p0 += 8; } - else if (out_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) { vst1_u16(p0, _hf0); vst1_u16(p0 + out_hstep * 4, _hf1); p0 += 4; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_hf0, 0); p0[out_hstep] = vget_lane_u16(_hf0, 1); @@ -7084,6 +6948,14 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c23; uint16x8_t _c45; uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } if (c_elempack == 1) { _c01 = vld1q_u16(pC); @@ -7093,14 +6965,6 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& transpose8x4_u16(_c01, _c23, _c45, _c67); pC += 8; } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + 8); - _c45 = vld1q_u16(pC + 16); - _c67 = vld1q_u16(pC + 24); - pC += 32; - } _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); @@ -7194,7 +7058,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 24, vcombine_u16(_hf6, _hf7)); p0 += 32; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); transpose4x4_u16(_hf4, _hf5, _hf6, _hf7); @@ -7275,6 +7139,16 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& float32x4_t _c1; float32x4_t _c2; float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 16; + } if (c_elempack == 1) { uint16x4_t _cc0 = vld1_u16(pC); @@ -7288,16 +7162,6 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _c3 = vcvt_f32_f16((float16x4_t)_cc3); pC += 4; } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - pC += 16; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -7357,7 +7221,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); p0 += 16; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); vst1_u16(p0, _hf0); @@ -7413,6 +7277,11 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } if (c_elempack == 1) { _c = uint16x8_t(); @@ -7426,11 +7295,6 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); pC += 2; } - else // if (c_elempack == 4) - { - _c = vld1q_u16(pC); - pC += 8; - } _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); if (beta == 1.f) @@ -7470,7 +7334,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); p0 += 8; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_hf0, 0); p0[1] = vget_lane_u16(_hf1, 0); @@ -7502,6 +7366,11 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } if (c_elempack == 1) { _c = uint16x4_t(); @@ -7511,11 +7380,6 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); pC += 1; } - else // if (c_elempack == 4) - { - _c = vld1_u16(pC); - pC += 4; - } _c0 = vcvt_f32_f16((float16x4_t)_c); _f0 = vmlaq_n_f32(_f0, _c0, beta); } @@ -7536,7 +7400,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1_u16(p0, _hf0); p0 += 4; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_hf0, 0); p0[out_hstep] = vget_lane_u16(_hf0, 1); @@ -8335,36 +8199,57 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { - if (c_elempack == 1) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); - float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); - float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); - float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); + uint16x8_t _c08 = vld1q_u16(pC); + uint16x8_t _c19 = vld1q_u16(pC + 8); + uint16x8_t _c2a = vld1q_u16(pC + 16); + uint16x8_t _c3b = vld1q_u16(pC + 24); + uint16x8_t _c4c = vld1q_u16(pC + 32); + uint16x8_t _c5d = vld1q_u16(pC + 40); + uint16x8_t _c6e = vld1q_u16(pC + 48); + uint16x8_t _c7f = vld1q_u16(pC + 56); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c08)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c19)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c2a)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c3b)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c4c)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c5d)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c6e)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c7f)); + float32x4_t _c8 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c08)); + float32x4_t _c9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c19)); + float32x4_t _ca = vcvt_f32_f16((float16x4_t)vget_high_u16(_c2a)); + float32x4_t _cb = vcvt_f32_f16((float16x4_t)vget_high_u16(_c3b)); + float32x4_t _cc = vcvt_f32_f16((float16x4_t)vget_high_u16(_c4c)); + float32x4_t _cd = vcvt_f32_f16((float16x4_t)vget_high_u16(_c5d)); + float32x4_t _ce = vcvt_f32_f16((float16x4_t)vget_high_u16(_c6e)); + float32x4_t _cf = vcvt_f32_f16((float16x4_t)vget_high_u16(_c7f)); + + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c8); + _f9 = vaddq_f32(_f9, _c9); + _fa = vaddq_f32(_fa, _ca); + _fb = vaddq_f32(_fb, _cb); + _fc = vaddq_f32(_fc, _cc); + _fd = vaddq_f32(_fd, _cd); + _fe = vaddq_f32(_fe, _ce); + _ff = vaddq_f32(_ff, _cf); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); @@ -8372,46 +8257,19 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f5 = vmlaq_f32(_f5, _c5, _beta); _f6 = vmlaq_f32(_f6, _c6, _beta); _f7 = vmlaq_f32(_f7, _c7, _beta); + _f8 = vmlaq_f32(_f8, _c8, _beta); + _f9 = vmlaq_f32(_f9, _c9, _beta); + _fa = vmlaq_f32(_fa, _ca, _beta); + _fb = vmlaq_f32(_fb, _cb, _beta); + _fc = vmlaq_f32(_fc, _cc, _beta); + _fd = vmlaq_f32(_fd, _cd, _beta); + _fe = vmlaq_f32(_fe, _ce, _beta); + _ff = vmlaq_f32(_ff, _cf, _beta); } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 5); - _c45 = vld1q_u16(pC + c_hstep * 6); - _c67 = vld1q_u16(pC + c_hstep * 7); - transpose8x4_u16(_c01, _c23, _c45, _c67); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); - _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); - _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); - _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + pC += 64; } - else if (c_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -8485,33 +8343,21 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } pC += 32; } - else // if (c_elempack == 8) + if (c_elempack == 1) { - uint16x8_t _c08 = vld1q_u16(pC); - uint16x8_t _c19 = vld1q_u16(pC + 8); - uint16x8_t _c2a = vld1q_u16(pC + 16); - uint16x8_t _c3b = vld1q_u16(pC + 24); - uint16x8_t _c4c = vld1q_u16(pC + 32); - uint16x8_t _c5d = vld1q_u16(pC + 40); - uint16x8_t _c6e = vld1q_u16(pC + 48); - uint16x8_t _c7f = vld1q_u16(pC + 56); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c08)); - _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c19)); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c2a)); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c3b)); - float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c4c)); - float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c5d)); - float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c6e)); - float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c7f)); - float32x4_t _c8 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c08)); - float32x4_t _c9 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c19)); - float32x4_t _ca = vcvt_f32_f16((float16x4_t)vget_high_u16(_c2a)); - float32x4_t _cb = vcvt_f32_f16((float16x4_t)vget_high_u16(_c3b)); - float32x4_t _cc = vcvt_f32_f16((float16x4_t)vget_high_u16(_c4c)); - float32x4_t _cd = vcvt_f32_f16((float16x4_t)vget_high_u16(_c5d)); - float32x4_t _ce = vcvt_f32_f16((float16x4_t)vget_high_u16(_c6e)); - float32x4_t _cf = vcvt_f32_f16((float16x4_t)vget_high_u16(_c7f)); - + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -8522,14 +8368,6 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c8); - _f9 = vaddq_f32(_f9, _c9); - _fa = vaddq_f32(_fa, _ca); - _fb = vaddq_f32(_fb, _cb); - _fc = vaddq_f32(_fc, _cc); - _fd = vaddq_f32(_fd, _cd); - _fe = vaddq_f32(_fe, _ce); - _ff = vaddq_f32(_ff, _cf); } else { @@ -8542,16 +8380,44 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f5 = vmlaq_f32(_f5, _c5, _beta); _f6 = vmlaq_f32(_f6, _c6, _beta); _f7 = vmlaq_f32(_f7, _c7, _beta); - _f8 = vmlaq_f32(_f8, _c8, _beta); - _f9 = vmlaq_f32(_f9, _c9, _beta); - _fa = vmlaq_f32(_fa, _ca, _beta); - _fb = vmlaq_f32(_fb, _cb, _beta); - _fc = vmlaq_f32(_fc, _cc, _beta); - _fd = vmlaq_f32(_fd, _cd, _beta); - _fe = vmlaq_f32(_fe, _ce, _beta); - _ff = vmlaq_f32(_ff, _cf, _beta); } - pC += 64; + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + _c4 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c45)); + _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c45)); + _c6 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c67)); + _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c67)); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); + } + pC += 8; } } if (broadcast_type_C == 4) @@ -8623,6 +8489,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x8_t _hf6 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f6), (uint16x4_t)vcvt_f16_f32(_fe)); uint16x8_t _hf7 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f7), (uint16x4_t)vcvt_f16_f32(_ff)); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (out_elempack == 8) { transpose8x8_u16(_hf0, _hf1, _hf2, _hf3, _hf4, _hf5, _hf6, _hf7); @@ -8635,7 +8502,8 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma vst1q_u16(p0 + 48, _hf6); vst1q_u16(p0 + 56, _hf7); } - else if (out_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) { uint16x8x4_t _hfa; uint16x8x4_t _hfb; @@ -8650,7 +8518,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma vst4q_u16(p0, _hfa); vst4q_u16(p0 + out_hstep * 4, _hfb); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1q_u16(p0, _hf0); vst1q_u16(p0 + out_hstep, _hf1); @@ -8772,23 +8640,31 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { - if (c_elempack == 1) +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = vcvt_f32_f16((float16x4_t)_cc0); - _c1 = vcvt_f32_f16((float16x4_t)_cc1); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)_cc2); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)_cc3); + uint16x8_t _c04 = vld1q_u16(pC); + uint16x8_t _c15 = vld1q_u16(pC + 8); + uint16x8_t _c26 = vld1q_u16(pC + 16); + uint16x8_t _c37 = vld1q_u16(pC + 24); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c04)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c15)); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c26)); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c37)); + float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c04)); + float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c15)); + float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c26)); + float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c37)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } else { @@ -8797,34 +8673,15 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - _cc0 = vld1_u16(pC + c_hstep * 4); - _cc1 = vld1_u16(pC + c_hstep * 5); - _cc2 = vld1_u16(pC + c_hstep * 6); - _cc3 = vld1_u16(pC + c_hstep * 7); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = vcvt_f32_f16((float16x4_t)_cc0); - _c1 = vcvt_f32_f16((float16x4_t)_cc1); - _c2 = vcvt_f32_f16((float16x4_t)_cc2); - _c3 = vcvt_f32_f16((float16x4_t)_cc3); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; + pC += 32; } - else if (c_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); uint16x8_t _c23 = vld1q_u16(pC + 8); @@ -8870,30 +8727,23 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } pC += 16; } - else // if (c_elempack == 8) + if (c_elempack == 1) { - uint16x8_t _c04 = vld1q_u16(pC); - uint16x8_t _c15 = vld1q_u16(pC + 8); - uint16x8_t _c26 = vld1q_u16(pC + 16); - uint16x8_t _c37 = vld1q_u16(pC + 24); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c04)); - _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c15)); - float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c26)); - float32x4_t _c3 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c37)); - float32x4_t _c4 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c04)); - float32x4_t _c5 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c15)); - float32x4_t _c6 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c26)); - float32x4_t _c7 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c37)); + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + float32x4_t _c2 = vcvt_f32_f16((float16x4_t)_cc2); + float32x4_t _c3 = vcvt_f32_f16((float16x4_t)_cc3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); } else { @@ -8902,12 +8752,32 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f1 = vmlaq_f32(_f1, _c1, _beta); _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); } - pC += 32; + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = vcvt_f32_f16((float16x4_t)_cc0); + _c1 = vcvt_f32_f16((float16x4_t)_cc1); + _c2 = vcvt_f32_f16((float16x4_t)_cc2); + _c3 = vcvt_f32_f16((float16x4_t)_cc3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; } } if (broadcast_type_C == 4) @@ -8964,7 +8834,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _hf.val[3] = _hf3; vst4q_u16(p0, _hf); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1q_u16(p0, _hf0); vst1q_u16(p0 + out_hstep, _hf1); @@ -9039,6 +8909,28 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma { float32x4_t _c2; float32x4_t _c3; +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) + { + uint16x8_t _c02 = vld1q_u16(pC); + uint16x8_t _c13 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c02)); + _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c13)); + _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c02)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c13)); + pC += 16; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 8; + } if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -9067,26 +8959,6 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); pC += 2; } - else if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - pC += 8; - } - else // if (c_elempack == 8) - { - uint16x8_t _c02 = vld1q_u16(pC); - uint16x8_t _c13 = vld1q_u16(pC + 8); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c02)); - _c1 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c13)); - _c2 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c02)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c13)); - pC += 16; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -9149,6 +9021,21 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 8) + { + uint16x8_t _c = vld1q_u16(pC); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); + pC += 8; + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (c_elempack == 4) + { + _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); + _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep * 4)); + pC += 4; + } if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -9164,19 +9051,6 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); pC += 1; } - else if (c_elempack == 4) - { - _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); - _c1 = vcvt_f32_f16((float16x4_t)vld1_u16(pC + c_hstep * 4)); - pC += 4; - } - else // if (c_elempack == 8) - { - uint16x8_t _c = vld1q_u16(pC); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); - pC += 8; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -9349,6 +9223,14 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x8_t _c23; uint16x8_t _c45; uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } if (c_elempack == 1) { _c01 = vld1q_u16(pC); @@ -9358,14 +9240,6 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma transpose8x4_u16(_c01, _c23, _c45, _c67); pC += 8; } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + 8); - _c45 = vld1q_u16(pC + 16); - _c67 = vld1q_u16(pC + 24); - pC += 32; - } _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); float32x4_t _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); @@ -9451,6 +9325,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (out_elempack == 8) { transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); @@ -9460,7 +9335,8 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma vst1q_u16(p0 + 16, vcombine_u16(_hf2, _hf6)); vst1q_u16(p0 + 24, vcombine_u16(_hf3, _hf7)); } - else if (out_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) { uint16x4x4_t _hfa; uint16x4x4_t _hfb; @@ -9475,7 +9351,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma vst4_u16(p0, _hfa); vst4_u16(p0 + out_hstep * 4, _hfb); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1_u16(p0, _hf0); vst1_u16(p0 + out_hstep, _hf1); @@ -9558,6 +9434,16 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma float32x4_t _c1; float32x4_t _c2; float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); + _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); + _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); + _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); + pC += 16; + } if (c_elempack == 1) { uint16x4_t _cc0 = vld1_u16(pC); @@ -9571,16 +9457,6 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _c3 = vcvt_f32_f16((float16x4_t)_cc3); pC += 4; } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c01)); - _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c01)); - _c2 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c23)); - _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c23)); - pC += 16; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -9643,7 +9519,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _hf.val[3] = _hf3; vst4_u16(p0, _hf); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1_u16(p0, _hf0); vst1_u16(p0 + out_hstep, _hf1); @@ -9698,6 +9574,11 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma if (broadcast_type_C == 3) { uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } if (c_elempack == 1) { _c = uint16x8_t(); @@ -9711,11 +9592,6 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); pC += 2; } - else // if (c_elempack == 4) - { - _c = vld1q_u16(pC); - pC += 8; - } _c0 = vcvt_f32_f16((float16x4_t)vget_low_u16(_c)); float32x4_t _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); if (beta == 1.f) @@ -9770,6 +9646,11 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma if (broadcast_type_C == 3) { uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } if (c_elempack == 1) { _c = uint16x4_t(); @@ -9779,11 +9660,6 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); pC += 1; } - else // if (c_elempack == 4) - { - _c = vld1_u16(pC); - pC += 4; - } _c0 = vcvt_f32_f16((float16x4_t)_c); _f0 = vmlaq_n_f32(_f0, _c0, beta); } @@ -9939,17 +9815,19 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); } - else if (out_elempack == 4) +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf2)); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_hf1, _hf3)); } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_hf0, 0); p0[1] = vget_lane_u16(_hf2, 0); @@ -10038,7 +9916,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_hf0, 0); p0[1] = vget_lane_u16(_hf1, 0); @@ -10267,36 +10145,41 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); } - else if (out_elempack == 8) + else { - vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); - vst1q_u16(p0 + out_hstep * 8, vcombine_u16(_hf2, _hf3)); - } - else if (out_elempack == 4) - { - vst1_u16(p0, _hf0); - vst1_u16(p0 + out_hstep * 4, _hf1); - vst1_u16(p0 + out_hstep * 8, _hf2); - vst1_u16(p0 + out_hstep * 12, _hf3); - } - else // if (out_elempack == 1) - { - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + out_hstep * 8, vcombine_u16(_hf2, _hf3)); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + vst1_u16(p0 + out_hstep * 8, _hf2); + vst1_u16(p0 + out_hstep * 12, _hf3); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); + } } pp += 16; @@ -10348,25 +10231,34 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); - if (out_hstep == 1 || out_elempack == 8) + if (out_hstep == 1) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); } - else if (out_elempack == 4) + else { - vst1_u16(p0, _hf0); - vst1_u16(p0 + out_hstep * 4, _hf1); - } - else // if (out_elempack == 1) - { - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 8) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + } } pp += 8; @@ -10395,16 +10287,23 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); - if (out_hstep == 1 || out_elempack == 4) + if (out_hstep == 1) { vst1_u16(p0, _hf0); } - else // if (out_elempack == 1) + else { - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + if (out_elempack == 4) + { + vst1_u16(p0, _hf0); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + } } pp += 4; From 854be473d7a6e3398634928967faaa2b0b1a993f Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 12 Oct 2024 14:22:23 +0800 Subject: [PATCH 52/55] cc --- src/layer/arm/gemm_int8_bf16s.h | 505 ++++++++++++++++---------------- 1 file changed, 259 insertions(+), 246 deletions(-) diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 220e2a696aa3..350f20ab4c0f 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -4385,13 +4385,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { - if (c_elempack == 1) + if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4424,10 +4423,9 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f7 = vmlaq_f32(_f7, _c7, _beta); } _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 5); - _c45 = vld1q_u16(pC + c_hstep * 6); - _c67 = vld1q_u16(pC + c_hstep * 7); - transpose8x4_u16(_c01, _c23, _c45, _c67); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4459,14 +4457,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _fe = vmlaq_f32(_fe, _c6, _beta); _ff = vmlaq_f32(_ff, _c7, _beta); } - pC += 8; + pC += 32; } - else // if (c_elempack == 4) + if (c_elempack == 1) { uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + 16); - uint16x8_t _c67 = vld1q_u16(pC + 24); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4499,9 +4498,10 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f7 = vmlaq_f32(_f7, _c7, _beta); } _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c45 = vld1q_u16(pC + c_hstep * 4 + 16); - _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); _c2 = bfloat2float(vget_low_u16(_c23)); @@ -4533,7 +4533,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _fe = vmlaq_f32(_fe, _c6, _beta); _ff = vmlaq_f32(_ff, _c7, _beta); } - pC += 32; + pC += 8; } } if (broadcast_type_C == 4) @@ -4625,7 +4625,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + out_hstep * 4 + 24, vcombine_u16(_bfe, _bff)); p0 += 32; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); @@ -4749,17 +4749,14 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { - if (c_elempack == 1) + if (c_elempack == 4) { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - float32x4_t _c2 = bfloat2float(_cc2); - float32x4_t _c3 = bfloat2float(_cc3); + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -4775,15 +4772,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _cc0 = vld1_u16(pC + c_hstep * 4); - _cc1 = vld1_u16(pC + c_hstep * 5); - _cc2 = vld1_u16(pC + c_hstep * 6); - _cc3 = vld1_u16(pC + c_hstep * 7); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - _c2 = bfloat2float(_cc2); - _c3 = bfloat2float(_cc3); + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0); @@ -4799,16 +4793,19 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f6 = vmlaq_f32(_f6, _c2, _beta); _f7 = vmlaq_f32(_f7, _c3, _beta); } - pC += 4; + pC += 16; } - else // if (c_elempack == 4) + if (c_elempack == 1) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -4824,12 +4821,15 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0); @@ -4845,7 +4845,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _f6 = vmlaq_f32(_f6, _c2, _beta); _f7 = vmlaq_f32(_f7, _c3, _beta); } - pC += 16; + pC += 4; } } if (broadcast_type_C == 4) @@ -4905,7 +4905,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + out_hstep * 4 + 8, vcombine_u16(_bf6, _bf7)); p0 += 16; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); @@ -4986,6 +4986,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& { uint16x8_t _c01; uint16x8_t _c23; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + c_hstep * 4); + pC += 8; + } if (c_elempack == 1) { _c01 = uint16x8_t(); @@ -5008,12 +5014,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c23 = vsetq_lane_u16(pC[c_hstep * 7 + 1], _c23, 7); pC += 2; } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + c_hstep * 4); - pC += 8; - } _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -5066,7 +5066,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf2, _bf3)); p0 += 8; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_bf0, 0); p0[1] = vget_lane_u16(_bf1, 0); @@ -5111,6 +5111,12 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -5126,12 +5132,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c1 = bfloat2float(vget_high_u16(_c01)); pC += 1; } - else // if (c_elempack == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - pC += 4; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -5169,7 +5169,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1_u16(p0 + out_hstep * 4, _bf1); p0 += 4; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_bf0, 0); p0[out_hstep] = vget_lane_u16(_bf0, 1); @@ -5324,6 +5324,14 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& uint16x8_t _c23; uint16x8_t _c45; uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } if (c_elempack == 1) { _c01 = vld1q_u16(pC); @@ -5333,14 +5341,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& transpose8x4_u16(_c01, _c23, _c45, _c67); pC += 8; } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + 8); - _c45 = vld1q_u16(pC + 16); - _c67 = vld1q_u16(pC + 24); - pC += 32; - } _c0 = bfloat2float(vget_low_u16(_c01)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -5434,7 +5434,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 24, vcombine_u16(_bf6, _bf7)); p0 += 32; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); transpose4x4_u16(_bf4, _bf5, _bf6, _bf7); @@ -5515,6 +5515,16 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& float32x4_t _c1; float32x4_t _c2; float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; + } if (c_elempack == 1) { uint16x4_t _cc0 = vld1_u16(pC); @@ -5528,16 +5538,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c3 = bfloat2float(_cc3); pC += 4; } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - pC += 16; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -5597,7 +5597,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); p0 += 16; } - else // if (out_elempack == 1) + if (out_elempack == 1) { transpose4x4_u16(_bf0, _bf1, _bf2, _bf3); vst1_u16(p0, _bf0); @@ -5653,6 +5653,11 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } if (c_elempack == 1) { _c = uint16x8_t(); @@ -5666,11 +5671,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); pC += 2; } - else // if (c_elempack == 4) - { - _c = vld1q_u16(pC); - pC += 8; - } _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); if (beta == 1.f) @@ -5710,7 +5710,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); p0 += 8; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_bf0, 0); p0[1] = vget_lane_u16(_bf1, 0); @@ -5742,6 +5742,11 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& if (broadcast_type_C == 3) { uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } if (c_elempack == 1) { _c = uint16x4_t(); @@ -5751,11 +5756,6 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); pC += 1; } - else // if (c_elempack == 4) - { - _c = vld1_u16(pC); - pC += 4; - } _c0 = bfloat2float(_c); _f0 = vmlaq_n_f32(_f0, _c0, beta); } @@ -5776,7 +5776,7 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat& vst1_u16(p0, _bf0); p0 += 4; } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_bf0, 0); p0[out_hstep] = vget_lane_u16(_bf0, 1); @@ -6575,13 +6575,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { - if (c_elempack == 1) + if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep); - uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); - uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); - transpose8x4_u16(_c01, _c23, _c45, _c67); + uint16x8_t _c23 = vld1q_u16(pC + 8); + uint16x8_t _c45 = vld1q_u16(pC + 16); + uint16x8_t _c67 = vld1q_u16(pC + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -6614,10 +6613,9 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vmlaq_f32(_f7, _c7, _beta); } _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 5); - _c45 = vld1q_u16(pC + c_hstep * 6); - _c67 = vld1q_u16(pC + c_hstep * 7); - transpose8x4_u16(_c01, _c23, _c45, _c67); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c45 = vld1q_u16(pC + c_hstep * 4 + 16); + _c67 = vld1q_u16(pC + c_hstep * 4 + 24); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); _c2 = bfloat2float(vget_low_u16(_c23)); @@ -6649,14 +6647,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _fe = vmlaq_f32(_fe, _c6, _beta); _ff = vmlaq_f32(_ff, _c7, _beta); } - pC += 8; + pC += 32; } - else // if (c_elempack == 4) + if (c_elempack == 1) { uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - uint16x8_t _c45 = vld1q_u16(pC + 16); - uint16x8_t _c67 = vld1q_u16(pC + 24); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep); + uint16x8_t _c45 = vld1q_u16(pC + c_hstep * 2); + uint16x8_t _c67 = vld1q_u16(pC + c_hstep * 3); + transpose8x4_u16(_c01, _c23, _c45, _c67); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -6689,9 +6688,10 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f7 = vmlaq_f32(_f7, _c7, _beta); } _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c45 = vld1q_u16(pC + c_hstep * 4 + 16); - _c67 = vld1q_u16(pC + c_hstep * 4 + 24); + _c23 = vld1q_u16(pC + c_hstep * 5); + _c45 = vld1q_u16(pC + c_hstep * 6); + _c67 = vld1q_u16(pC + c_hstep * 7); + transpose8x4_u16(_c01, _c23, _c45, _c67); _c0 = bfloat2float(vget_low_u16(_c01)); _c1 = bfloat2float(vget_high_u16(_c01)); _c2 = bfloat2float(vget_low_u16(_c23)); @@ -6723,7 +6723,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _fe = vmlaq_f32(_fe, _c6, _beta); _ff = vmlaq_f32(_ff, _c7, _beta); } - pC += 32; + pC += 8; } } if (broadcast_type_C == 4) @@ -6810,7 +6810,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma vst4q_u16(p0, _bfa); vst4q_u16(p0 + out_hstep * 4, _bfb); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1q_u16(p0, _bf0); vst1q_u16(p0 + out_hstep, _bf1); @@ -6932,17 +6932,14 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { - if (c_elempack == 1) + if (c_elempack == 4) { - uint16x4_t _cc0 = vld1_u16(pC); - uint16x4_t _cc1 = vld1_u16(pC + c_hstep); - uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); - uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - float32x4_t _c2 = bfloat2float(_cc2); - float32x4_t _c3 = bfloat2float(_cc3); + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); + float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -6958,15 +6955,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _cc0 = vld1_u16(pC + c_hstep * 4); - _cc1 = vld1_u16(pC + c_hstep * 5); - _cc2 = vld1_u16(pC + c_hstep * 6); - _cc3 = vld1_u16(pC + c_hstep * 7); - transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); - _c0 = bfloat2float(_cc0); - _c1 = bfloat2float(_cc1); - _c2 = bfloat2float(_cc2); - _c3 = bfloat2float(_cc3); + _c01 = vld1q_u16(pC + c_hstep * 4); + _c23 = vld1q_u16(pC + c_hstep * 4 + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0); @@ -6982,16 +6976,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f6 = vmlaq_f32(_f6, _c2, _beta); _f7 = vmlaq_f32(_f7, _c3, _beta); } - pC += 4; + pC += 16; } - else // if (c_elempack == 4) + if (c_elempack == 1) { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); - float32x4_t _c3 = bfloat2float(vget_high_u16(_c23)); + uint16x4_t _cc0 = vld1_u16(pC); + uint16x4_t _cc1 = vld1_u16(pC + c_hstep); + uint16x4_t _cc2 = vld1_u16(pC + c_hstep * 2); + uint16x4_t _cc3 = vld1_u16(pC + c_hstep * 3); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + float32x4_t _c2 = bfloat2float(_cc2); + float32x4_t _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -7007,12 +7004,15 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f2 = vmlaq_f32(_f2, _c2, _beta); _f3 = vmlaq_f32(_f3, _c3, _beta); } - _c01 = vld1q_u16(pC + c_hstep * 4); - _c23 = vld1q_u16(pC + c_hstep * 4 + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); + _cc0 = vld1_u16(pC + c_hstep * 4); + _cc1 = vld1_u16(pC + c_hstep * 5); + _cc2 = vld1_u16(pC + c_hstep * 6); + _cc3 = vld1_u16(pC + c_hstep * 7); + transpose4x4_u16(_cc0, _cc1, _cc2, _cc3); + _c0 = bfloat2float(_cc0); + _c1 = bfloat2float(_cc1); + _c2 = bfloat2float(_cc2); + _c3 = bfloat2float(_cc3); if (beta == 1.f) { _f4 = vaddq_f32(_f4, _c0); @@ -7028,7 +7028,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f6 = vmlaq_f32(_f6, _c2, _beta); _f7 = vmlaq_f32(_f7, _c3, _beta); } - pC += 16; + pC += 4; } } if (broadcast_type_C == 4) @@ -7085,7 +7085,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _bf.val[3] = _bf3; vst4q_u16(p0, _bf); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1q_u16(p0, _bf0); vst1q_u16(p0 + out_hstep, _bf1); @@ -7160,6 +7160,16 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { float32x4_t _c2; float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 8; + } if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -7188,16 +7198,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c3 = bfloat2float(vget_high_u16(_c23)); pC += 2; } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + c_hstep * 4); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - pC += 8; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -7260,6 +7260,12 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { + if (c_elempack == 4) + { + _c0 = bfloat2float(vld1_u16(pC)); + _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); + pC += 4; + } if (c_elempack == 1) { uint16x8_t _c01 = uint16x8_t(); @@ -7275,12 +7281,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c1 = bfloat2float(vget_high_u16(_c01)); pC += 1; } - else // if (c_elempack == 4) - { - _c0 = bfloat2float(vld1_u16(pC)); - _c1 = bfloat2float(vld1_u16(pC + c_hstep * 4)); - pC += 4; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -7453,6 +7453,14 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x8_t _c23; uint16x8_t _c45; uint16x8_t _c67; + if (c_elempack == 4) + { + _c01 = vld1q_u16(pC); + _c23 = vld1q_u16(pC + 8); + _c45 = vld1q_u16(pC + 16); + _c67 = vld1q_u16(pC + 24); + pC += 32; + } if (c_elempack == 1) { _c01 = vld1q_u16(pC); @@ -7462,14 +7470,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma transpose8x4_u16(_c01, _c23, _c45, _c67); pC += 8; } - else // if (c_elempack == 4) - { - _c01 = vld1q_u16(pC); - _c23 = vld1q_u16(pC + 8); - _c45 = vld1q_u16(pC + 16); - _c67 = vld1q_u16(pC + 24); - pC += 32; - } _c0 = bfloat2float(vget_low_u16(_c01)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c01)); float32x4_t _c2 = bfloat2float(vget_low_u16(_c23)); @@ -7570,7 +7570,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma vst4_u16(p0, _bfa); vst4_u16(p0 + out_hstep * 4, _bfb); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1_u16(p0, _bf0); vst1_u16(p0 + out_hstep, _bf1); @@ -7653,6 +7653,16 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma float32x4_t _c1; float32x4_t _c2; float32x4_t _c3; + if (c_elempack == 4) + { + uint16x8_t _c01 = vld1q_u16(pC); + uint16x8_t _c23 = vld1q_u16(pC + 8); + _c0 = bfloat2float(vget_low_u16(_c01)); + _c1 = bfloat2float(vget_high_u16(_c01)); + _c2 = bfloat2float(vget_low_u16(_c23)); + _c3 = bfloat2float(vget_high_u16(_c23)); + pC += 16; + } if (c_elempack == 1) { uint16x4_t _cc0 = vld1_u16(pC); @@ -7666,16 +7676,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c3 = bfloat2float(_cc3); pC += 4; } - else // if (c_elempack == 4) - { - uint16x8_t _c01 = vld1q_u16(pC); - uint16x8_t _c23 = vld1q_u16(pC + 8); - _c0 = bfloat2float(vget_low_u16(_c01)); - _c1 = bfloat2float(vget_high_u16(_c01)); - _c2 = bfloat2float(vget_low_u16(_c23)); - _c3 = bfloat2float(vget_high_u16(_c23)); - pC += 16; - } if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); @@ -7738,7 +7738,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _bf.val[3] = _bf3; vst4_u16(p0, _bf); } - else // if (out_elempack == 1) + if (out_elempack == 1) { vst1_u16(p0, _bf0); vst1_u16(p0 + out_hstep, _bf1); @@ -7793,6 +7793,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 3) { uint16x8_t _c; + if (c_elempack == 4) + { + _c = vld1q_u16(pC); + pC += 8; + } if (c_elempack == 1) { _c = uint16x8_t(); @@ -7806,11 +7811,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c = vsetq_lane_u16(pC[c_hstep * 3 + 1], _c, 7); pC += 2; } - else // if (c_elempack == 4) - { - _c = vld1q_u16(pC); - pC += 8; - } _c0 = bfloat2float(vget_low_u16(_c)); float32x4_t _c1 = bfloat2float(vget_high_u16(_c)); if (beta == 1.f) @@ -7865,6 +7865,11 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma if (broadcast_type_C == 3) { uint16x4_t _c; + if (c_elempack == 4) + { + _c = vld1_u16(pC); + pC += 4; + } if (c_elempack == 1) { _c = uint16x4_t(); @@ -7874,11 +7879,6 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _c = vset_lane_u16(pC[c_hstep * 3], _c, 3); pC += 1; } - else // if (c_elempack == 4) - { - _c = vld1_u16(pC); - pC += 4; - } _c0 = bfloat2float(_c); _f0 = vmlaq_n_f32(_f0, _c0, beta); } @@ -8039,7 +8039,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma vst1q_u16(p0, vcombine_u16(_bf0, _bf2)); vst1q_u16(p0 + out_hstep * 4, vcombine_u16(_bf1, _bf3)); } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_bf0, 0); p0[1] = vget_lane_u16(_bf2, 0); @@ -8128,7 +8128,7 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); } - else // if (out_elempack == 1) + if (out_elempack == 1) { p0[0] = vget_lane_u16(_bf0, 0); p0[1] = vget_lane_u16(_bf1, 0); @@ -8357,31 +8357,34 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); } - else if (out_elempack == 4) - { - vst1_u16(p0, _bf0); - vst1_u16(p0 + out_hstep * 4, _bf1); - vst1_u16(p0 + out_hstep * 8, _bf2); - vst1_u16(p0 + out_hstep * 12, _bf3); - } - else if (out_elempack == 1) + else { - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + vst1_u16(p0 + out_hstep * 8, _bf2); + vst1_u16(p0 + out_hstep * 12, _bf3); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + } } pp += 16; @@ -8437,21 +8440,24 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma { vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); } - else if (out_elempack == 4) - { - vst1_u16(p0, _bf0); - vst1_u16(p0 + out_hstep * 4, _bf1); - } - else // if (out_elempack == 1) + else { - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + } } pp += 8; @@ -8479,16 +8485,23 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _bf0 = float2bfloat(_f0); - if (out_hstep == 1 || out_elempack == 4) + if (out_hstep == 1) { vst1_u16(p0, _bf0); } - else // if (out_elempack == 1) + else { - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + if (out_elempack == 4) + { + vst1_u16(p0, _bf0); + } + if (out_elempack == 1) + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + } } pp += 4; From 1e21bb96039864fc54da65079972048e01e1729e Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 12 Oct 2024 15:10:07 +0800 Subject: [PATCH 53/55] cc --- src/layer/arm/gemm_int8.h | 9398 +++++++++++++------------------------ 1 file changed, 3280 insertions(+), 6118 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 09c94d5226f2..020df8b9c84a 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -5675,423 +5675,388 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 #else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - { - _sum8 = vrev64q_s32(_sum8); - _sum9 = vrev64q_s32(_sum9); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sumc = vrev64q_s32(_sumc); - _sumd = vrev64q_s32(_sumd); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - _sum8 = vextq_s32(_sum8, _sum8, 2); - _sum9 = vextq_s32(_sum9, _sum9, 2); - _suma = vextq_s32(_suma, _suma, 2); - _sumb = vextq_s32(_sumb, _sumb, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - _sum9 = vrev64q_s32(_sum9); - _sumb = vrev64q_s32(_sumb); - _sumd = vrev64q_s32(_sumd); - _sumf = vrev64q_s32(_sumf); - } - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } #endif - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c1); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c1); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c1); - _ff = vaddq_f32(_ff, _c1); - } - if (broadcast_type_C == 3) + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) { - if (c_elempack == 1) + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 4 * 2); + float32x4_t _c3 = vld1q_f32(pC + 4 * 3); + float32x4_t _c4 = vld1q_f32(pC + 4 * 4); + float32x4_t _c5 = vld1q_f32(pC + 4 * 5); + float32x4_t _c6 = vld1q_f32(pC + 4 * 6); + float32x4_t _c7 = vld1q_f32(pC + 4 * 7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 4 * 2); + _c3 = vld1q_f32(pC + c_hstep * 4 + 4 * 3); + _c4 = vld1q_f32(pC + c_hstep * 4 + 4 * 4); + _c5 = vld1q_f32(pC + c_hstep * 4 + 4 * 5); + _c6 = vld1q_f32(pC + c_hstep * 4 + 4 * 6); + _c7 = vld1q_f32(pC + c_hstep * 4 + 4 * 7); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 4 * 2); - float32x4_t _c3 = vld1q_f32(pC + 4 * 3); - float32x4_t _c4 = vld1q_f32(pC + 4 * 4); - float32x4_t _c5 = vld1q_f32(pC + 4 * 5); - float32x4_t _c6 = vld1q_f32(pC + 4 * 6); - float32x4_t _c7 = vld1q_f32(pC + 4 * 7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 4 * 2); - _c3 = vld1q_f32(pC + c_hstep * 4 + 4 * 3); - _c4 = vld1q_f32(pC + c_hstep * 4 + 4 * 4); - _c5 = vld1q_f32(pC + c_hstep * 4 + 4 * 5); - _c6 = vld1q_f32(pC + c_hstep * 4 + 4 * 6); - _c7 = vld1q_f32(pC + c_hstep * 4 + 4 * 7); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } + pC += 32; } - if (broadcast_type_C == 4) + if (c_elempack == 1) { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else { float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } - _c0 = vdupq_laneq_f32(_cc0, 0); - _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); pC += 8; } } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); @@ -6108,220 +6073,247 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& vst1q_f32(p0 + out_hstep * 4 + 20, _fd); vst1q_f32(p0 + out_hstep * 4 + 24, _fe); vst1q_f32(p0 + out_hstep * 4 + 28, _ff); - - pp += 64; p0 += 32; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + transpose4x4_ps(_f8, _f9, _fa, _fb); + transpose4x4_ps(_fc, _fd, _fe, _ff); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + vst1q_f32(p0 + out_hstep * 4, _f8); + vst1q_f32(p0 + out_hstep * 4 + 4, _fc); + vst1q_f32(p0 + out_hstep * 5, _f9); + vst1q_f32(p0 + out_hstep * 5 + 4, _fd); + vst1q_f32(p0 + out_hstep * 6, _fa); + vst1q_f32(p0 + out_hstep * 6 + 4, _fe); + vst1q_f32(p0 + out_hstep * 7, _fb); + vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + p0 += 8; + } -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 + pp += 64; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 #else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c1); - _f7 = vaddq_f32(_f7, _c1); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - _c2 = vld1q_f32(pC + c_hstep * 2); - _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - transpose4x4_ps(_c0, _c1, _c2, _c3); - pC += 4; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 8); - _c3 = vld1q_f32(pC + c_hstep * 4 + 12); - pC += 16; - } - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _c = vld1q_f32(pC); - _c = vmulq_n_f32(_c, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + } + if (c_elempack == 4) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + pC += 16; + } + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; + } + if (beta == 1.f) + { _f4 = vaddq_f32(_f4, _c0); _f5 = vaddq_f32(_f5, _c1); _f6 = vaddq_f32(_f6, _c2); _f7 = vaddq_f32(_f7, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + float32x4_t _c = vld1q_f32(pC); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); @@ -6330,1166 +6322,259 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& vst1q_f32(p0 + out_hstep * 4 + 4, _f5); vst1q_f32(p0 + out_hstep * 4 + 8, _f6); vst1q_f32(p0 + out_hstep * 4 + 12, _f7); - - pp += 32; p0 += 16; } - for (; jj + 1 < max_jj; jj += 2) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 4, _f4); + vst1q_f32(p0 + out_hstep * 5, _f5); + vst1q_f32(p0 + out_hstep * 6, _f6); + vst1q_f32(p0 + out_hstep * 7, _f7); + p0 += 4; + } + + pp += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 #else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep * 4); + _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + pC += 8; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x4_t _c01 = vcombine_f32(_cc0, _cc1); - float32x4_t _c23 = vcombine_f32(_cc2, _cc3); - float32x4x2_t _ccc0 = vuzpq_f32(_c01, _c23); - _c0 = _ccc0.val[0]; - _c1 = _ccc0.val[1]; - float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); - float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); - float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); - float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - float32x4_t _c45 = vcombine_f32(_cc4, _cc5); - float32x4_t _c67 = vcombine_f32(_cc6, _cc7); - float32x4x2_t _ccc1 = vuzpq_f32(_c45, _c67); - _c2 = _ccc1.val[0]; - _c3 = _ccc1.val[1]; - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + c_hstep * 4); - _c3 = vld1q_f32(pC + c_hstep * 4 + 4); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _ccc0 = vuzpq_f32(_c01, _c23); + _c0 = _ccc0.val[0]; + _c1 = _ccc0.val[1]; + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _c45 = vcombine_f32(_cc4, _cc5); + float32x4_t _c67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc1 = vuzpq_f32(_c45, _c67); + _c2 = _ccc1.val[0]; + _c3 = _ccc1.val[1]; + pC += 2; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x2_t _c = vld1_f32(pC); - _c = vmul_n_f32(_c, beta); - _c0 = vdupq_lane_f32(_c, 0); - _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 2; + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + out_hstep * 4, _f2); vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - - pp += 16; p0 += 8; } - for (; jj < max_jj; jj++) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); - _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); - _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); - _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 4); - pC += 4; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); - - pp += 8; - p0 += 4; + if (out_elempack == 1) + { + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + float32x4x2_t _f23 = vzipq_f32(_f2, _f3); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f23.val[0])); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f23.val[0])); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f23.val[1])); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f23.val[1])); + p0 += 2; } + + pp += 16; } - if (out_elempack == 1) + for (; jj < max_jj; jj++) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); - int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); - int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); - _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); - } -#else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); - float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); - float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); - float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); - float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); - float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); - float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); - float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); - float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - float32x4_t _cc4 = vdupq_laneq_f32(_c1, 0); - float32x4_t _cc5 = vdupq_laneq_f32(_c1, 1); - float32x4_t _cc6 = vdupq_laneq_f32(_c1, 2); - float32x4_t _cc7 = vdupq_laneq_f32(_c1, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - _f8 = vaddq_f32(_f8, _cc4); - _f9 = vaddq_f32(_f9, _cc4); - _fa = vaddq_f32(_fa, _cc5); - _fb = vaddq_f32(_fb, _cc5); - _fc = vaddq_f32(_fc, _cc6); - _fd = vaddq_f32(_fd, _cc6); - _fe = vaddq_f32(_fe, _cc7); - _ff = vaddq_f32(_ff, _cc7); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - _cc1 = vld4q_f32(pC + c_hstep * 4 + 16); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _cc0.val[0]); - _f9 = vaddq_f32(_f9, _cc1.val[0]); - _fa = vaddq_f32(_fa, _cc0.val[1]); - _fb = vaddq_f32(_fb, _cc1.val[1]); - _fc = vaddq_f32(_fc, _cc0.val[2]); - _fd = vaddq_f32(_fd, _cc1.val[2]); - _fe = vaddq_f32(_fe, _cc0.val[3]); - _ff = vaddq_f32(_ff, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _cc0.val[0], _beta); - _f9 = vmlaq_f32(_f9, _cc1.val[0], _beta); - _fa = vmlaq_f32(_fa, _cc0.val[1], _beta); - _fb = vmlaq_f32(_fb, _cc1.val[1], _beta); - _fc = vmlaq_f32(_fc, _cc0.val[2], _beta); - _fd = vmlaq_f32(_fd, _cc1.val[2], _beta); - _fe = vmlaq_f32(_fe, _cc0.val[3], _beta); - _ff = vmlaq_f32(_ff, _cc1.val[3], _beta); - } - pC += 32; - } - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c1); - pC += 8; - } - } + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + out_hstep, _f2); - vst1q_f32(p0 + out_hstep + 4, _f3); - vst1q_f32(p0 + out_hstep * 2, _f4); - vst1q_f32(p0 + out_hstep * 2 + 4, _f5); - vst1q_f32(p0 + out_hstep * 3, _f6); - vst1q_f32(p0 + out_hstep * 3 + 4, _f7); - vst1q_f32(p0 + out_hstep * 4, _f8); - vst1q_f32(p0 + out_hstep * 4 + 4, _f9); - vst1q_f32(p0 + out_hstep * 5, _fa); - vst1q_f32(p0 + out_hstep * 5 + 4, _fb); - vst1q_f32(p0 + out_hstep * 6, _fc); - vst1q_f32(p0 + out_hstep * 6 + 4, _fd); - vst1q_f32(p0 + out_hstep * 7, _fe); - vst1q_f32(p0 + out_hstep * 7 + 4, _ff); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - pp += 64; - p0 += 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (pC) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); - float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); - float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); - float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); - float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); -#endif - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); -#if __aarch64__ - _cc0 = vdupq_laneq_f32(_c1, 0); - _cc1 = vdupq_laneq_f32(_c1, 1); - _cc2 = vdupq_laneq_f32(_c1, 2); - _cc3 = vdupq_laneq_f32(_c1, 3); -#else - _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); - _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); - _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); - _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif - _f4 = vaddq_f32(_f4, _cc0); - _f5 = vaddq_f32(_f5, _cc1); - _f6 = vaddq_f32(_f6, _cc2); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc0.val[2]); - _f3 = vaddq_f32(_f3, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); - _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _cc0.val[0]); - _f5 = vaddq_f32(_f5, _cc0.val[1]); - _f6 = vaddq_f32(_f6, _cc0.val[2]); - _f7 = vaddq_f32(_f7, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); - } - pC += 16; - } - } - if (broadcast_type_C == 4) + if (c_elempack == 4) { _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _c1 = vld1q_f32(pC + c_hstep * 4); pC += 4; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep * 2, _f2); - vst1q_f32(p0 + out_hstep * 3, _f3); - vst1q_f32(p0 + out_hstep * 4, _f4); - vst1q_f32(p0 + out_hstep * 5, _f5); - vst1q_f32(p0 + out_hstep * 6, _f6); - vst1q_f32(p0 + out_hstep * 7, _f7); - - pp += 32; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - // e0 e1 f0 f1 - // g0 g1 h0 h1 - { - int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _sum13 = vzipq_s32(_sum2, _sum3); - _sum0 = _sum02.val[0]; - _sum1 = _sum02.val[1]; - _sum2 = _sum13.val[0]; - _sum3 = _sum13.val[1]; - } -#else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 - // e0 e1 f0 f1 - // g0 g1 h0 h1 - { - int32x4x2_t _t0 = vuzpq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vuzpq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_t0.val[0], _t1.val[0]); - int32x4x2_t _t3 = vzipq_s32(_t1.val[1], _t0.val[1]); - _sum0 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4x2_t _descale01 = vzipq_f32(_descale0, _descale0); - float32x4x2_t _descale23 = vzipq_f32(_descale1, _descale1); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale23.val[0]); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale23.val[1]); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (c_elempack == 1) { - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc1.val[0]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + pC += 1; } - if (broadcast_type_C == 3) + if (beta == 1.f) { - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - _c0 = vcombine_f32(_cc0, _cc1); - _c1 = vcombine_f32(_cc2, _cc3); - float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); - float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); - float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); - float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - float32x4_t _c2 = vcombine_f32(_cc4, _cc5); - float32x4_t _c3 = vcombine_f32(_cc6, _cc7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - float32x4x2_t _c23 = vzipq_f32(_c2, _c3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); - _f2 = vaddq_f32(_f2, _c23.val[0]); - _f3 = vaddq_f32(_f3, _c23.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c01.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c01.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c23.val[0], _beta); - _f3 = vmlaq_f32(_f3, _c23.val[1], _beta); - } - pC += 8; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - float32x2_t _cc0 = vld1_f32(pC); - _cc0 = vmul_n_f32(_cc0, beta); - _c0 = vcombine_f32(_cc0, _cc0); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 2; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; } + } - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f2)); - vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f2)); - vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f3)); - vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f3)); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - pp += 16; - p0 += 2; + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + p0 += 4; } - for (; jj < max_jj; jj++) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); - _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); - _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); - _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 4); - pC += 4; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); @@ -7498,10 +6583,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); - - pp += 8; p0++; } + + pp += 8; } } for (; ii + 3 < max_ii; ii += 4) @@ -7533,213 +6618,213 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 #else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); - - if (pC) - { - if (broadcast_type_C == 0) + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + _c4 = vld1q_f32(pC + 16); + _c5 = vld1q_f32(pC + 20); + _c6 = vld1q_f32(pC + 24); + _c7 = vld1q_f32(pC + 28); + pC += 32; } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (c_elempack == 1) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep); + _c3 = vld1q_f32(pC + c_hstep + 4); + _c4 = vld1q_f32(pC + c_hstep * 2); + _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + _c6 = vld1q_f32(pC + c_hstep * 3); + _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; + } + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } - if (broadcast_type_C == 3) + else { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + c_hstep); - _c3 = vld1q_f32(pC + c_hstep + 4); - _c4 = vld1q_f32(pC + c_hstep * 2); - _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - _c6 = vld1q_f32(pC + c_hstep * 3); - _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - pC += 8; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - _c4 = vld1q_f32(pC + 16); - _c5 = vld1q_f32(pC + 20); - _c6 = vld1q_f32(pC + 24); - _c7 = vld1q_f32(pC + 28); - pC += 32; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - if (broadcast_type_C == 4) + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); - } - _c0 = vdupq_laneq_f32(_cc0, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); @@ -7748,899 +6833,648 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& vst1q_f32(p0 + 20, _f5); vst1q_f32(p0 + 24, _f6); vst1q_f32(p0 + 28, _f7); - - pp += 32; p0 += 32; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + transpose4x4_ps(_f0, _f1, _f2, _f3); + transpose4x4_ps(_f4, _f5, _f6, _f7); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + p0 += 8; + } + + pp += 32; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 #else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + pC += 16; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 1); - _c2 = vld1q_f32(pC + c_hstep * 2); - _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - pC += 4; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 1); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _c = vld1q_f32(pC); - _c = vmulq_n_f32(_c, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_c, 0); - float32x4_t _c1 = vdupq_laneq_f32(_c, 1); - float32x4_t _c2 = vdupq_laneq_f32(_c, 2); - float32x4_t _c3 = vdupq_laneq_f32(_c, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x4_t _c = vld1q_f32(pC); + _c = vmulq_n_f32(_c, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_c, 0); + float32x4_t _c1 = vdupq_laneq_f32(_c, 1); + float32x4_t _c2 = vdupq_laneq_f32(_c, 2); + float32x4_t _c3 = vdupq_laneq_f32(_c, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_c), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_c), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_c), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_c), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; } + } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); vst1q_f32(p0 + 8, _f2); vst1q_f32(p0 + 12, _f3); - - pp += 16; p0 += 16; } - for (; jj + 1 < max_jj; jj += 2) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + transpose4x4_ps(_f0, _f1, _f2, _f3); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 3, _f3); + p0 += 4; + } + + pp += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 #else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - { - _sum1 = vrev64q_s32(_sum1); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - } + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + pC += 8; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x4_t _c01 = vcombine_f32(_cc0, _cc1); - float32x4_t _c23 = vcombine_f32(_cc2, _cc3); - float32x4x2_t _cc = vuzpq_f32(_c01, _c23); - _c0 = _cc.val[0]; - _c1 = _cc.val[1]; - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _c01 = vcombine_f32(_cc0, _cc1); + float32x4_t _c23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_c01, _c23); + _c0 = _cc.val[0]; + _c1 = _cc.val[1]; + pC += 2; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x2_t _c = vld1_f32(pC); - _c = vmul_n_f32(_c, beta); - _c0 = vdupq_lane_f32(_c, 0); - float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 2; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; } + } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); - - pp += 8; p0 += 8; } - for (; jj < max_jj; jj++) + if (out_elempack == 1) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - pC += 4; - } - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; - } - } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1q_f32(p0, _f0); - - pp += 4; - p0 += 4; + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + p0 += 2; } + + pp += 8; } - if (out_elempack == 1) + for (; jj < max_jj; jj++) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); } -#else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) + if (c_elempack == 4) { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - pC += 8; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - pC += 32; - } + _c0 = vld1q_f32(pC); + pC += 4; } - if (broadcast_type_C == 4) + if (c_elempack == 1) { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - pC += 8; + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + pC += 1; } + _f0 = vmlaq_n_f32(_f0, _c0, beta); } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; } + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + out_hstep, _f2); - vst1q_f32(p0 + out_hstep + 4, _f3); - vst1q_f32(p0 + out_hstep * 2, _f4); - vst1q_f32(p0 + out_hstep * 2 + 4, _f5); - vst1q_f32(p0 + out_hstep * 3, _f6); - vst1q_f32(p0 + out_hstep * 3 + 4, _f7); + _f0 = vmulq_n_f32(_f0, alpha); - pp += 32; - p0 += 8; + if (out_elempack == 4) + { + vst1q_f32(p0, _f0); + p0 += 4; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0++; + } -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - { - int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); - } -#else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - { - _sum1 = vextq_s32(_sum1, _sum1, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } -#endif // __ARM_FEATURE_DOTPROD + pp += 4; + } + } +#endif // __ARM_NEON + for (; ii + 1 < max_ii; ii += 2) + { + // out_elempack == 1 + float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); + const float descale0 = descales[ii]; + const float descale1 = descales[ii + 1]; +#if __ARM_NEON + float32x2_t _descale = vld1_f32((const float*)descales + ii); +#endif + + float c0; + float c1; +#if __ARM_NEON + float32x4_t _c0; + float32x4_t _c1; +#endif + if (pC) + { + if (broadcast_type_C == 0) + { + c0 = pC[0] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); +#endif + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + c0 = pC[0] * beta; + c1 = pC[1] * beta; +#if __ARM_NEON + _c0 = vdupq_n_f32(c0); + _c1 = vdupq_n_f32(c1); #endif + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + pC = (const float*)C + (i + ii) * c_hstep + j; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } + + int jj = 0; +#if __ARM_NEON +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - if (pC) + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3) + else { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + c_hstep * 1); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - float32x4x4_t _c = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c.val[0]); - _f1 = vaddq_f32(_f1, _c.val[1]); - _f2 = vaddq_f32(_f2, _c.val[2]); - _f3 = vaddq_f32(_f3, _c.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c.val[2], _beta); - _f3 = vmlaq_f32(_f3, _c.val[3], _beta); - } - pC += 16; - } + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - if (broadcast_type_C == 4) + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + if (beta != 1.f) { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep * 2, _f2); - vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep, _f2); + vst1q_f32(p0 + out_hstep + 4, _f3); - pp += 16; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + pp += 16; + p0 += 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); - _sum0 = _sum01.val[0]; - _sum1 = _sum01.val[1]; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 - - // to - // a0 a1 b0 b1 - // c0 c1 d0 d1 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - int32x4_t _t0 = vuzpq_s32(_sum0, _sum1).val[0]; - int32x4_t _t1 = vuzpq_s32(_sum1, _sum0).val[1]; - int32x4x2_t _t3 = vuzpq_s32(_t0, _t1); - _sum0 = _t3.val[0]; - _sum1 = _t3.val[1]; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4x2_t _descale01 = vzipq_f32(_descale, _descale); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale01.val[0]); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale01.val[1]); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - _c0 = vcombine_f32(_cc0, _cc1); - float32x4_t _c1 = vcombine_f32(_cc2, _cc3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c01.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c01.val[1], _beta); - } - pC += 8; - } + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - float32x2_t _cc0 = vld1_f32(pC); - _cc0 = vmul_n_f32(_cc0, beta); - _c0 = vcombine_f32(_cc0, _cc0); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 2; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 4; } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; } - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - - pp += 8; - p0 += 2; } - for (; jj < max_jj; jj++) + + if (alpha != 1.f) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); - if (pC) + float32x2x2_t _descale01 = vzip_f32(_descale, _descale); + float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); + + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - pC += 4; - } - _f0 = vmlaq_n_f32(_f0, _c0, beta); - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; - } + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); + _f0 = vaddq_f32(_f0, _c0011); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vcombine_f32(vld1_f32(pC), vld1_f32(pC + c_hstep)); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + float32x2_t _c = vld1_f32(pC); + _c0 = vcombine_f32(_c, _c); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - pp += 4; - p0++; + pp += 4; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; + + if (pC) + { + if (broadcast_type_C == 0) + { + f0 += c0; + f1 += c0; + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + f0 += c0; + f1 += c1; + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; + } + if (broadcast_type_C == 4) + { + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; + } } + + f0 *= alpha; + f1 *= alpha; + + p0[0] = f0; + p0[out_hstep] = f1; + + pp += 2; + p0++; } } -#endif // __ARM_NEON - for (; ii + 1 < max_ii; ii += 2) + for (; ii < max_ii; ii += 1) { // out_elempack == 1 float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; - const float descale0 = descales[ii]; - const float descale1 = descales[ii + 1]; + const float descale = descales[ii]; #if __ARM_NEON - float32x2_t _descale = vld1_f32((const float*)descales + ii); + float32x4_t _descale = vdupq_n_f32(descale); #endif float c0; - float c1; #if __ARM_NEON float32x4_t _c0; - float32x4_t _c1; #endif if (pC) { @@ -8655,10 +7489,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& { pC = (const float*)C + i + ii; c0 = pC[0] * beta; - c1 = pC[1] * beta; #if __ARM_NEON _c0 = vdupq_n_f32(c0); - _c1 = vdupq_n_f32(c1); #endif } if (broadcast_type_C == 3) @@ -8672,484 +7504,197 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat& } } - // if (out_elempack == 1) - { - int jj = 0; + int jj = 0; #if __ARM_NEON -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 15 < max_jj; jj += 16) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale, 1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } + pC += 16; } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + out_hstep, _f2); - vst1q_f32(p0 + out_hstep + 4, _f3); - - pp += 16; - p0 += 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + + if (alpha != 1.f) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale, 1); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 4; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } + pp += 16; + p0 += 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - pp += 8; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) + if (pC) { - int32x4_t _sum0 = vld1q_s32(pp); - - float32x2x2_t _descale01 = vzip_f32(_descale, _descale); - float32x4_t _descale0011 = vcombine_f32(_descale01.val[0], _descale01.val[1]); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0011); - - if (pC) + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _c0011 = vcombine_f32(vget_low_f32(_c0), vget_high_f32(_c1)); - _f0 = vaddq_f32(_f0, _c0011); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vcombine_f32(vld1_f32(pC), vld1_f32(pC + c_hstep)); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } - if (broadcast_type_C == 4) - { - float32x2_t _c = vld1_f32(pC); - _c0 = vcombine_f32(_c, _c); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - - pp += 4; - p0 += 2; - } -#endif // __ARM_NEON - for (; jj < max_jj; jj++) - { - float f0 = pp[0] * descale0; - float f1 = pp[1] * descale1; - - if (pC) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0) - { - f0 += c0; - f1 += c0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + // out_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + if (beta == 1.f) { - f0 += c0; - f1 += c1; - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - f0 += pC[0] * beta; - f1 += pC[c_hstep] * beta; - pC += 1; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - f0 += pC[0] * beta; - f1 += pC[0] * beta; - pC += 1; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 8; } - - f0 *= alpha; - f1 *= alpha; - - p0[0] = f0; - p0[out_hstep] = f1; - - pp += 2; - p0++; } - } - } - for (; ii < max_ii; ii += 1) - { - // out_elempack == 1 - float* p0 = (float*)top_blob + (i + ii) * out_hstep + j; - - const float descale = descales[ii]; -#if __ARM_NEON - float32x4_t _descale = vdupq_n_f32(descale); -#endif - float c0; -#if __ARM_NEON - float32x4_t _c0; -#endif - if (pC) - { - if (broadcast_type_C == 0) - { - c0 = pC[0] * beta; -#if __ARM_NEON - _c0 = vdupq_n_f32(c0); -#endif - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - pC = (const float*)C + i + ii; - c0 = pC[0] * beta; -#if __ARM_NEON - _c0 = vdupq_n_f32(c0); -#endif - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - pC = (const float*)C + (i + ii) * c_hstep + j; - } - if (broadcast_type_C == 4) + if (alpha != 1.f) { - pC = (const float*)C + j; + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); } - } - // if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON - for (; jj + 15 < max_jj; jj += 16) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + pp += 8; + p0 += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } + _f0 = vaddq_f32(_f0, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + // out_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; } + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); + _f0 = vmulq_n_f32(_f0, alpha); - pp += 16; - p0 += 16; - } - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + vst1q_f32(p0, _f0); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + pp += 4; + p0 += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; - } + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); } - - if (alpha != 1.f) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + // out_elempack == 1 + float32x2_t _c = vld1_f32(pC); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - - pp += 8; - p0 += 8; } - for (; jj + 3 < max_jj; jj += 4) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; - } - } + _f0 = vmul_n_f32(_f0, alpha); - _f0 = vmulq_n_f32(_f0, alpha); + vst1_f32(p0, _f0); - vst1q_f32(p0, _f0); + pp += 2; + p0 += 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj++) + { + float f0 = pp[0] * descale; - pp += 4; - p0 += 4; - } - for (; jj + 1 < max_jj; jj += 2) + if (pC) { - float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - - if (pC) + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vadd_f32(_f0, vget_low_f32(_c0)); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - float32x2_t _c = vld1_f32(pC); - _f0 = vmla_n_f32(_f0, _c, beta); - pC += 2; - } + f0 += c0; } - - _f0 = vmul_n_f32(_f0, alpha); - - vst1_f32(p0, _f0); - - pp += 2; - p0 += 2; - } -#endif // __ARM_NEON - for (; jj < max_jj; jj++) - { - float f0 = pp[0] * descale; - - if (pC) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - f0 += pC[0] * beta; - pC += 1; - } + // out_elempack == 1 + f0 += pC[0] * beta; + pC += 1; } + } - f0 *= alpha; + f0 *= alpha; - p0[0] = f0; + p0[0] = f0; - pp += 1; - p0++; - } + pp += 1; + p0++; } } } @@ -9210,1126 +7755,415 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } - if (out_elempack == 4) - { - int jj = 0; + int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum8, _sum9); - int32x4x2_t _t3 = vzipq_s32(_suma, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t5 = vzipq_s32(_sum6, _sum7); - int32x4x2_t _t6 = vzipq_s32(_sumc, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sume, _sumf); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t4.val[1]), vget_low_s32(_t5.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t6.val[1]), vget_low_s32(_t7.val[1])); - _sume = vcombine_s32(vget_high_s32(_t4.val[1]), vget_high_s32(_t5.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t6.val[1]), vget_high_s32(_t7.val[1])); - } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 - // e0 e1 e2 e3 - // e4 e5 e6 e7 - // f0 f1 f2 f3 - // f4 f5 f6 f7 - // g0 g1 g2 g3 - // g4 g5 g6 g7 - // h0 h1 h2 h3 - // h4 h5 h6 h7 - { - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t3 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t4 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t5 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum9 = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _suma = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sumb = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sumc = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sumd = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sume = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale0, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale0, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale0, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale0, 3); - float32x4_t _f8 = vmulq_laneq_f32(vcvtq_f32_s32(_sum8), _descale1, 0); - float32x4_t _f9 = vmulq_laneq_f32(vcvtq_f32_s32(_sum9), _descale1, 0); - float32x4_t _fa = vmulq_laneq_f32(vcvtq_f32_s32(_suma), _descale1, 1); - float32x4_t _fb = vmulq_laneq_f32(vcvtq_f32_s32(_sumb), _descale1, 1); - float32x4_t _fc = vmulq_laneq_f32(vcvtq_f32_s32(_sumc), _descale1, 2); - float32x4_t _fd = vmulq_laneq_f32(vcvtq_f32_s32(_sumd), _descale1, 2); - float32x4_t _fe = vmulq_laneq_f32(vcvtq_f32_s32(_sume), _descale1, 3); - float32x4_t _ff = vmulq_laneq_f32(vcvtq_f32_s32(_sumf), _descale1, 3); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - _cc0 = vdupq_laneq_f32(_c1, 0); - _cc1 = vdupq_laneq_f32(_c1, 1); - _cc2 = vdupq_laneq_f32(_c1, 2); - _cc3 = vdupq_laneq_f32(_c1, 3); - _f8 = vaddq_f32(_f8, _cc0); - _f9 = vaddq_f32(_f9, _cc0); - _fa = vaddq_f32(_fa, _cc1); - _fb = vaddq_f32(_fb, _cc1); - _fc = vaddq_f32(_fc, _cc2); - _fd = vaddq_f32(_fd, _cc2); - _fe = vaddq_f32(_fe, _cc3); - _ff = vaddq_f32(_ff, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; - } - else // if (c_elempack == 4) - { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - _cc1 = vld4q_f32(pC + c_hstep * 4 + 16); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _cc0.val[0]); - _f9 = vaddq_f32(_f9, _cc1.val[0]); - _fa = vaddq_f32(_fa, _cc0.val[1]); - _fb = vaddq_f32(_fb, _cc1.val[1]); - _fc = vaddq_f32(_fc, _cc0.val[2]); - _fd = vaddq_f32(_fd, _cc1.val[2]); - _fe = vaddq_f32(_fe, _cc0.val[3]); - _ff = vaddq_f32(_ff, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _cc0.val[0], _beta); - _f9 = vmlaq_f32(_f9, _cc1.val[0], _beta); - _fa = vmlaq_f32(_fa, _cc0.val[1], _beta); - _fb = vmlaq_f32(_fb, _cc1.val[1], _beta); - _fc = vmlaq_f32(_fc, _cc0.val[2], _beta); - _fd = vmlaq_f32(_fd, _cc1.val[2], _beta); - _fe = vmlaq_f32(_fe, _cc0.val[3], _beta); - _ff = vmlaq_f32(_ff, _cc1.val[3], _beta); - } - pC += 32; - } - } - if (broadcast_type_C == 4) + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + int32x4_t _sum8 = vld1q_s32(pp + 32); + int32x4_t _sum9 = vld1q_s32(pp + 36); + int32x4_t _suma = vld1q_s32(pp + 40); + int32x4_t _sumb = vld1q_s32(pp + 44); + int32x4_t _sumc = vld1q_s32(pp + 48); + int32x4_t _sumd = vld1q_s32(pp + 52); + int32x4_t _sume = vld1q_s32(pp + 56); + int32x4_t _sumf = vld1q_s32(pp + 60); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 +#else + // from + // a0 b1 c2 d3 + // e4 f5 g6 h7 + // e0 f1 g2 h3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // g4 h5 e6 f7 + // g0 h1 e2 f3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // e7 f6 g5 h4 + // e3 f2 g1 h0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // g7 h6 e5 f4 + // g3 h2 e1 f0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + // e4 f4 g4 h4 + // e5 f5 g5 h5 + // e6 f6 g6 h6 + // e7 f7 g7 h7 + { + _sum8 = vrev64q_s32(_sum8); + _sum9 = vrev64q_s32(_sum9); + _suma = vrev64q_s32(_suma); + _sumb = vrev64q_s32(_sumb); + _sumc = vrev64q_s32(_sumc); + _sumd = vrev64q_s32(_sumd); + _sume = vrev64q_s32(_sume); + _sumf = vrev64q_s32(_sumf); + _sum8 = vextq_s32(_sum8, _sum8, 2); + _sum9 = vextq_s32(_sum9, _sum9, 2); + _suma = vextq_s32(_suma, _suma, 2); + _sumb = vextq_s32(_sumb, _sumb, 2); + _sumc = vextq_s32(_sumc, _sumc, 2); + _sumd = vextq_s32(_sumd, _sumd, 2); + _sume = vextq_s32(_sume, _sume, 2); + _sumf = vextq_s32(_sumf, _sumf, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); + int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); + int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); + int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); + int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); + int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); + int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); + int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum8 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); + _sum9 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); + _suma = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); + _sumb = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); + _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); + _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); + _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); + _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + _sum9 = vrev64q_s32(_sum9); + _sumb = vrev64q_s32(_sumb); + _sumd = vrev64q_s32(_sumd); + _sumf = vrev64q_s32(_sumf); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); + float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); + float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); + float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); + float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c0); + _fa = vaddq_f32(_fa, _c0); + _fb = vaddq_f32(_fb, _c0); + _fc = vaddq_f32(_fc, _c0); + _fd = vaddq_f32(_fd, _c0); + _fe = vaddq_f32(_fe, _c0); + _ff = vaddq_f32(_ff, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + _f8 = vaddq_f32(_f8, _c1); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c1); + _fb = vaddq_f32(_fb, _c1); + _fc = vaddq_f32(_fc, _c1); + _fd = vaddq_f32(_fd, _c1); + _fe = vaddq_f32(_fe, _c1); + _ff = vaddq_f32(_ff, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) { _c0 = vld1q_f32(pC); _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + float32x4_t _c4 = vld1q_f32(pC + 16); + float32x4_t _c5 = vld1q_f32(pC + 20); + float32x4_t _c6 = vld1q_f32(pC + 24); + float32x4_t _c7 = vld1q_f32(pC + 28); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + } + else { float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c1); - pC += 8; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f2); - vst1q_f32(p0 + 8, _f4); - vst1q_f32(p0 + 12, _f6); - vst1q_f32(p0 + 16, _f8); - vst1q_f32(p0 + 20, _fa); - vst1q_f32(p0 + 24, _fc); - vst1q_f32(p0 + 28, _fe); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - vst1q_f32(p0 + out_hstep * 4 + 8, _f5); - vst1q_f32(p0 + out_hstep * 4 + 12, _f7); - vst1q_f32(p0 + out_hstep * 4 + 16, _f9); - vst1q_f32(p0 + out_hstep * 4 + 20, _fb); - vst1q_f32(p0 + out_hstep * 4 + 24, _fd); - vst1q_f32(p0 + out_hstep * 4 + 28, _ff); - pp += 64; - p0 += out_hstep * 8; - } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 - { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); - } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 - // e0 e1 e2 e3 - // f0 f1 f2 f3 - // g0 g1 g2 g3 - // h0 h1 h2 h3 - { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale0, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale0, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale0, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale0, 3); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale1, 0); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale1, 1); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale1, 2); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale1, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale0), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale0), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale0), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale0), 1); - float32x4_t _f4 = vmulq_lane_f32(vcvtq_f32_s32(_sum4), vget_low_f32(_descale1), 0); - float32x4_t _f5 = vmulq_lane_f32(vcvtq_f32_s32(_sum5), vget_low_f32(_descale1), 1); - float32x4_t _f6 = vmulq_lane_f32(vcvtq_f32_s32(_sum6), vget_high_f32(_descale1), 0); - float32x4_t _f7 = vmulq_lane_f32(vcvtq_f32_s32(_sum7), vget_high_f32(_descale1), 1); -#endif - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); -#if __aarch64__ - _cc0 = vdupq_laneq_f32(_c1, 0); - _cc1 = vdupq_laneq_f32(_c1, 1); - _cc2 = vdupq_laneq_f32(_c1, 2); - _cc3 = vdupq_laneq_f32(_c1, 3); -#else - _cc0 = vdupq_lane_f32(vget_low_f32(_c1), 0); - _cc1 = vdupq_lane_f32(vget_low_f32(_c1), 1); - _cc2 = vdupq_lane_f32(vget_high_f32(_c1), 0); - _cc3 = vdupq_lane_f32(vget_high_f32(_c1), 1); -#endif - _f4 = vaddq_f32(_f4, _cc0); - _f5 = vaddq_f32(_f5, _cc1); - _f6 = vaddq_f32(_f6, _cc2); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + _c4 = vld1q_f32(pC + c_hstep * 4 + 16); + _c5 = vld1q_f32(pC + c_hstep * 4 + 20); + _c6 = vld1q_f32(pC + c_hstep * 4 + 24); + _c7 = vld1q_f32(pC + c_hstep * 4 + 28); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); } - else // if (c_elempack == 4) + else { - float32x4x4_t _cc0 = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc0.val[2]); - _f3 = vaddq_f32(_f3, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc0.val[1], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[2], _beta); - _f3 = vmlaq_f32(_f3, _cc0.val[3], _beta); - } - _cc0 = vld4q_f32(pC + c_hstep * 4); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _cc0.val[0]); - _f5 = vaddq_f32(_f5, _cc0.val[1]); - _f6 = vaddq_f32(_f6, _cc0.val[2]); - _f7 = vaddq_f32(_f7, _cc0.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _cc0.val[0], _beta); - _f5 = vmlaq_f32(_f5, _cc0.val[1], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[2], _beta); - _f7 = vmlaq_f32(_f7, _cc0.val[3], _beta); - } - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } + pC += 32; } - if (broadcast_type_C == 4) + if (c_elempack == 1) { _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - pC += 4; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); - vst1q_f32(p0 + 16, _f4); - vst1q_f32(p0 + 20, _f5); - vst1q_f32(p0 + 24, _f6); - vst1q_f32(p0 + 28, _f7); - pp += 32; - p0 += out_hstep * 4; - } - } - if (out_elempack == 1) - { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); - int32x4_t _sum8 = vld1q_s32(pp + 32); - int32x4_t _sum9 = vld1q_s32(pp + 36); - int32x4_t _suma = vld1q_s32(pp + 40); - int32x4_t _sumb = vld1q_s32(pp + 44); - int32x4_t _sumc = vld1q_s32(pp + 48); - int32x4_t _sumd = vld1q_s32(pp + 52); - int32x4_t _sume = vld1q_s32(pp + 56); - int32x4_t _sumf = vld1q_s32(pp + 60); - -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_suma), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sumb), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); -#else - // from - // a0 b1 c2 d3 - // e4 f5 g6 h7 - // e0 f1 g2 h3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // g4 h5 e6 f7 - // g0 h1 e2 f3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // e7 f6 g5 h4 - // e3 f2 g1 h0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // g7 h6 e5 f4 - // g3 h2 e1 f0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - // e4 f4 g4 h4 - // e5 f5 g5 h5 - // e6 f6 g6 h6 - // e7 f7 g7 h7 - { - _sum8 = vrev64q_s32(_sum8); - _sum9 = vrev64q_s32(_sum9); - _suma = vrev64q_s32(_suma); - _sumb = vrev64q_s32(_sumb); - _sumc = vrev64q_s32(_sumc); - _sumd = vrev64q_s32(_sumd); - _sume = vrev64q_s32(_sume); - _sumf = vrev64q_s32(_sumf); - _sum8 = vextq_s32(_sum8, _sum8, 2); - _sum9 = vextq_s32(_sum9, _sum9, 2); - _suma = vextq_s32(_suma, _suma, 2); - _sumb = vextq_s32(_sumb, _sumb, 2); - _sumc = vextq_s32(_sumc, _sumc, 2); - _sumd = vextq_s32(_sumd, _sumd, 2); - _sume = vextq_s32(_sume, _sume, 2); - _sumf = vextq_s32(_sumf, _sumf, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sumc); - int32x4x2_t _t1 = vzipq_s32(_sum4, _sum8); - int32x4x2_t _t2 = vzipq_s32(_sum2, _sume); - int32x4x2_t _t3 = vzipq_s32(_sum6, _suma); - int32x4x2_t _t4 = vzipq_s32(_sum3, _sumf); - int32x4x2_t _t5 = vzipq_s32(_sum7, _sumb); - int32x4x2_t _t6 = vzipq_s32(_sum1, _sumd); - int32x4x2_t _t7 = vzipq_s32(_sum5, _sum9); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t4.val[0]), vget_low_s32(_t5.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t4.val[0]), vget_high_s32(_t5.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t5.val[1]), vget_low_s32(_t4.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t5.val[1]), vget_high_s32(_t4.val[1])); - _sum8 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum9 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _suma = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sumb = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sumc = vcombine_s32(vget_low_s32(_t6.val[0]), vget_low_s32(_t7.val[0])); - _sumd = vcombine_s32(vget_high_s32(_t6.val[0]), vget_high_s32(_t7.val[0])); - _sume = vcombine_s32(vget_low_s32(_t7.val[1]), vget_low_s32(_t6.val[1])); - _sumf = vcombine_s32(vget_high_s32(_t7.val[1]), vget_high_s32(_t6.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - _sum9 = vrev64q_s32(_sum9); - _sumb = vrev64q_s32(_sumb); - _sumd = vrev64q_s32(_sumd); - _sumf = vrev64q_s32(_sumf); - } - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale0); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale0); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale0); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale0); - float32x4_t _f8 = vmulq_f32(vcvtq_f32_s32(_sum8), _descale1); - float32x4_t _f9 = vmulq_f32(vcvtq_f32_s32(_sum9), _descale1); - float32x4_t _fa = vmulq_f32(vcvtq_f32_s32(_suma), _descale1); - float32x4_t _fb = vmulq_f32(vcvtq_f32_s32(_sumb), _descale1); - float32x4_t _fc = vmulq_f32(vcvtq_f32_s32(_sumc), _descale1); - float32x4_t _fd = vmulq_f32(vcvtq_f32_s32(_sumd), _descale1); - float32x4_t _fe = vmulq_f32(vcvtq_f32_s32(_sume), _descale1); - float32x4_t _ff = vmulq_f32(vcvtq_f32_s32(_sumf), _descale1); -#endif // __ARM_FEATURE_DOTPROD - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c0); - _fa = vaddq_f32(_fa, _c0); - _fb = vaddq_f32(_fb, _c0); - _fc = vaddq_f32(_fc, _c0); - _fd = vaddq_f32(_fd, _c0); - _fe = vaddq_f32(_fe, _c0); - _ff = vaddq_f32(_ff, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - _f8 = vaddq_f32(_f8, _c1); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c1); - _fb = vaddq_f32(_fb, _c1); - _fc = vaddq_f32(_fc, _c1); - _fd = vaddq_f32(_fd, _c1); - _fe = vaddq_f32(_fe, _c1); - _ff = vaddq_f32(_ff, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); + float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 5); - _c3 = vld1q_f32(pC + c_hstep * 5 + 4); - _c4 = vld1q_f32(pC + c_hstep * 6); - _c5 = vld1q_f32(pC + c_hstep * 6 + 4); - _c6 = vld1q_f32(pC + c_hstep * 7); - _c7 = vld1q_f32(pC + c_hstep * 7 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 8; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - float32x4_t _c4 = vld1q_f32(pC + 16); - float32x4_t _c5 = vld1q_f32(pC + 20); - float32x4_t _c6 = vld1q_f32(pC + 24); - float32x4_t _c7 = vld1q_f32(pC + 28); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 8); - _c3 = vld1q_f32(pC + c_hstep * 4 + 12); - _c4 = vld1q_f32(pC + c_hstep * 4 + 16); - _c5 = vld1q_f32(pC + c_hstep * 4 + 20); - _c6 = vld1q_f32(pC + c_hstep * 4 + 24); - _c7 = vld1q_f32(pC + c_hstep * 4 + 28); - if (beta == 1.f) - { - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f8 = vmlaq_f32(_f8, _c0, _beta); - _f9 = vmlaq_f32(_f9, _c1, _beta); - _fa = vmlaq_f32(_fa, _c2, _beta); - _fb = vmlaq_f32(_fb, _c3, _beta); - _fc = vmlaq_f32(_fc, _c4, _beta); - _fd = vmlaq_f32(_fd, _c5, _beta); - _fe = vmlaq_f32(_fe, _c6, _beta); - _ff = vmlaq_f32(_ff, _c7, _beta); - } - pC += 32; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); } - } - if (broadcast_type_C == 4) - { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 5); + _c3 = vld1q_f32(pC + c_hstep * 5 + 4); + _c4 = vld1q_f32(pC + c_hstep * 6); + _c5 = vld1q_f32(pC + c_hstep * 6 + 4); + _c6 = vld1q_f32(pC + c_hstep * 7); + _c7 = vld1q_f32(pC + c_hstep * 7 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + if (beta == 1.f) + { + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + } + else { float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); + _f8 = vmlaq_f32(_f8, _c0, _beta); + _f9 = vmlaq_f32(_f9, _c1, _beta); + _fa = vmlaq_f32(_fa, _c2, _beta); + _fb = vmlaq_f32(_fb, _c3, _beta); + _fc = vmlaq_f32(_fc, _c4, _beta); + _fd = vmlaq_f32(_fd, _c5, _beta); + _fe = vmlaq_f32(_fe, _c6, _beta); + _ff = vmlaq_f32(_ff, _c7, _beta); } - _c0 = vdupq_laneq_f32(_cc0, 0); - _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - _f8 = vaddq_f32(_f8, _c0); - _f9 = vaddq_f32(_f9, _c1); - _fa = vaddq_f32(_fa, _c2); - _fb = vaddq_f32(_fb, _c3); - _fc = vaddq_f32(_fc, _c4); - _fd = vaddq_f32(_fd, _c5); - _fe = vaddq_f32(_fe, _c6); - _ff = vaddq_f32(_ff, _c7); pC += 8; } } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - _f8 = vmulq_f32(_f8, _alpha); - _f9 = vmulq_f32(_f9, _alpha); - _fa = vmulq_f32(_fa, _alpha); - _fb = vmulq_f32(_fb, _alpha); - _fc = vmulq_f32(_fc, _alpha); - _fd = vmulq_f32(_fd, _alpha); - _fe = vmulq_f32(_fe, _alpha); - _ff = vmulq_f32(_ff, _alpha); - } - + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + _f8 = vaddq_f32(_f8, _c0); + _f9 = vaddq_f32(_f9, _c1); + _fa = vaddq_f32(_fa, _c2); + _fb = vaddq_f32(_fb, _c3); + _fc = vaddq_f32(_fc, _c4); + _fd = vaddq_f32(_fd, _c5); + _fe = vaddq_f32(_fe, _c6); + _ff = vaddq_f32(_ff, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + _f8 = vmulq_f32(_f8, _alpha); + _f9 = vmulq_f32(_f9, _alpha); + _fa = vmulq_f32(_fa, _alpha); + _fb = vmulq_f32(_fb, _alpha); + _fc = vmulq_f32(_fc, _alpha); + _fd = vmulq_f32(_fd, _alpha); + _fe = vmulq_f32(_fe, _alpha); + _ff = vmulq_f32(_ff, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _ffa; + float32x4x4_t _ffb; + float32x4x4_t _ffc; + float32x4x4_t _ffd; + _ffa.val[0] = _f0; + _ffa.val[1] = _f1; + _ffa.val[2] = _f2; + _ffa.val[3] = _f3; + _ffb.val[0] = _f4; + _ffb.val[1] = _f5; + _ffb.val[2] = _f6; + _ffb.val[3] = _f7; + _ffc.val[0] = _f8; + _ffc.val[1] = _f9; + _ffc.val[2] = _fa; + _ffc.val[3] = _fb; + _ffd.val[0] = _fc; + _ffd.val[1] = _fd; + _ffd.val[2] = _fe; + _ffd.val[3] = _ff; + vst4q_f32(p0, _ffa); + vst4q_f32(p0 + 16, _ffc); + vst4q_f32(p0 + out_hstep * 4, _ffb); + vst4q_f32(p0 + out_hstep * 4 + 16, _ffd); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f8); vst1q_f32(p0 + out_hstep, _f1); @@ -10346,1103 +8180,664 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma vst1q_f32(p0 + out_hstep * 6 + 4, _fe); vst1q_f32(p0 + out_hstep * 7, _f7); vst1q_f32(p0 + out_hstep * 7 + 4, _ff); - - pp += 64; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 + pp += 64; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 #else - // from - // a0 b1 c2 d3 - // e0 f1 g2 h3 - // c0 d1 a2 b3 - // g0 h1 e2 f3 - // a3 b2 c1 d0 - // e3 f2 g1 h0 - // c3 d2 a1 b0 - // g3 h2 e1 f0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - // e2 f2 g2 h2 - // e3 f3 g3 h3 - - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c1); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c1); - _f7 = vaddq_f32(_f7, _c1); - } - if (broadcast_type_C == 3) + // from + // a0 b1 c2 d3 + // e0 f1 g2 h3 + // c0 d1 a2 b3 + // g0 h1 e2 f3 + // a3 b2 c1 d0 + // e3 f2 g1 h0 + // c3 d2 a1 b0 + // g3 h2 e1 f0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // e0 f0 g0 h0 + // e1 f1 g1 h1 + // e2 f2 g2 h2 + // e3 f3 g3 h3 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale0); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale0); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale1); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale1); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale1); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c1); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c1); + _f7 = vaddq_f32(_f7, _c1); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 4) { - if (c_elempack == 1) + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 5); - _c2 = vld1q_f32(pC + c_hstep * 6); - _c3 = vld1q_f32(pC + c_hstep * 7); - transpose4x4_ps(_c0, _c1, _c2, _c3); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 4; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - _c0 = vld1q_f32(pC + c_hstep * 4); - _c1 = vld1q_f32(pC + c_hstep * 4 + 4); - _c2 = vld1q_f32(pC + c_hstep * 4 + 8); - _c3 = vld1q_f32(pC + c_hstep * 4 + 12); - if (beta == 1.f) - { - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f4 = vmlaq_f32(_f4, _c0, _beta); - _f5 = vmlaq_f32(_f5, _c1, _beta); - _f6 = vmlaq_f32(_f6, _c2, _beta); - _f7 = vmlaq_f32(_f7, _c3, _beta); - } - pC += 16; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - float32x4_t _cc = vld1q_f32(pC); - _cc = vmulq_n_f32(_cc, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_cc, 0); - _c1 = vdupq_laneq_f32(_cc, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); - _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); -#endif - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c2); - _f7 = vaddq_f32(_f7, _c3); - pC += 4; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f4); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep + 4, _f5); - vst1q_f32(p0 + out_hstep * 2, _f2); - vst1q_f32(p0 + out_hstep * 2 + 4, _f6); - vst1q_f32(p0 + out_hstep * 3, _f3); - vst1q_f32(p0 + out_hstep * 3 + 4, _f7); - - pp += 32; - p0 += out_hstep * 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 -#else - // from - // a0 b1 c0 d1 - // e0 f1 g0 h1 - // a1 b0 c1 d0 - // e1 f0 g1 h0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // e0 f0 g0 h0 - // e1 f1 g1 h1 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 4 + 4); + _c2 = vld1q_f32(pC + c_hstep * 4 + 8); + _c3 = vld1q_f32(pC + c_hstep * 4 + 12); + if (beta == 1.f) { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); - float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); - float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); - float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); - float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); - float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); - float32x4_t _cc45 = vcombine_f32(_cc4, _cc5); - float32x4_t _cc67 = vcombine_f32(_cc6, _cc7); - float32x4x2_t _ccc0 = vuzpq_f32(_cc01, _cc23); - float32x4x2_t _ccc1 = vuzpq_f32(_cc45, _cc67); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _ccc0.val[0]); - _f1 = vaddq_f32(_f1, _ccc0.val[1]); - _f2 = vaddq_f32(_f2, _ccc1.val[0]); - _f3 = vaddq_f32(_f3, _ccc1.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _ccc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _ccc0.val[1], _beta); - _f2 = vmlaq_f32(_f2, _ccc1.val[0], _beta); - _f3 = vmlaq_f32(_f3, _ccc1.val[1], _beta); - } - pC += 2; + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); } - else // if (c_elempack == 4) + else { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); } + pC += 16; } - if (broadcast_type_C == 4) - { - float32x2_t _cc = vld1_f32(pC); - _cc = vmul_n_f32(_cc, beta); - _c0 = vdupq_lane_f32(_cc, 0); - _c1 = vdupq_lane_f32(_cc, 1); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 2; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f2); - vst1q_f32(p0 + out_hstep, _f1); - vst1q_f32(p0 + out_hstep + 4, _f3); - - pp += 16; - p0 += out_hstep * 2; - } - for (; jj < max_jj; jj += 1) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); - _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); - _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); - _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep * 4); - pC += 4; - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } else { float32x4_t _beta = vdupq_n_f32(beta); _f0 = vmlaq_f32(_f0, _c0, _beta); _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - } - if (broadcast_type_C == 4) - { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 1; + _c0 = vld1q_f32(pC + c_hstep * 4); + _c1 = vld1q_f32(pC + c_hstep * 5); + _c2 = vld1q_f32(pC + c_hstep * 6); + _c3 = vld1q_f32(pC + c_hstep * 7); + transpose4x4_ps(_c0, _c1, _c2, _c3); + if (beta == 1.f) + { + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f4 = vmlaq_f32(_f4, _c0, _beta); + _f5 = vmlaq_f32(_f5, _c1, _beta); + _f6 = vmlaq_f32(_f6, _c2, _beta); + _f7 = vmlaq_f32(_f7, _c3, _beta); + } + pC += 4; } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - + float32x4_t _cc = vld1q_f32(pC); + _cc = vmulq_n_f32(_cc, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_cc, 0); + _c1 = vdupq_laneq_f32(_cc, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); + _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c1); + _f6 = vaddq_f32(_f6, _c2); + _f7 = vaddq_f32(_f7, _c3); + pC += 4; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _fa; + float32x4x4_t _fb; + _fa.val[0] = _f0; + _fa.val[1] = _f1; + _fa.val[2] = _f2; + _fa.val[3] = _f3; + _fb.val[0] = _f4; + _fb.val[1] = _f5; + _fb.val[2] = _f6; + _fb.val[3] = _f7; + vst4q_f32(p0, _fa); + vst4q_f32(p0 + 16, _fb); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - pp += 8; - p0 += out_hstep; + vst1q_f32(p0 + 4, _f4); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f5); + vst1q_f32(p0 + out_hstep * 2, _f2); + vst1q_f32(p0 + out_hstep * 2 + 4, _f6); + vst1q_f32(p0 + out_hstep * 3, _f3); + vst1q_f32(p0 + out_hstep * 3 + 4, _f7); } + + pp += 32; + p0 += out_hstep * 4; } - } - for (; ii + 3 < max_ii; ii += 4) - { - float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _descale = vld1q_f32((const float*)descales + ii); +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 +#else + // from + // a0 b1 c0 d1 + // e0 f1 g0 h1 + // a1 b0 c1 d0 + // e1 f0 g1 h0 - float32x4_t _c0; - if (pC) - { - if (broadcast_type_C == 0) + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // e0 f0 g0 h0 + // e1 f1 g1 h1 { - _c0 = vdupq_n_f32(pC[0] * beta); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - pC = (const float*)C + i + ii; - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum2); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[0]), vget_low_s32(_t1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[0]), vget_high_s32(_t1.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); } - if (broadcast_type_C == 3) - { - pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; - } - if (broadcast_type_C == 4) - { - pC = (const float*)C + j; - } - } +#endif // __ARM_FEATURE_DOTPROD - if (out_elempack == 4) - { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale0); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale1); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale1); -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum3); - int32x4x2_t _t2 = vzipq_s32(_sum4, _sum5); - int32x4x2_t _t3 = vzipq_s32(_sum6, _sum7); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t0.val[1]), vget_low_s32(_t1.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t2.val[1]), vget_low_s32(_t3.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t0.val[1]), vget_high_s32(_t1.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t2.val[1]), vget_high_s32(_t3.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } -#else // __ARM_FEATURE_DOTPROD - - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - // c0 c1 c2 c3 - // c4 c5 c6 c7 - // d0 d1 d2 d3 - // d4 d5 d6 d7 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum2 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum3 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum4 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum5 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum6 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 0); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 1); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 1); - float32x4_t _f4 = vmulq_laneq_f32(vcvtq_f32_s32(_sum4), _descale, 2); - float32x4_t _f5 = vmulq_laneq_f32(vcvtq_f32_s32(_sum5), _descale, 2); - float32x4_t _f6 = vmulq_laneq_f32(vcvtq_f32_s32(_sum6), _descale, 3); - float32x4_t _f7 = vmulq_laneq_f32(vcvtq_f32_s32(_sum7), _descale, 3); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + if (c_elempack == 1) { - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc0); - _f2 = vaddq_f32(_f2, _cc1); - _f3 = vaddq_f32(_f3, _cc1); - _f4 = vaddq_f32(_f4, _cc2); - _f5 = vaddq_f32(_f5, _cc2); - _f6 = vaddq_f32(_f6, _cc3); - _f7 = vaddq_f32(_f7, _cc3); - } - if (broadcast_type_C == 3) - { - if (c_elempack == 1) + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x2_t _cc4 = vld1_f32(pC + c_hstep * 4); + float32x2_t _cc5 = vld1_f32(pC + c_hstep * 5); + float32x2_t _cc6 = vld1_f32(pC + c_hstep * 6); + float32x2_t _cc7 = vld1_f32(pC + c_hstep * 7); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4_t _cc45 = vcombine_f32(_cc4, _cc5); + float32x4_t _cc67 = vcombine_f32(_cc6, _cc7); + float32x4x2_t _ccc0 = vuzpq_f32(_cc01, _cc23); + float32x4x2_t _ccc1 = vuzpq_f32(_cc45, _cc67); + if (beta == 1.f) + { + _f0 = vaddq_f32(_f0, _ccc0.val[0]); + _f1 = vaddq_f32(_f1, _ccc0.val[1]); + _f2 = vaddq_f32(_f2, _ccc1.val[0]); + _f3 = vaddq_f32(_f3, _ccc1.val[1]); + } + else { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4_t _c4 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - float32x4_t _c6 = vld1q_f32(pC + c_hstep * 3); - float32x4_t _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _ccc0.val[0], _beta); + _f1 = vmlaq_f32(_f1, _ccc0.val[1], _beta); + _f2 = vmlaq_f32(_f2, _ccc1.val[0], _beta); + _f3 = vmlaq_f32(_f3, _ccc1.val[1], _beta); } - else // if (c_elempack == 4) + pC += 2; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep * 4); + float32x4_t _c3 = vld1q_f32(pC + c_hstep * 4 + 4); + if (beta == 1.f) { - float32x4x4_t _cc0 = vld4q_f32(pC); - float32x4x4_t _cc1 = vld4q_f32(pC + 16); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc1.val[0]); - _f2 = vaddq_f32(_f2, _cc0.val[1]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - _f4 = vaddq_f32(_f4, _cc0.val[2]); - _f5 = vaddq_f32(_f5, _cc1.val[2]); - _f6 = vaddq_f32(_f6, _cc0.val[3]); - _f7 = vaddq_f32(_f7, _cc1.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _cc0.val[0], _beta); - _f1 = vmlaq_f32(_f1, _cc1.val[0], _beta); - _f2 = vmlaq_f32(_f2, _cc0.val[1], _beta); - _f3 = vmlaq_f32(_f3, _cc1.val[1], _beta); - _f4 = vmlaq_f32(_f4, _cc0.val[2], _beta); - _f5 = vmlaq_f32(_f5, _cc1.val[2], _beta); - _f6 = vmlaq_f32(_f6, _cc0.val[3], _beta); - _f7 = vmlaq_f32(_f7, _cc1.val[3], _beta); - } - pC += 32; + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) + else { float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c1); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c1); pC += 8; } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); + float32x2_t _cc = vld1_f32(pC); + _cc = vmul_n_f32(_cc, beta); + _c0 = vdupq_lane_f32(_cc, 0); + _c1 = vdupq_lane_f32(_cc, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 2; } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f2); - vst1q_f32(p0 + 8, _f4); - vst1q_f32(p0 + 12, _f6); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - vst1q_f32(p0 + out_hstep * 4 + 8, _f5); - vst1q_f32(p0 + out_hstep * 4 + 12, _f7); - - pp += 32; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + + if (alpha != 1.f) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } -#if __ARM_FEATURE_DOTPROD - // from - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f2); + vst1q_f32(p0 + out_hstep, _f1); + vst1q_f32(p0 + out_hstep + 4, _f3); + + pp += 16; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale0); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp + 4)), _descale1); + + if (pC) + { + if (broadcast_type_C == 0) { - int32x4x2_t _r01 = vzipq_s32(_sum0, _sum1); - int32x4x2_t _r23 = vzipq_s32(_sum2, _sum3); - _sum0 = vcombine_s32(vget_low_s32(_r01.val[0]), vget_low_s32(_r23.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_r01.val[0]), vget_high_s32(_r23.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_r01.val[1]), vget_low_s32(_r23.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_r01.val[1]), vget_high_s32(_r23.val[1])); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); } -#else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 a1 a2 a3 - // b0 b1 b2 b3 - // c0 c1 c2 c3 - // d0 d1 d2 d3 + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - _sum1 = vextq_s32(_sum1, _sum1, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } -#endif // __ARM_FEATURE_DOTPROD - -#if __aarch64__ - float32x4_t _f0 = vmulq_laneq_f32(vcvtq_f32_s32(_sum0), _descale, 0); - float32x4_t _f1 = vmulq_laneq_f32(vcvtq_f32_s32(_sum1), _descale, 1); - float32x4_t _f2 = vmulq_laneq_f32(vcvtq_f32_s32(_sum2), _descale, 2); - float32x4_t _f3 = vmulq_laneq_f32(vcvtq_f32_s32(_sum3), _descale, 3); -#else - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), vget_low_f32(_descale), 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), vget_low_f32(_descale), 1); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), vget_high_f32(_descale), 0); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), vget_high_f32(_descale), 1); -#endif - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) + if (c_elempack == 1) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + _c1 = vsetq_lane_f32(pC[c_hstep * 4], _c1, 0); + _c1 = vsetq_lane_f32(pC[c_hstep * 5], _c1, 1); + _c1 = vsetq_lane_f32(pC[c_hstep * 6], _c1, 2); + _c1 = vsetq_lane_f32(pC[c_hstep * 7], _c1, 3); + pC += 1; } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + else // if (c_elempack == 4) { -#if __aarch64__ - float32x4_t _cc0 = vdupq_laneq_f32(_c0, 0); - float32x4_t _cc1 = vdupq_laneq_f32(_c0, 1); - float32x4_t _cc2 = vdupq_laneq_f32(_c0, 2); - float32x4_t _cc3 = vdupq_laneq_f32(_c0, 3); -#else - float32x4_t _cc0 = vdupq_lane_f32(vget_low_f32(_c0), 0); - float32x4_t _cc1 = vdupq_lane_f32(vget_low_f32(_c0), 1); - float32x4_t _cc2 = vdupq_lane_f32(vget_high_f32(_c0), 0); - float32x4_t _cc3 = vdupq_lane_f32(vget_high_f32(_c0), 1); -#endif - _f0 = vaddq_f32(_f0, _cc0); - _f1 = vaddq_f32(_f1, _cc1); - _f2 = vaddq_f32(_f2, _cc2); - _f3 = vaddq_f32(_f3, _cc3); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep * 4); + pC += 4; } - if (broadcast_type_C == 3) + if (beta == 1.f) { - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + c_hstep); - float32x4_t _c2 = vld1q_f32(pC + c_hstep * 2); - float32x4_t _c3 = vld1q_f32(pC + c_hstep * 3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 4; - } - else // if (c_elempack == 4) - { - float32x4x4_t _c = vld4q_f32(pC); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c.val[0]); - _f1 = vaddq_f32(_f1, _c.val[1]); - _f2 = vaddq_f32(_f2, _c.val[2]); - _f3 = vaddq_f32(_f3, _c.val[3]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c.val[2], _beta); - _f3 = vmlaq_f32(_f3, _c.val[3], _beta); - } - pC += 16; - } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 4) + else { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 1; } + } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); - - pp += 16; - p0 += out_hstep * 4; + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + pp += 8; + p0 += out_hstep; } - if (out_elempack == 1) + } + for (; ii + 3 < max_ii; ii += 4) + { + float* p0 = (float*)top_blob + j * out_hstep + (i + ii) * out_elempack; + + float32x4_t _descale = vld1q_f32((const float*)descales + ii); + + float32x4_t _c0; + if (pC) { - int jj = 0; -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) + if (broadcast_type_C == 0) + { + _c0 = vdupq_n_f32(pC[0] * beta); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + pC = (const float*)C + i + ii; + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + } + if (broadcast_type_C == 3) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - int32x4_t _sum4 = vld1q_s32(pp + 16); - int32x4_t _sum5 = vld1q_s32(pp + 20); - int32x4_t _sum6 = vld1q_s32(pp + 24); - int32x4_t _sum7 = vld1q_s32(pp + 28); + pC = (const float*)C + (i + ii) * c_hstep + j * c_elempack; + } + if (broadcast_type_C == 4) + { + pC = (const float*)C + j; + } + } -#if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 + int jj = 0; +#if __aarch64__ + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); + int32x4_t _sum4 = vld1q_s32(pp + 16); + int32x4_t _sum5 = vld1q_s32(pp + 20); + int32x4_t _sum6 = vld1q_s32(pp + 24); + int32x4_t _sum7 = vld1q_s32(pp + 28); + +#if __ARM_FEATURE_DOTPROD + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 #else - // from - // a0 b1 c2 d3 - // a4 b5 c6 d7 - // c0 d1 a2 b3 - // c4 d5 a6 b7 - // a3 b2 c1 d0 - // a7 b6 c5 d4 - // c3 d2 a1 b0 - // c7 d6 a5 b4 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - // a4 b4 c4 d4 - // a5 b5 c5 d5 - // a6 b6 c6 d6 - // a7 b7 c7 d7 - { - _sum4 = vrev64q_s32(_sum4); - _sum5 = vrev64q_s32(_sum5); - _sum6 = vrev64q_s32(_sum6); - _sum7 = vrev64q_s32(_sum7); - _sum4 = vextq_s32(_sum4, _sum4, 2); - _sum5 = vextq_s32(_sum5, _sum5, 2); - _sum6 = vextq_s32(_sum6, _sum6, 2); - _sum7 = vextq_s32(_sum7, _sum7, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); - int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); - int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); - int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); - _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); - _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); - _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - _sum5 = vrev64q_s32(_sum5); - _sum7 = vrev64q_s32(_sum7); - } -#endif // __ARM_FEATURE_DOTPROD - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); - float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); - float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); - float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); - - if (pC) - { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + // from + // a0 b1 c2 d3 + // a4 b5 c6 d7 + // c0 d1 a2 b3 + // c4 d5 a6 b7 + // a3 b2 c1 d0 + // a7 b6 c5 d4 + // c3 d2 a1 b0 + // c7 d6 a5 b4 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + // a4 b4 c4 d4 + // a5 b5 c5 d5 + // a6 b6 c6 d6 + // a7 b7 c7 d7 + { + _sum4 = vrev64q_s32(_sum4); + _sum5 = vrev64q_s32(_sum5); + _sum6 = vrev64q_s32(_sum6); + _sum7 = vrev64q_s32(_sum7); + _sum4 = vextq_s32(_sum4, _sum4, 2); + _sum5 = vextq_s32(_sum5, _sum5, 2); + _sum6 = vextq_s32(_sum6, _sum6, 2); + _sum7 = vextq_s32(_sum7, _sum7, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum6); + int32x4x2_t _t1 = vzipq_s32(_sum2, _sum4); + int32x4x2_t _t2 = vzipq_s32(_sum1, _sum7); + int32x4x2_t _t3 = vzipq_s32(_sum3, _sum5); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum4 = vcombine_s32(vget_low_s32(_t2.val[0]), vget_low_s32(_t3.val[0])); + _sum5 = vcombine_s32(vget_high_s32(_t2.val[0]), vget_high_s32(_t3.val[0])); + _sum6 = vcombine_s32(vget_low_s32(_t3.val[1]), vget_low_s32(_t2.val[1])); + _sum7 = vcombine_s32(vget_high_s32(_t3.val[1]), vget_high_s32(_t2.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + _sum5 = vrev64q_s32(_sum5); + _sum7 = vrev64q_s32(_sum7); + } +#endif // __ARM_FEATURE_DOTPROD + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f4 = vmulq_f32(vcvtq_f32_s32(_sum4), _descale); + float32x4_t _f5 = vmulq_f32(vcvtq_f32_s32(_sum5), _descale); + float32x4_t _f6 = vmulq_f32(vcvtq_f32_s32(_sum6), _descale); + float32x4_t _f7 = vmulq_f32(vcvtq_f32_s32(_sum7), _descale); + + if (pC) + { + if (broadcast_type_C == 0) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + _f4 = vaddq_f32(_f4, _c0); + _f5 = vaddq_f32(_f5, _c0); + _f6 = vaddq_f32(_f6, _c0); + _f7 = vaddq_f32(_f7, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + float32x4_t _c4; + float32x4_t _c5; + float32x4_t _c6; + float32x4_t _c7; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - _f4 = vaddq_f32(_f4, _c0); - _f5 = vaddq_f32(_f5, _c0); - _f6 = vaddq_f32(_f6, _c0); - _f7 = vaddq_f32(_f7, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + _c4 = vld1q_f32(pC + 16); + _c5 = vld1q_f32(pC + 20); + _c6 = vld1q_f32(pC + 24); + _c7 = vld1q_f32(pC + 28); + pC += 32; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - float32x4_t _c4; - float32x4_t _c5; - float32x4_t _c6; - float32x4_t _c7; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + c_hstep); - _c3 = vld1q_f32(pC + c_hstep + 4); - _c4 = vld1q_f32(pC + c_hstep * 2); - _c5 = vld1q_f32(pC + c_hstep * 2 + 4); - _c6 = vld1q_f32(pC + c_hstep * 3); - _c7 = vld1q_f32(pC + c_hstep * 3 + 4); - transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); - pC += 8; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - _c4 = vld1q_f32(pC + 16); - _c5 = vld1q_f32(pC + 20); - _c6 = vld1q_f32(pC + 24); - _c7 = vld1q_f32(pC + 28); - pC += 32; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - _f4 = vaddq_f32(_f4, _c4); - _f5 = vaddq_f32(_f5, _c5); - _f6 = vaddq_f32(_f6, _c6); - _f7 = vaddq_f32(_f7, _c7); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - _f4 = vmlaq_f32(_f4, _c4, _beta); - _f5 = vmlaq_f32(_f5, _c5, _beta); - _f6 = vmlaq_f32(_f6, _c6, _beta); - _f7 = vmlaq_f32(_f7, _c7, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + c_hstep); + _c3 = vld1q_f32(pC + c_hstep + 4); + _c4 = vld1q_f32(pC + c_hstep * 2); + _c5 = vld1q_f32(pC + c_hstep * 2 + 4); + _c6 = vld1q_f32(pC + c_hstep * 3); + _c7 = vld1q_f32(pC + c_hstep * 3 + 4); + transpose8x4_ps(_c0, _c1, _c2, _c3, _c4, _c5, _c6, _c7); + pC += 8; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _cc0 = vld1q_f32(pC); - float32x4_t _cc1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _cc0 = vmulq_f32(_cc0, _beta); - _cc1 = vmulq_f32(_cc1, _beta); - } - _c0 = vdupq_laneq_f32(_cc0, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); - float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); - float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); - float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); - float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); @@ -11451,23 +8846,80 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f5 = vaddq_f32(_f5, _c5); _f6 = vaddq_f32(_f6, _c6); _f7 = vaddq_f32(_f7, _c7); - pC += 8; } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - _f4 = vmulq_f32(_f4, _alpha); - _f5 = vmulq_f32(_f5, _alpha); - _f6 = vmulq_f32(_f6, _alpha); - _f7 = vmulq_f32(_f7, _alpha); - } - + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); + _f4 = vmlaq_f32(_f4, _c4, _beta); + _f5 = vmlaq_f32(_f5, _c5, _beta); + _f6 = vmlaq_f32(_f6, _c6, _beta); + _f7 = vmlaq_f32(_f7, _c7, _beta); + } + } + if (broadcast_type_C == 4) + { + float32x4_t _cc0 = vld1q_f32(pC); + float32x4_t _cc1 = vld1q_f32(pC + 4); + if (beta != 1.f) + { + float32x4_t _beta = vdupq_n_f32(beta); + _cc0 = vmulq_f32(_cc0, _beta); + _cc1 = vmulq_f32(_cc1, _beta); + } + _c0 = vdupq_laneq_f32(_cc0, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc0, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc0, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc0, 3); + float32x4_t _c4 = vdupq_laneq_f32(_cc1, 0); + float32x4_t _c5 = vdupq_laneq_f32(_cc1, 1); + float32x4_t _c6 = vdupq_laneq_f32(_cc1, 2); + float32x4_t _c7 = vdupq_laneq_f32(_cc1, 3); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + _f4 = vaddq_f32(_f4, _c4); + _f5 = vaddq_f32(_f5, _c5); + _f6 = vaddq_f32(_f6, _c6); + _f7 = vaddq_f32(_f7, _c7); + pC += 8; + } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + _f4 = vmulq_f32(_f4, _alpha); + _f5 = vmulq_f32(_f5, _alpha); + _f6 = vmulq_f32(_f6, _alpha); + _f7 = vmulq_f32(_f7, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _fa; + float32x4x4_t _fb; + _fa.val[0] = _f0; + _fa.val[1] = _f1; + _fa.val[2] = _f2; + _fa.val[3] = _f3; + _fb.val[0] = _f4; + _fb.val[1] = _f5; + _fb.val[2] = _f6; + _fb.val[3] = _f7; + vst4q_f32(p0, _fa); + vst4q_f32(p0 + out_hstep * 4, _fb); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + out_hstep, _f1); vst1q_f32(p0 + out_hstep * 2, _f2); @@ -11476,296 +8928,308 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma vst1q_f32(p0 + out_hstep * 5, _f5); vst1q_f32(p0 + out_hstep * 6, _f6); vst1q_f32(p0 + out_hstep * 7, _f7); - - pp += 32; - p0 += out_hstep * 8; } + + pp += 32; + p0 += out_hstep * 8; + } #endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 3 < max_jj; jj += 4) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 #else - // from - // a0 b1 c2 d3 - // c0 d1 a2 b3 - // a3 b2 c1 d0 - // c3 d2 a1 b0 - - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - // a2 b2 c2 d2 - // a3 b3 c3 d3 - { - _sum2 = vrev64q_s32(_sum2); - _sum3 = vrev64q_s32(_sum3); - _sum2 = vextq_s32(_sum2, _sum2, 2); - _sum3 = vextq_s32(_sum3, _sum3, 2); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); - int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); - _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); - _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - _sum3 = vrev64q_s32(_sum3); - } + // from + // a0 b1 c2 d3 + // c0 d1 a2 b3 + // a3 b2 c1 d0 + // c3 d2 a1 b0 + + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + // a2 b2 c2 d2 + // a3 b3 c3 d3 + { + _sum2 = vrev64q_s32(_sum2); + _sum3 = vrev64q_s32(_sum3); + _sum2 = vextq_s32(_sum2, _sum2, 2); + _sum3 = vextq_s32(_sum3, _sum3, 2); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum3); + int32x4x2_t _t1 = vzipq_s32(_sum1, _sum2); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_t1.val[1]), vget_low_s32(_t0.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_t1.val[1]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + _sum3 = vrev64q_s32(_sum3); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + float32x4_t _c2; + float32x4_t _c3; + if (c_elempack == 4) { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + _c2 = vld1q_f32(pC + 8); + _c3 = vld1q_f32(pC + 12); + pC += 16; } - if (broadcast_type_C == 3) + if (c_elempack == 1) { - float32x4_t _c1; - float32x4_t _c2; - float32x4_t _c3; - if (c_elempack == 1) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - _c2 = vld1q_f32(pC + c_hstep * 2); - _c3 = vld1q_f32(pC + c_hstep * 3); - transpose4x4_ps(_c0, _c1, _c2, _c3); - pC += 4; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - _c2 = vld1q_f32(pC + 8); - _c3 = vld1q_f32(pC + 12); - pC += 16; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + _c2 = vld1q_f32(pC + c_hstep * 2); + _c3 = vld1q_f32(pC + c_hstep * 3); + transpose4x4_ps(_c0, _c1, _c2, _c3); + pC += 4; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x4_t _cc = vld1q_f32(pC); - _cc = vmulq_n_f32(_cc, beta); -#if __aarch64__ - _c0 = vdupq_laneq_f32(_cc, 0); - float32x4_t _c1 = vdupq_laneq_f32(_cc, 1); - float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); - float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); -#else - _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); - float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); - float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); - float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); -#endif _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); _f2 = vaddq_f32(_f2, _c2); _f3 = vaddq_f32(_f3, _c3); - pC += 4; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x4_t _cc = vld1q_f32(pC); + _cc = vmulq_n_f32(_cc, beta); +#if __aarch64__ + _c0 = vdupq_laneq_f32(_cc, 0); + float32x4_t _c1 = vdupq_laneq_f32(_cc, 1); + float32x4_t _c2 = vdupq_laneq_f32(_cc, 2); + float32x4_t _c3 = vdupq_laneq_f32(_cc, 3); +#else + _c0 = vdupq_lane_f32(vget_low_f32(_cc), 0); + float32x4_t _c1 = vdupq_lane_f32(vget_low_f32(_cc), 1); + float32x4_t _c2 = vdupq_lane_f32(vget_high_f32(_cc), 0); + float32x4_t _c3 = vdupq_lane_f32(vget_high_f32(_cc), 1); +#endif + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); + pC += 4; } + } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + + if (out_elempack == 4) + { + float32x4x4_t _f; + _f.val[0] = _f0; + _f.val[1] = _f1; + _f.val[2] = _f2; + _f.val[3] = _f3; + vst4q_f32(p0, _f); + } + if (out_elempack == 1) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + out_hstep, _f1); vst1q_f32(p0 + out_hstep * 2, _f2); vst1q_f32(p0 + out_hstep * 3, _f3); - - pp += 16; - p0 += out_hstep * 4; } - for (; jj + 1 < max_jj; jj += 2) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + + pp += 16; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); #if __ARM_FEATURE_DOTPROD - // from/to - // a0 b0 c0 d0 - // a1 b1 c1 d1 + // from/to + // a0 b0 c0 d0 + // a1 b1 c1 d1 #else - // from - // a0 b1 c0 d1 - // a1 b0 c1 d0 + // from + // a0 b1 c0 d1 + // a1 b0 c1 d0 - // to - // a0 b0 c0 d0 - // a1 b1 c1 d1 - { - _sum1 = vrev64q_s32(_sum1); - int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); - _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); - _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); - _sum1 = vrev64q_s32(_sum1); - } + // to + // a0 b0 c0 d0 + // a1 b1 c1 d1 + { + _sum1 = vrev64q_s32(_sum1); + int32x4x2_t _t0 = vzipq_s32(_sum0, _sum1); + _sum0 = vcombine_s32(vget_low_s32(_t0.val[0]), vget_low_s32(_t0.val[1])); + _sum1 = vcombine_s32(vget_high_s32(_t0.val[0]), vget_high_s32(_t0.val[1])); + _sum1 = vrev64q_s32(_sum1); + } #endif // __ARM_FEATURE_DOTPROD - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3) + { + float32x4_t _c1; + if (c_elempack == 1) { - float32x4_t _c1; - if (c_elempack == 1) - { - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); - float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); - float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); - float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); - float32x4x2_t _cc = vuzpq_f32(_cc01, _cc23); - _c0 = _cc.val[0]; - _c1 = _cc.val[1]; - pC += 2; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - pC += 8; - } - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2_t _cc2 = vld1_f32(pC + c_hstep * 2); + float32x2_t _cc3 = vld1_f32(pC + c_hstep * 3); + float32x4_t _cc01 = vcombine_f32(_cc0, _cc1); + float32x4_t _cc23 = vcombine_f32(_cc2, _cc3); + float32x4x2_t _cc = vuzpq_f32(_cc01, _cc23); + _c0 = _cc.val[0]; + _c1 = _cc.val[1]; + pC += 2; + } + else // if (c_elempack == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + pC += 8; } - if (broadcast_type_C == 4) + if (beta == 1.f) { - float32x2_t _c = vld1_f32(pC); - _c = vmul_n_f32(_c, beta); - _c0 = vdupq_lane_f32(_c, 0); - float32x4_t _c1 = vdupq_lane_f32(_c, 1); _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); - pC += 2; + } + else + { + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + float32x2_t _c = vld1_f32(pC); + _c = vmul_n_f32(_c, beta); + _c0 = vdupq_lane_f32(_c, 0); + float32x4_t _c1 = vdupq_lane_f32(_c, 1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + pC += 2; } - - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep, _f1); - - pp += 8; - p0 += out_hstep * 2; } - for (; jj < max_jj; jj += 1) + + if (alpha != 1.f) { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep, _f1); + + pp += 8; + p0 += out_hstep * 2; + } + for (; jj < max_jj; jj += 1) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3) + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + } + if (broadcast_type_C == 3) + { + if (c_elempack == 1) { - if (c_elempack == 1) - { - _c0 = vsetq_lane_f32(pC[0], _c0, 0); - _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); - _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); - _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); - pC += 1; - } - else // if (c_elempack == 4) - { - _c0 = vld1q_f32(pC); - pC += 4; - } - _f0 = vmlaq_n_f32(_f0, _c0, beta); + _c0 = vsetq_lane_f32(pC[0], _c0, 0); + _c0 = vsetq_lane_f32(pC[c_hstep], _c0, 1); + _c0 = vsetq_lane_f32(pC[c_hstep * 2], _c0, 2); + _c0 = vsetq_lane_f32(pC[c_hstep * 3], _c0, 3); + pC += 1; } - if (broadcast_type_C == 4) + else // if (c_elempack == 4) { - _c0 = vdupq_n_f32(pC[0] * beta); - _f0 = vaddq_f32(_f0, _c0); - pC += 1; + _c0 = vld1q_f32(pC); + pC += 4; } + _f0 = vmlaq_n_f32(_f0, _c0, beta); + } + if (broadcast_type_C == 4) + { + _c0 = vdupq_n_f32(pC[0] * beta); + _f0 = vaddq_f32(_f0, _c0); + pC += 1; } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - vst1q_f32(p0, _f0); - pp += 4; - p0 += out_hstep; - } + vst1q_f32(p0, _f0); + pp += 4; + p0 += out_hstep; } } #endif // __ARM_NEON @@ -11815,440 +9279,277 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } + int jj = 0; #if __ARM_NEON - if (out_elempack == 4) - { - int jj = 0; #if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); - float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); - float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 0); + float32x4_t _f2 = vmulq_lane_f32(vcvtq_f32_s32(_sum2), _descale01, 1); + float32x4_t _f3 = vmulq_lane_f32(vcvtq_f32_s32(_sum3), _descale01, 1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c1); + _f3 = vaddq_f32(_f3, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + c_hstep); + float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c1); - _f3 = vaddq_f32(_f3, _c1); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3) + else { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } - if (broadcast_type_C == 4) + pC += 8; + } + if (broadcast_type_C == 4) + { + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + 4); + if (beta != 1.f) { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c1); - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _c0 = vmulq_f32(_c0, _beta); + _c1 = vmulq_f32(_c1, _beta); } + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c1); + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f2); vst1q_f32(p0 + out_hstep * 4, _f1); vst1q_f32(p0 + out_hstep * 4 + 4, _f3); - - pp += 16; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) + if (out_elempack == 1) { - // a0 a1 a2 a3 - // b0 b1 b2 b3 + float32x4x2_t _f02 = vzipq_f32(_f0, _f2); + float32x4x2_t _f13 = vzipq_f32(_f1, _f3); + vst1_f32(p0, vget_low_f32(_f02.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f02.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f02.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f02.val[1])); + vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f13.val[0])); + vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f13.val[0])); + vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f13.val[1])); + vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f13.val[1])); + } + + pp += 16; + p0 += out_hstep * 8; + } +#endif // __aarch64__ + for (; jj + 3 < max_jj; jj += 4) + { + // a0 a1 a2 a3 + // b0 b1 b2 b3 - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); - float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); - float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); + float32x4_t _f0 = vmulq_lane_f32(vcvtq_f32_s32(_sum0), _descale01, 0); + float32x4_t _f1 = vmulq_lane_f32(vcvtq_f32_s32(_sum1), _descale01, 1); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _c1 = vld1q_f32(pC + c_hstep); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 4; - } - if (broadcast_type_C == 4) + else { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 4; } - - if (alpha != 1.f) + if (broadcast_type_C == 4) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + _c0 = vld1q_f32(pC); + _c0 = vmulq_n_f32(_c0, beta); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + pC += 4; } + } + + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } + if (out_elempack == 4) + { vst1q_f32(p0, _f0); vst1q_f32(p0 + 4, _f1); - - pp += 8; - p0 += out_hstep * 4; } - } -#endif // __ARM_NEON - if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON -#if __aarch64__ - for (; jj + 7 < max_jj; jj += 8) + if (out_elempack == 1) { - // a0 a1 a2 a3 - // a4 a5 a6 a7 - // b0 b1 b2 b3 - // b4 b5 b6 b7 - - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); + float32x4x2_t _f01 = vzipq_f32(_f0, _f1); + vst1_f32(p0, vget_low_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep, vget_high_f32(_f01.val[0])); + vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f01.val[1])); + vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f01.val[1])); + } - int32x4x2_t _sum02 = vzipq_s32(_sum0, _sum2); - int32x4x2_t _sum13 = vzipq_s32(_sum1, _sum3); + pp += 8; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + // a0 a1 b0 b1 + int32x2x2_t _sum0 = vld2_s32(pp); - float32x4_t _descale = vcombine_f32(_descale01, _descale01); + float32x4_t _descale = vcombine_f32(_descale01, _descale01); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum02.val[0]), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum02.val[1]), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum13.val[0]), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum13.val[1]), _descale); + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - _f1 = vaddq_f32(_f1, _cc); - _f2 = vaddq_f32(_f2, _cc); - _f3 = vaddq_f32(_f3, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + c_hstep); - float32x4_t _c3 = vld1q_f32(pC + c_hstep + 4); - float32x4x2_t _c02 = vzipq_f32(_c0, _c2); - float32x4x2_t _c13 = vzipq_f32(_c1, _c3); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c02.val[0]); - _f1 = vaddq_f32(_f1, _c02.val[1]); - _f2 = vaddq_f32(_f2, _c13.val[0]); - _f3 = vaddq_f32(_f3, _c13.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c02.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c02.val[1], _beta); - _f2 = vmlaq_f32(_f2, _c13.val[0], _beta); - _f3 = vmlaq_f32(_f3, _c13.val[1], _beta); - } - pC += 8; - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + 4); - if (beta != 1.f) - { - float32x4_t _beta = vdupq_n_f32(beta); - _c0 = vmulq_f32(_c0, _beta); - _c1 = vmulq_f32(_c1, _beta); - } - float32x4x2_t _cc0 = vzipq_f32(_c0, _c0); - float32x4x2_t _cc1 = vzipq_f32(_c1, _c1); - _f0 = vaddq_f32(_f0, _cc0.val[0]); - _f1 = vaddq_f32(_f1, _cc0.val[1]); - _f2 = vaddq_f32(_f2, _cc1.val[0]); - _f3 = vaddq_f32(_f3, _cc1.val[1]); - pC += 8; - } + _f0 = vaddq_f32(_f0, _c0); } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); + float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; + _f0 = vaddq_f32(_f0, _cc); + } + if (broadcast_type_C == 3) + { + // c_elempack == 1 + float32x2_t _cc0 = vld1_f32(pC); + float32x2_t _cc1 = vld1_f32(pC + c_hstep); + float32x2x2_t _c01 = vzip_f32(_cc0, _cc1); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; + } + if (broadcast_type_C == 4) + { + float32x2_t _cc = vld1_f32(pC); + float32x2x2_t _c01 = vzip_f32(_cc, _cc); + _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 2; } - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - vst1_f32(p0 + out_hstep * 4, vget_low_f32(_f2)); - vst1_f32(p0 + out_hstep * 5, vget_high_f32(_f2)); - vst1_f32(p0 + out_hstep * 6, vget_low_f32(_f3)); - vst1_f32(p0 + out_hstep * 7, vget_high_f32(_f3)); - - pp += 16; - p0 += out_hstep * 8; } -#endif // __aarch64__ - for (; jj + 3 < max_jj; jj += 4) - { - // a0 a1 a2 a3 - // b0 b1 b2 b3 - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); + _f0 = vmulq_n_f32(_f0, alpha); - int32x4x2_t _sum01 = vzipq_s32(_sum0, _sum1); + vst1_f32(p0, vget_low_f32(_f0)); + vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - float32x4_t _descale = vcombine_f32(_descale01, _descale01); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum01.val[0]), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum01.val[1]), _descale); + pp += 4; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale0; + float f1 = pp[1] * descale1; - if (pC) + if (pC) + { + if (broadcast_type_C == 0) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - _f1 = vaddq_f32(_f1, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _c1 = vld1q_f32(pC + c_hstep); - float32x4x2_t _c01 = vzipq_f32(_c0, _c1); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c01.val[0]); - _f1 = vaddq_f32(_f1, _c01.val[1]); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c01.val[0], _beta); - _f1 = vmlaq_f32(_f1, _c01.val[1], _beta); - } - pC += 4; - } - if (broadcast_type_C == 4) - { - _c0 = vld1q_f32(pC); - _c0 = vmulq_n_f32(_c0, beta); - float32x4x2_t _cc = vzipq_f32(_c0, _c0); - _f0 = vaddq_f32(_f0, _cc.val[0]); - _f1 = vaddq_f32(_f1, _cc.val[1]); - pC += 4; - } + f0 += c0; + f1 += c0; } - - if (alpha != 1.f) + if (broadcast_type_C == 1 || broadcast_type_C == 2) { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); + f0 += c0; + f1 += c1; } - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - vst1_f32(p0 + out_hstep * 2, vget_low_f32(_f1)); - vst1_f32(p0 + out_hstep * 3, vget_high_f32(_f1)); - - pp += 8; - p0 += out_hstep * 4; - } - for (; jj + 1 < max_jj; jj += 2) - { - // a0 a1 b0 b1 - int32x2x2_t _sum0 = vld2_s32(pp); - - float32x4_t _descale = vcombine_f32(_descale01, _descale01); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vcombine_s32(_sum0.val[0], _sum0.val[1])), _descale); - - if (pC) + if (broadcast_type_C == 3) { - if (broadcast_type_C == 0) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - float32x4_t _cc = vzipq_f32(_c0, _c1).val[0]; - _f0 = vaddq_f32(_f0, _cc); - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - float32x2_t _cc0 = vld1_f32(pC); - float32x2_t _cc1 = vld1_f32(pC + c_hstep); - float32x2x2_t _c01 = vzip_f32(_cc0, _cc1); - _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } - if (broadcast_type_C == 4) - { - float32x2_t _cc = vld1_f32(pC); - float32x2x2_t _c01 = vzip_f32(_cc, _cc); - _c0 = vcombine_f32(_c01.val[0], _c01.val[1]); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 2; - } + // c_elempack == 1 + f0 += pC[0] * beta; + f1 += pC[c_hstep] * beta; + pC += 1; } - - _f0 = vmulq_n_f32(_f0, alpha); - - vst1_f32(p0, vget_low_f32(_f0)); - vst1_f32(p0 + out_hstep, vget_high_f32(_f0)); - - pp += 4; - p0 += out_hstep * 2; - } -#endif // __ARM_NEON - for (; jj < max_jj; jj += 1) - { - float f0 = pp[0] * descale0; - float f1 = pp[1] * descale1; - - if (pC) + if (broadcast_type_C == 4) { - if (broadcast_type_C == 0) - { - f0 += c0; - f1 += c0; - } - if (broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - f1 += c1; - } - if (broadcast_type_C == 3) - { - // c_elempack == 1 - f0 += pC[0] * beta; - f1 += pC[c_hstep] * beta; - pC += 1; - } - if (broadcast_type_C == 4) - { - f0 += pC[0] * beta; - f1 += pC[0] * beta; - pC += 1; - } + f0 += pC[0] * beta; + f1 += pC[0] * beta; + pC += 1; } + } - f0 *= alpha; - f1 *= alpha; + f0 *= alpha; + f1 *= alpha; - p0[0] = f0; - p0[1] = f1; + p0[0] = f0; + p0[1] = f1; - pp += 2; - p0 += out_hstep; - } + pp += 2; + p0 += out_hstep; } } for (; ii < max_ii; ii += 1) @@ -12292,235 +9593,81 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma } } + int jj = 0; #if __ARM_NEON - if (out_elempack == 4) + for (; jj + 15 < max_jj; jj += 16) { - int jj = 0; - for (; jj + 15 < max_jj; jj += 16) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } - } + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + int32x4_t _sum2 = vld1q_s32(pp + 8); + int32x4_t _sum3 = vld1q_s32(pp + 12); - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - if (out_hstep == 1) - { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); - } - else - { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 8, _f2); - vst1q_f32(p0 + out_hstep * 12, _f3); - } + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); + float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - pp += 16; - p0 += out_hstep * 16; - } - for (; jj + 7 < max_jj; jj += 8) + if (pC) { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } - - if (out_hstep == 1) - { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - } - else + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + _f2 = vaddq_f32(_f2, _c0); + _f3 = vaddq_f32(_f3, _c0); } - - pp += 8; - p0 += out_hstep * 8; - } - for (; jj + 3 < max_jj; jj += 4) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - - if (pC) + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + float32x4_t _c2 = vld1q_f32(pC + 8); + float32x4_t _c3 = vld1q_f32(pC + 12); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c1); + _f2 = vaddq_f32(_f2, _c2); + _f3 = vaddq_f32(_f3, _c3); } - if (broadcast_type_C == 3 || broadcast_type_C == 4) + else { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); + _f2 = vmlaq_f32(_f2, _c2, _beta); + _f3 = vmlaq_f32(_f3, _c3, _beta); } + pC += 16; } + } - _f0 = vmulq_n_f32(_f0, alpha); + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + _f2 = vmulq_f32(_f2, _alpha); + _f3 = vmulq_f32(_f3, _alpha); + } + if (out_hstep == 1) + { vst1q_f32(p0, _f0); - pp += 4; - p0 += out_hstep * 4; + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); } - } -#endif // __ARM_NEON - if (out_elempack == 1) - { - int jj = 0; -#if __ARM_NEON - for (; jj + 15 < max_jj; jj += 16) + else { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - int32x4_t _sum2 = vld1q_s32(pp + 8); - int32x4_t _sum3 = vld1q_s32(pp + 12); - - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - float32x4_t _f2 = vmulq_f32(vcvtq_f32_s32(_sum2), _descale); - float32x4_t _f3 = vmulq_f32(vcvtq_f32_s32(_sum3), _descale); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); - _f2 = vaddq_f32(_f2, _c0); - _f3 = vaddq_f32(_f3, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - float32x4_t _c2 = vld1q_f32(pC + 8); - float32x4_t _c3 = vld1q_f32(pC + 12); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - _f2 = vaddq_f32(_f2, _c2); - _f3 = vaddq_f32(_f3, _c3); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - _f2 = vmlaq_f32(_f2, _c2, _beta); - _f3 = vmlaq_f32(_f3, _c3, _beta); - } - pC += 16; - } - } - - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - _f2 = vmulq_f32(_f2, _alpha); - _f3 = vmulq_f32(_f3, _alpha); - } - - if (out_hstep == 1) + if (out_elempack == 4) { vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); - vst1q_f32(p0 + 8, _f2); - vst1q_f32(p0 + 12, _f3); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 8, _f2); + vst1q_f32(p0 + out_hstep * 12, _f3); } - else + if (out_elempack == 1) { p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -12539,58 +9686,66 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); } - - pp += 16; - p0 += out_hstep * 16; } - for (; jj + 7 < max_jj; jj += 8) - { - int32x4_t _sum0 = vld1q_s32(pp); - int32x4_t _sum1 = vld1q_s32(pp + 4); - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); - float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); + pp += 16; + p0 += out_hstep * 16; + } + for (; jj + 7 < max_jj; jj += 8) + { + int32x4_t _sum0 = vld1q_s32(pp); + int32x4_t _sum1 = vld1q_s32(pp + 4); + + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(_sum0), _descale); + float32x4_t _f1 = vmulq_f32(vcvtq_f32_s32(_sum1), _descale); - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + { + _f0 = vaddq_f32(_f0, _c0); + _f1 = vaddq_f32(_f1, _c0); + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) + // c_elempack == 1 + _c0 = vld1q_f32(pC); + float32x4_t _c1 = vld1q_f32(pC + 4); + if (beta == 1.f) { _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c0); + _f1 = vaddq_f32(_f1, _c1); } - if (broadcast_type_C == 3 || broadcast_type_C == 4) + else { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - float32x4_t _c1 = vld1q_f32(pC + 4); - if (beta == 1.f) - { - _f0 = vaddq_f32(_f0, _c0); - _f1 = vaddq_f32(_f1, _c1); - } - else - { - float32x4_t _beta = vdupq_n_f32(beta); - _f0 = vmlaq_f32(_f0, _c0, _beta); - _f1 = vmlaq_f32(_f1, _c1, _beta); - } - pC += 8; + float32x4_t _beta = vdupq_n_f32(beta); + _f0 = vmlaq_f32(_f0, _c0, _beta); + _f1 = vmlaq_f32(_f1, _c1, _beta); } + pC += 8; } + } - if (alpha != 1.f) - { - float32x4_t _alpha = vdupq_n_f32(alpha); - _f0 = vmulq_f32(_f0, _alpha); - _f1 = vmulq_f32(_f1, _alpha); - } + if (alpha != 1.f) + { + float32x4_t _alpha = vdupq_n_f32(alpha); + _f0 = vmulq_f32(_f0, _alpha); + _f1 = vmulq_f32(_f1, _alpha); + } - if (out_hstep == 1) + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + else + { + if (out_elempack == 4) { vst1q_f32(p0, _f0); - vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + out_hstep * 4, _f1); } - else + if (out_elempack == 1) { p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); @@ -12601,106 +9756,113 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); } - - pp += 8; - p0 += out_hstep * 8; } - for (; jj + 3 < max_jj; jj += 4) - { - float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); - if (pC) + pp += 8; + p0 += out_hstep * 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float32x4_t _f0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(pp)), _descale); + + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vaddq_f32(_f0, _c0); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // out_elempack == 1 - _c0 = vld1q_f32(pC); - _f0 = vmlaq_n_f32(_f0, _c0, beta); - pC += 4; - } + _f0 = vaddq_f32(_f0, _c0); } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + _c0 = vld1q_f32(pC); + _f0 = vmlaq_n_f32(_f0, _c0, beta); + pC += 4; + } + } - _f0 = vmulq_n_f32(_f0, alpha); + _f0 = vmulq_n_f32(_f0, alpha); - if (out_hstep == 1) + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + } + else + { + if (out_elempack == 4) { vst1q_f32(p0, _f0); } - else + if (out_elempack == 1) { p0[0] = vgetq_lane_f32(_f0, 0); p0[out_hstep] = vgetq_lane_f32(_f0, 1); p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); } - - pp += 4; - p0 += out_hstep * 4; } - for (; jj + 1 < max_jj; jj += 2) - { - float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - - if (pC) - { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - _f0 = vadd_f32(_f0, vget_low_f32(_c0)); - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - float32x2_t _c = vld1_f32(pC); - _f0 = vmla_n_f32(_f0, _c, beta); - pC += 2; - } - } - _f0 = vmul_n_f32(_f0, alpha); + pp += 4; + p0 += out_hstep * 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + float32x2_t _f0 = vmul_f32(vcvt_f32_s32(vld1_s32(pp)), vget_low_f32(_descale)); - if (out_hstep == 1) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - vst1_f32(p0, _f0); + _f0 = vadd_f32(_f0, vget_low_f32(_c0)); } - else + if (broadcast_type_C == 3 || broadcast_type_C == 4) { - p0[0] = vget_lane_f32(_f0, 0); - p0[out_hstep] = vget_lane_f32(_f0, 1); + // c_elempack == 1 + float32x2_t _c = vld1_f32(pC); + _f0 = vmla_n_f32(_f0, _c, beta); + pC += 2; } + } - pp += 2; - p0 += out_hstep * 2; + _f0 = vmul_n_f32(_f0, alpha); + + if (out_hstep == 1) + { + vst1_f32(p0, _f0); } -#endif // __ARM_NEON - for (; jj < max_jj; jj += 1) + else { - float f0 = pp[0] * descale; + p0[0] = vget_lane_f32(_f0, 0); + p0[out_hstep] = vget_lane_f32(_f0, 1); + } + + pp += 2; + p0 += out_hstep * 2; + } +#endif // __ARM_NEON + for (; jj < max_jj; jj += 1) + { + float f0 = pp[0] * descale; - if (pC) + if (pC) + { + if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) { - if (broadcast_type_C == 0 || broadcast_type_C == 1 || broadcast_type_C == 2) - { - f0 += c0; - } - if (broadcast_type_C == 3 || broadcast_type_C == 4) - { - // c_elempack == 1 - f0 += pC[0] * beta; - pC += 1; - } + f0 += c0; + } + if (broadcast_type_C == 3 || broadcast_type_C == 4) + { + // c_elempack == 1 + f0 += pC[0] * beta; + pC += 1; } + } - f0 *= alpha; + f0 *= alpha; - p0[0] = f0; + p0[0] = f0; - pp += 1; - p0 += out_hstep; - } + pp += 1; + p0 += out_hstep; } } } From b8421f7433c283551fdc5f46ef8603745eba8552 Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 12 Oct 2024 15:47:29 +0800 Subject: [PATCH 54/55] tiled gemm int8 test --- tests/test_gemm_3.cpp | 2 +- tests/test_gemm_4.cpp | 140 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 tests/test_gemm_4.cpp diff --git a/tests/test_gemm_3.cpp b/tests/test_gemm_3.cpp index 35b30623c872..d7c6c531a05f 100644 --- a/tests/test_gemm_3.cpp +++ b/tests/test_gemm_3.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at diff --git a/tests/test_gemm_4.cpp b/tests/test_gemm_4.cpp new file mode 100644 index 000000000000..3b25cf9e9f97 --- /dev/null +++ b/tests/test_gemm_4.cpp @@ -0,0 +1,140 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +#if NCNN_INT8 +static int test_gemm_int8(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K, float alpha, int transA, int transB, int output_transpose) +{ + ncnn::ParamDict pd; + pd.set(0, alpha); + pd.set(1, 1.f); // beta + pd.set(2, transA); + pd.set(3, transB); + pd.set(14, output_transpose); + pd.set(18, 2); // int8_scale_term + + pd.set(20, TILE_M); + pd.set(21, TILE_N); + pd.set(22, TILE_K); + + std::vector weights(0); + + std::vector a(2); + a[0] = transA ? ncnn::Mat(M, K) : ncnn::Mat(K, M); + a[1] = transB ? ncnn::Mat(K, N) : ncnn::Mat(N, K); + + Randomize(a[0], -10.f, 10.f); + Randomize(a[1], -10.f, 10.f); + + int ret = test_layer("Gemm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gemm_int8 failed M=%d N=%d K=%d TILE_M=%d TILE_N=%d TILE_K=%d alpha=%f transA=%d transB=%d output_transpose=%d\n", M, N, K, TILE_M, TILE_N, TILE_K, alpha, transA, transB, output_transpose); + } + + return ret; +} + +static int test_gemm_0(int M, int N, int K, int TILE_M, int TILE_N, int TILE_K) +{ + return 0 + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 0) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 2.1f, 0, 0, 1) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 3.1f, 0, 1, 1) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 4.1f, 1, 0, 1) + || test_gemm_int8(M, N, K, TILE_M, TILE_N, TILE_K, 5.1f, 1, 1, 1); +} +#endif // NCNN_INT8 + +int main() +{ + SRAND(7767517); + +#if NCNN_INT8 + int mnk[][3] = { + {1, 1, 1}, + {2, 2, 2}, + {3, 3, 3}, + {4, 4, 4}, + {5, 5, 5}, + {6, 6, 6}, + {7, 7, 7}, + {8, 8, 8}, + {15, 15, 15}, + {16, 16, 16}, + {24, 24, 24}, + {31, 31, 31}, + {31, 32, 31}, + {32, 31, 32}, + {32, 32, 32}, + {20, 32, 20}, + {40, 40, 40}, + {47, 47, 47}, + {48, 48, 48}, + {52, 52, 52}, + {63, 64, 63}, + {64, 63, 64}, + {64, 64, 64} + }; + + int tile_mnk[][3] = { + {1, 1, 1}, + {2, 2, 2}, + {4, 4, 4}, + {8, 8, 8}, + {12, 12, 12}, + {16, 16, 16}, + {20, 20, 20}, + {24, 24, 24}, + {28, 28, 28} + }; + + int mnk_count = sizeof(mnk) / sizeof(int) / 3; + int tile_mnk_count = sizeof(tile_mnk) / sizeof(int) / 3; + + for (int i = 0; i < mnk_count; i++) + { + int M = mnk[i][0]; + int N = mnk[i][1]; + int K = mnk[i][2]; + + for (int j = 0; j < tile_mnk_count; j++) + { + int TILE_M = tile_mnk[j][0]; + int TILE_N = tile_mnk[j][1]; + int TILE_K = tile_mnk[j][2]; + + if (TILE_M >= M && TILE_N >= N && TILE_K >= K) + continue; + + int ret = test_gemm_0(M, N, K, TILE_M, TILE_N, TILE_K); + if (ret != 0) + return ret; + } + + // test no tiling + int ret = test_gemm_0(M, N, K, 100, 100, 100); + if (ret != 0) + return ret; + } +#else + // test nothing for non-int8 build +#endif + + return 0; +} From 4f2431239e55fcf4a5acd3d70d70bcf55fe4d5ed Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 12 Oct 2024 17:30:32 +0800 Subject: [PATCH 55/55] opt arm64 tiles, fix asimdhp dispatch --- src/layer/arm/gemm_int8.h | 16 +++++ src/layer/arm/gemm_int8_fp16s.h | 108 ++++++++++++++++---------------- 2 files changed, 70 insertions(+), 54 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index 020df8b9c84a..68688c863102 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -14617,7 +14617,11 @@ static void get_optimal_tile_mnk_int8(int M, int N, int K, int constant_TILE_M, int tile_size = (int)sqrtf((float)l2_cache_size / (2 * sizeof(signed char) + sizeof(int))); TILE_M = std::max(8, tile_size / 8 * 8); +#if __aarch64__ + TILE_N = std::max(8, tile_size / 8 * 8); +#else TILE_N = std::max(4, tile_size / 4 * 4); +#endif TILE_K = std::max(8, tile_size / 8 * 8); if (K > 0) @@ -14630,7 +14634,11 @@ static void get_optimal_tile_mnk_int8(int M, int N, int K, int constant_TILE_M, tile_size = (int)((float)l2_cache_size / 2 / sizeof(signed char) / TILE_K); TILE_M = std::max(8, tile_size / 8 * 8); +#if __aarch64__ + TILE_N = std::max(8, tile_size / 8 * 8); +#else TILE_N = std::max(4, tile_size / 4 * 4); +#endif } } @@ -14645,7 +14653,11 @@ static void get_optimal_tile_mnk_int8(int M, int N, int K, int constant_TILE_M, if (N > 0) { int nn_N = (N + TILE_N - 1) / TILE_N; +#if __aarch64__ + TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 7) / 8 * 8); +#else TILE_N = std::min(TILE_N, ((N + nn_N - 1) / nn_N + 3) / 4 * 4); +#endif } if (nT > 1) @@ -14661,7 +14673,11 @@ static void get_optimal_tile_mnk_int8(int M, int N, int K, int constant_TILE_M, if (constant_TILE_N > 0) { +#if __aarch64__ + TILE_N = (constant_TILE_N + 7) / 8 * 8; +#else TILE_N = (constant_TILE_N + 3) / 4 * 4; +#endif } if (constant_TILE_K > 0) diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index a7e1f15d5ddf..e096a6caf6f6 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -393,7 +393,7 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -561,7 +561,7 @@ static void pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i p0 += 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -2038,7 +2038,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _scale0 = vld1q_f32((const float*)scales + ii); float32x4_t _scale1 = vld1q_f32((const float*)scales + ii + 4); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -2143,7 +2143,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -2512,7 +2512,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int float32x4_t _scale = vld1q_f32((const float*)scales + ii); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -2572,7 +2572,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -2809,7 +2809,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_NEON float32x4_t _scale0 = vdupq_n_f32(scale0); float32x4_t _scale1 = vdupq_n_f32(scale1); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -2849,7 +2849,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -3072,7 +3072,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int #if __ARM_NEON float32x4_t _scale = vdupq_n_f32(scale); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -3114,7 +3114,7 @@ static void transpose_pack_A_tile_fp16_to_int8(const Mat& A, Mat& AT, int i, int p0 += A_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -3414,7 +3414,7 @@ static void pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i { const unsigned short* p0 = (const unsigned short*)B + (j + jj) * B_hstep + k * elempack; -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -3582,7 +3582,7 @@ static void pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i p0 += 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -4440,7 +4440,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int { const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -4545,7 +4545,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -4882,7 +4882,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int { const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -4942,7 +4942,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -5150,7 +5150,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; #if __ARM_NEON -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -5190,7 +5190,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -5409,7 +5409,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int const unsigned short* p0 = (const unsigned short*)B + k * B_hstep + (j + jj) * elempack; #if __ARM_NEON -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (elempack == 8) { int kk = 0; @@ -5451,7 +5451,7 @@ static void transpose_pack_B_tile_fp16_to_int8(const Mat& B, Mat& BT, int j, int p0 += B_hstep * 8; } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (elempack == 4) { int kk = 0; @@ -5821,7 +5821,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _c08 = vld1q_u16(pC); @@ -5896,7 +5896,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } pC += 64; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); @@ -6125,7 +6125,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hfe = (uint16x4_t)vcvt_f16_f32(_fe); uint16x4_t _hff = (uint16x4_t)vcvt_f16_f32(_ff); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf8)); @@ -6138,7 +6138,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 56, vcombine_u16(_hf7, _hff)); p0 += 64; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); @@ -6275,7 +6275,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _c04 = vld1q_u16(pC); @@ -6322,7 +6322,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } pC += 32; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); @@ -6471,7 +6471,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf4)); @@ -6480,7 +6480,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& vst1q_u16(p0 + 24, vcombine_u16(_hf3, _hf7)); p0 += 32; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); @@ -6570,7 +6570,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& { float32x4_t _c2; float32x4_t _c3; -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _cc0 = vld1q_u16(pC); @@ -6581,7 +6581,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_cc1)); pC += 16; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); @@ -6660,14 +6660,14 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf2)); vst1q_u16(p0 + 8, vcombine_u16(_hf1, _hf3)); p0 += 16; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); @@ -6719,7 +6719,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& } if (broadcast_type_C == 3) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _c = vld1q_u16(pC); @@ -6727,7 +6727,7 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); pC += 8; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); @@ -6780,13 +6780,13 @@ static void unpack_output_tile_int32_to_fp16(const Mat& topT, const Mat& C, Mat& uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); p0 += 8; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { vst1_u16(p0, _hf0); @@ -8199,7 +8199,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _c08 = vld1q_u16(pC); @@ -8268,7 +8268,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } pC += 64; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); @@ -8489,7 +8489,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x8_t _hf6 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f6), (uint16x4_t)vcvt_f16_f32(_fe)); uint16x8_t _hf7 = vcombine_u16((uint16x4_t)vcvt_f16_f32(_f7), (uint16x4_t)vcvt_f16_f32(_ff)); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { transpose8x8_u16(_hf0, _hf1, _hf2, _hf3, _hf4, _hf5, _hf6, _hf7); @@ -8502,7 +8502,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma vst1q_u16(p0 + 48, _hf6); vst1q_u16(p0 + 56, _hf7); } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { uint16x8x4_t _hfa; @@ -8640,7 +8640,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _c04 = vld1q_u16(pC); @@ -8680,7 +8680,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } pC += 32; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); @@ -8909,7 +8909,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma { float32x4_t _c2; float32x4_t _c3; -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _c02 = vld1q_u16(pC); @@ -8920,7 +8920,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _c3 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c13)); pC += 16; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { uint16x8_t _c01 = vld1q_u16(pC); @@ -9021,7 +9021,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } if (broadcast_type_C == 3) { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (c_elempack == 8) { uint16x8_t _c = vld1q_u16(pC); @@ -9029,7 +9029,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _c1 = vcvt_f32_f16((float16x4_t)vget_high_u16(_c)); pC += 8; } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (c_elempack == 4) { _c0 = vcvt_f32_f16((float16x4_t)vld1_u16(pC)); @@ -9325,7 +9325,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf6 = (uint16x4_t)vcvt_f16_f32(_f6); uint16x4_t _hf7 = (uint16x4_t)vcvt_f16_f32(_f7); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { transpose4x4_u16(_hf0, _hf1, _hf2, _hf3); @@ -9335,7 +9335,7 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma vst1q_u16(p0 + 16, vcombine_u16(_hf2, _hf6)); vst1q_u16(p0 + 24, vcombine_u16(_hf3, _hf7)); } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { uint16x4x4_t _hfa; @@ -9815,13 +9815,13 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { vst1q_u16(p0, vcombine_u16(_hf0, _hf2)); @@ -10147,13 +10147,13 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } else { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); vst1q_u16(p0 + out_hstep * 8, vcombine_u16(_hf2, _hf3)); } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { vst1_u16(p0, _hf0); @@ -10237,12 +10237,12 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma } else { -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __aarch64__ if (out_elempack == 8) { vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // __aarch64__ if (out_elempack == 4) { vst1_u16(p0, _hf0);