-
Notifications
You must be signed in to change notification settings - Fork 2
/
cpu_gemm.h
28 lines (26 loc) · 1.12 KB
/
cpu_gemm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#pragma once
#include <cmath>
template <typename computeType, typename scaleType, typename inputType,
typename resultType>
void cpuGEMM(inputType *inputA, inputType *inputB, resultType *resultC, int M,
int N, int K, int strideA, int strideB, int strideC,
int batchCount, scaleType alpha, scaleType beta, bool transA = true,
bool transB = false, bool transC = true) {
for (int batch = 0; batch < batchCount; batch++) {
inputType *A = inputA + batch * strideA;
inputType *B = inputB + batch * strideB;
resultType *C = resultC + batch * strideC;
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
computeType sum = 0;
for (int k = 0; k < K; k++) {
inputType a = transA ? A[m * K + k] : A[k * M + m];
inputType b = transB ? B[k * N + n] : B[n * K + k];
sum += static_cast<computeType>(a) * static_cast<computeType>(b);
}
unsigned ci = transC ? m * N + n : n * M + m ;
C[ci] = static_cast<resultType>(alpha * static_cast<computeType>(sum) + beta * static_cast<computeType>(C[ci]));
}
}
}
}