From d352673bf486d198aa15359fbd168ad2e3ec528c Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Sat, 15 Feb 2020 06:00:48 +0000 Subject: [PATCH] Fix transformer.cu interleaved matmul for cuda arch < 5 (#17596) cublasGemmBatchedEx is only supported for GPU with architecture capabilities equal or greater than 5.0. Fixes a bug in #16408 --- src/operator/contrib/transformer.cu | 71 ++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index e152669478dd..59029eae65c2 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -50,18 +50,65 @@ void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, << "Must init CuBLAS handle in stream"; cublasHandle_t blas_handle = mshadow::Stream::GetBlasHandle(s); - auto err = CUBLAS_STATUS_SUCCESS; - // TODO(cfujitsang): handle computation_precision - err = cublasGemmStridedBatchedEx( - blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), - static_cast(m), static_cast(n), static_cast(k), - reinterpret_cast(&alpha), - a, CublasType::kCudaFlag, static_cast(lda), strideA, - b, CublasType::kCudaFlag, static_cast(ldb), strideB, - reinterpret_cast(&beta), - c, CublasType::kCudaFlag, static_cast(ldc), strideC, - static_cast(batchCount), CUDA_R_32F, algo); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail."; + // cublasGemmStridedBatchedEx is only supported for GPU with architecture + // capabilities equal or greater than 5.0. Fall back to + // cublasSgemmStridedBatched, which doesn't support implicit conversion + // to half-precision to use TensorCores + auto cc_major = (s->prop).major; + if (cc_major >= 5) { + CUBLAS_CALL(cublasGemmStridedBatchedEx( + blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), + static_cast(m), static_cast(n), static_cast(k), + reinterpret_cast(&alpha), + a, CublasType::kCudaFlag, static_cast(lda), strideA, + b, CublasType::kCudaFlag, static_cast(ldb), strideB, + reinterpret_cast(&beta), + c, CublasType::kCudaFlag, static_cast(ldc), strideC, + static_cast(batchCount), CUDA_R_32F, algo)); + } else { + if (std::is_same::value) { + CUBLAS_CALL(cublasSgemmStridedBatched( + blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), + static_cast(m), static_cast(n), static_cast(k), + reinterpret_cast(&alpha), + reinterpret_cast(a), + static_cast(lda), strideA, + reinterpret_cast(b), + static_cast(ldb), strideB, + reinterpret_cast(&beta), + reinterpret_cast(c), + static_cast(ldc), strideC, + static_cast(batchCount))); + } else if (std::is_same::value) { + CUBLAS_CALL(cublasDgemmStridedBatched( + blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), + static_cast(m), static_cast(n), static_cast(k), + reinterpret_cast(&alpha), + reinterpret_cast(a), + static_cast(lda), strideA, + reinterpret_cast(b), + static_cast(ldb), strideB, + reinterpret_cast(&beta), + reinterpret_cast(c), + static_cast(ldc), strideC, + static_cast(batchCount))); + } else if (std::is_same::value) { + CUBLAS_CALL(cublasHgemmStridedBatched( + blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), + static_cast(m), static_cast(n), static_cast(k), + reinterpret_cast<__half*>(&alpha), + reinterpret_cast(a), + static_cast(lda), strideA, + reinterpret_cast(b), + static_cast(ldb), strideB, + reinterpret_cast<__half*>(&beta), + reinterpret_cast<__half*>(c), + static_cast(ldc), strideC, + static_cast(batchCount))); + } else { + LOG(FATAL) << "Unsupported DType in CublasStridedBatchedGemm."; + } + } #else LOG(FATAL) << "Not implemented with CUDA < 9.1"; #endif