Skip to content

Commit

Permalink
less openmp args
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 27, 2024
1 parent 8068e6c commit a881bd9
Showing 1 changed file with 59 additions and 2 deletions.
61 changes: 59 additions & 2 deletions src/layer/arm/gemm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5301,6 +5301,18 @@ int Gemm_arm::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
#endif // NCNN_BF16

#if NCNN_INT8
struct gemm_arm_int8_omp_args
{
int TILE_M;
int TILE_N;
int TILE_K;
int broadcast_type_C;
int transA;
int output_transpose;
float alpha;
float beta;
};

static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, float beta, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
{
// NCNN_LOGE("gemm_arm_int8");
Expand Down Expand Up @@ -5457,15 +5469,26 @@ static int gemm_arm_int8(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob
if (topT.empty())
return -100;

const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta };

#pragma omp parallel for num_threads(nT)
for (int ppi = 0; ppi < nn_M; ppi++)
{
const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
const int TILE_M = args.TILE_M;
const int TILE_N = args.TILE_N;
const int TILE_K = args.TILE_K;
const int broadcast_type_C = args.broadcast_type_C;
const int transA = args.transA;
const int output_transpose = args.output_transpose;
const float alpha = args.alpha;
const float beta = args.beta;

const int M = transA ? A.w : (A.dims == 3 ? A.c : A.h) * A.elempack;
const int K = transA ? (A.dims == 3 ? A.c : A.h) * A.elempack : A.w;

const int i = ppi * TILE_M;

const int max_ii = std::min((M - i), TILE_M);

Mat topT_tile = topT.channel(get_omp_thread_num());
Expand Down Expand Up @@ -5697,9 +5720,20 @@ static int gemm_AT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Mat&
if (topT.empty())
return -100;

const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta };

#pragma omp parallel for num_threads(nT)
for (int ppi = 0; ppi < nn_M; ppi++)
{
// shadowed variable for less openmp task args
const int TILE_M = args.TILE_M;
const int TILE_N = args.TILE_N;
const int TILE_K = args.TILE_K;
const int broadcast_type_C = args.broadcast_type_C;
const int output_transpose = args.output_transpose;
const float alpha = args.alpha;
const float beta = args.beta;

const int i = ppi * TILE_M;

const int max_ii = std::min((M - i), TILE_M);
Expand Down Expand Up @@ -5778,9 +5812,21 @@ static int gemm_BT_arm_int8(const Mat& A, const Mat& BT, float B_int8_scale, con
if (topT.empty())
return -100;

const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, transA, output_transpose, alpha, beta };

#pragma omp parallel for num_threads(nT)
for (int ppi = 0; ppi < nn_M; ppi++)
{
// shadowed variable for less openmp task args
const int TILE_M = args.TILE_M;
const int TILE_N = args.TILE_N;
const int TILE_K = args.TILE_K;
const int broadcast_type_C = args.broadcast_type_C;
const int transA = args.transA;
const int output_transpose = args.output_transpose;
const float alpha = args.alpha;
const float beta = args.beta;

const int i = ppi * TILE_M;

// shadowed variable for less openmp task args
Expand Down Expand Up @@ -5892,9 +5938,20 @@ static int gemm_AT_BT_arm_int8(const Mat& AT, const Mat& A_int8_scales, const Ma
if (topT.empty())
return -100;

const struct gemm_arm_int8_omp_args args = { TILE_M, TILE_N, TILE_K, broadcast_type_C, 0, output_transpose, alpha, beta };

#pragma omp parallel for num_threads(nT)
for (int ppi = 0; ppi < nn_M; ppi++)
{
// shadowed variable for less openmp task args
const int TILE_M = args.TILE_M;
const int TILE_N = args.TILE_N;
const int TILE_K = args.TILE_K;
const int broadcast_type_C = args.broadcast_type_C;
const int output_transpose = args.output_transpose;
const float alpha = args.alpha;
const float beta = args.beta;

const int i = ppi * TILE_M;

const int max_ii = std::min((M - i), TILE_M);
Expand Down

0 comments on commit a881bd9

Please sign in to comment.