Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix transformer.cu interleaved matmul for cuda arch < 5 (#17596)
Browse files Browse the repository at this point in the history
cublasGemmBatchedEx is only supported for GPU with architecture capabilities equal or greater than 5.0.

Fixes a bug in #16408
  • Loading branch information
leezu authored Feb 15, 2020
1 parent 9ee4f04 commit d352673
Showing 1 changed file with 59 additions and 12 deletions.
71 changes: 59 additions & 12 deletions src/operator/contrib/transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,65 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
<< "Must init CuBLAS handle in stream";

cublasHandle_t blas_handle = mshadow::Stream<gpu>::GetBlasHandle(s);
auto err = CUBLAS_STATUS_SUCCESS;
// TODO(cfujitsang): handle computation_precision
err = cublasGemmStridedBatchedEx(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<void*>(&alpha),
a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
reinterpret_cast<void*>(&beta),
c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
static_cast<int>(batchCount), CUDA_R_32F, algo);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail.";
// cublasGemmStridedBatchedEx is only supported for GPU with architecture
// capabilities equal or greater than 5.0. Fall back to
// cublasSgemmStridedBatched, which doesn't support implicit conversion
// to half-precision to use TensorCores
auto cc_major = (s->prop).major;
if (cc_major >= 5) {
CUBLAS_CALL(cublasGemmStridedBatchedEx(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<void*>(&alpha),
a, CublasType<DType>::kCudaFlag, static_cast<int>(lda), strideA,
b, CublasType<DType>::kCudaFlag, static_cast<int>(ldb), strideB,
reinterpret_cast<void*>(&beta),
c, CublasType<DType>::kCudaFlag, static_cast<int>(ldc), strideC,
static_cast<int>(batchCount), CUDA_R_32F, algo));
} else {
if (std::is_same<DType, float>::value) {
CUBLAS_CALL(cublasSgemmStridedBatched(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<float*>(&alpha),
reinterpret_cast<const float*>(a),
static_cast<int>(lda), strideA,
reinterpret_cast<const float*>(b),
static_cast<int>(ldb), strideB,
reinterpret_cast<float*>(&beta),
reinterpret_cast<float*>(c),
static_cast<int>(ldc), strideC,
static_cast<int>(batchCount)));
} else if (std::is_same<DType, double>::value) {
CUBLAS_CALL(cublasDgemmStridedBatched(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<double*>(&alpha),
reinterpret_cast<const double*>(a),
static_cast<int>(lda), strideA,
reinterpret_cast<const double*>(b),
static_cast<int>(ldb), strideB,
reinterpret_cast<double*>(&beta),
reinterpret_cast<double*>(c),
static_cast<int>(ldc), strideC,
static_cast<int>(batchCount)));
} else if (std::is_same<DType, mshadow::half::half_t>::value) {
CUBLAS_CALL(cublasHgemmStridedBatched(
blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB),
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k),
reinterpret_cast<__half*>(&alpha),
reinterpret_cast<const __half*>(a),
static_cast<int>(lda), strideA,
reinterpret_cast<const __half*>(b),
static_cast<int>(ldb), strideB,
reinterpret_cast<__half*>(&beta),
reinterpret_cast<__half*>(c),
static_cast<int>(ldc), strideC,
static_cast<int>(batchCount)));
} else {
LOG(FATAL) << "Unsupported DType in CublasStridedBatchedGemm.";
}
}
#else
LOG(FATAL) << "Not implemented with CUDA < 9.1";
#endif
Expand Down

0 comments on commit d352673

Please sign in to comment.