Skip to content

Commit

Permalink
[CINN]revise cinn fp16 matmul cbulas api to gemmex (PaddlePaddle#56845)
Browse files Browse the repository at this point in the history
* change cinn fp16 matmul cbulas api ti gemmex

* fix flag error

* remove flag

* fix flags

* fix test

* fix test

* fix fp cublas gemmbatchedstrideex
  • Loading branch information
GGBond8488 authored and SecretXV committed Nov 28, 2023
1 parent 6ecda61 commit 209ae9d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
71 changes: 71 additions & 0 deletions paddle/cinn/runtime/cuda/cublas_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,27 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype,
reinterpret_cast<double *>(C),
ldc);
} else if (dtype == CUDA_R_16F) {
#if CUDA_VERSION >= 11000
return cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
&alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
&beta,
C,
CUDA_R_16F,
ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#else
common::float16 alpha_fp16{alpha};
common::float16 beta_fp16{beta};
return cublasHgemm(handle,
Expand All @@ -86,6 +107,7 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype,
reinterpret_cast<const __half *>(&beta_fp16),
reinterpret_cast<__half *>(C),
ldc);
#endif
} else if (dtype == CUDA_R_16BF) {
#if CUDA_VERSION >= 11000
return cublasGemmEx(handle,
Expand Down Expand Up @@ -174,6 +196,31 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype,
strideC,
batchCount);
} else if (dtype == CUDA_R_16F) {
#if CUDA_VERSION >= 11000
return cublasGemmStridedBatchedEx(handle,
transa,
transb,
m,
n,
k,
&alpha,
A,
CUDA_R_16F,
lda,
strideA,
B,
CUDA_R_16F,
ldb,
strideB,
&beta,
C,
CUDA_R_16F,
ldc,
strideC,
batchCount,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#else
common::float16 alpha_fp16{alpha};
common::float16 beta_fp16{beta};
return cublasHgemmStridedBatched(
Expand All @@ -195,6 +242,7 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype,
ldc,
strideC,
batchCount);
#endif
} else if (dtype == CUDA_R_16BF) {
#if CUDA_VERSION >= 11000
return cublasGemmStridedBatchedEx(handle,
Expand Down Expand Up @@ -279,6 +327,28 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype,
ldc,
batchCount);
} else if (dtype == CUDA_R_16F) {
#if CUDA_VERSION >= 11000
return cublasGemmBatchedEx(handle,
transa,
transb,
m,
n,
k,
&alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
&beta,
C,
CUDA_R_16F,
ldc,
batchCount,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#else
__half alpha_fp16{alpha};
__half beta_fp16{beta};
return cublasHgemmBatched(handle,
Expand All @@ -296,6 +366,7 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype,
reinterpret_cast<__half **>(C),
ldc,
batchCount);
#endif
} else if (dtype == CUDA_R_16BF) {
#if CUDA_VERSION >= 11000
return cublasGemmBatchedEx(handle,
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ void cinn_call_cublas(void *v_args,
int n = trans_o ? (trans_b ? b3 : b4) : (trans_a ? a4 : a3);
int k = trans_a ? a3 : a4;

VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k;

cublasOperation_t trans_op_l = trans_o
? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T)
: (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N);
Expand Down Expand Up @@ -245,7 +247,7 @@ void cinn_call_cublas(void *v_args,
int batch = std::max(a2, b2);
VLOG(3) << "call cublasGemmStridedBatched with a1*b1 = 1, stride_l = "
<< stride_l << ", stride_r = " << stride_r
<< ", batch = " << batch;
<< ", batch = " << batch << ", dtype = " << cuda_dtype;
cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched",
cinn::utils::EventType::kInstruction);
CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype,
Expand Down
2 changes: 0 additions & 2 deletions test/cinn/ops/test_matmul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def init_attrs(self):
# },
{
"dtype": "float16",
"max_relative_error": 1e-2,
"max_absolute_error": 1e-2,
},
{
"dtype": "float32",
Expand Down

0 comments on commit 209ae9d

Please sign in to comment.