diff --git a/src/layer/arm/gemm_arm.cpp b/src/layer/arm/gemm_arm.cpp index 2d4ff8734f8..f6ea6bc87c8 100644 --- a/src/layer/arm/gemm_arm.cpp +++ b/src/layer/arm/gemm_arm.cpp @@ -3785,9 +3785,9 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; - const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); + const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -3840,8 +3840,8 @@ static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int const int i = ppi * TILE_M; // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); const int max_ii = std::min((M - i), TILE_M); @@ -3899,7 +3899,7 @@ static int gemm_arm(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int static int gemm_AT_arm(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -3994,7 +3994,7 @@ static int gemm_AT_arm(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, static int gemm_BT_arm(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -4018,8 +4018,8 @@ static int gemm_BT_arm(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, const int i = ppi * TILE_M; // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); const int max_ii = std::min((M - i), TILE_M); @@ -4329,20 +4329,20 @@ int Gemm_arm::forward(const std::vector& bottom_blobs, std::vector& to { const Mat& B = bottom_blobs[0]; M = constantM; - N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); } else if (constantB) { const Mat& A = bottom_blobs[0]; - M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); N = constantN; } else { const Mat& A = bottom_blobs[0]; const Mat& B = bottom_blobs[1]; - M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); } Mat C; @@ -4502,9 +4502,9 @@ int Gemm_arm::forward(const std::vector& bottom_blobs, std::vector& to #if NCNN_BF16 static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; - const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); + const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -4557,8 +4557,8 @@ static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo const int i = ppi * TILE_M; // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); const int max_ii = std::min((M - i), TILE_M); @@ -4617,7 +4617,7 @@ static int gemm_arm_bf16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo static int gemm_AT_arm_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -4713,7 +4713,7 @@ static int gemm_AT_arm_bf16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top static int gemm_BT_arm_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -4737,8 +4737,8 @@ static int gemm_BT_arm_bf16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top const int i = ppi * TILE_M; // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); const int max_ii = std::min((M - i), TILE_M); @@ -5001,20 +5001,20 @@ int Gemm_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vectorconvert_packing(B0, B, 1, cmd, opt); vkdev->convert_packing(C0, C, 1, cmd, opt); - const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); - const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; - const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; + const int M = constantM ? constantM : transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); + const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); int broadcast_type_C; if (constantC) @@ -303,9 +303,9 @@ int Gemm_vulkan::forward(const std::vector& bottom_blobs, std::vecto vkdev->convert_packing(B0, B, 1, cmd, opt); vkdev->convert_packing(C0, C, 1, cmd, opt); - const int M = constantM ? constantM : transA ? A.w : (A.dims == 3 ? A.c : A.h); - const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) : A.w; - const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) : B.w; + const int M = constantM ? constantM : transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = constantK ? constantK : transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); + const int N = constantN ? constantN : transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); int broadcast_type_C; if (constantC) diff --git a/src/layer/x86/gemm_x86.cpp b/src/layer/x86/gemm_x86.cpp index 19cd7ebc09a..da1a14e68d2 100644 --- a/src/layer/x86/gemm_x86.cpp +++ b/src/layer/x86/gemm_x86.cpp @@ -6843,9 +6843,9 @@ static void get_optimal_tile_mnk(int M, int N, int K, int constant_TILE_M, int c static int gemm_x86(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; - const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1) : (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); + const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -6897,9 +6897,9 @@ static int gemm_x86(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int { const int i = ppi * TILE_M; - // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + // shadowed variable for less openmp task args + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1) : (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); const int max_ii = std::min((M - i), TILE_M); @@ -6957,7 +6957,7 @@ static int gemm_x86(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int static int gemm_AT_x86(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int M, int K, int transB, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + const int N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -7052,7 +7052,7 @@ static int gemm_AT_x86(const Mat& AT, const Mat& B, const Mat& C, Mat& top_blob, static int gemm_BT_x86(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, int broadcast_type_C, int N, int K, int transA, int output_transpose, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt) { - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1) : (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); // NCNN_LOGE("M/N/K = %d %d %d", M, N, K); @@ -7076,8 +7076,8 @@ static int gemm_BT_x86(const Mat& A, const Mat& BT, const Mat& C, Mat& top_blob, const int i = ppi * TILE_M; // shadowed variable for less openmp task args - const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w; + const int M = transA ? A.w * (A.dims == 1 ? A.elempack : 1) : (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + const int K = transA ? (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack) : A.w * (A.dims == 1 ? A.elempack : 1); const int max_ii = std::min((M - i), TILE_M); @@ -7348,20 +7348,20 @@ int Gemm_x86::forward(const std::vector& bottom_blobs, std::vector& to { const Mat& B = bottom_blobs[0]; M = constantM; - N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); } else if (constantB) { const Mat& A = bottom_blobs[0]; - M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; + M = transA ? A.w * (A.dims == 1 ? A.elempack : 1): (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); N = constantN; } else { const Mat& A = bottom_blobs[0]; const Mat& B = bottom_blobs[1]; - M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack; - N = transB ? (B.dims == 3 ? B.c : B.h) * B.elempack : B.w; + M = transA ? A.w * (A.dims == 1 ? A.elempack : 1) : (A.dims == 3 ? A.c : A.h) * (A.dims == 1 ? 1 : A.elempack); + N = transB ? (B.dims == 3 ? B.c : B.h) * (B.dims == 1 ? 1 : B.elempack) : B.w * (B.dims == 1 ? B.elempack : 1); } Mat C;