From 8cfc64a3eba6822bfd4cd5aa744411ea08e32f90 Mon Sep 17 00:00:00 2001 From: Brenton Chu Date: Thu, 16 Apr 2020 16:25:25 -0700 Subject: [PATCH] No tensor cores for fp32 interleaved attention, remove div by 8 restriction (#17994) (#18085) (cherry picked from commit afae030beb168f09cf08be101714e059157a9507) --- src/operator/contrib/transformer.cu | 53 ++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index 44c8ebdbb959..bcbc18525c09 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -43,7 +43,7 @@ void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, float alpha, const DType* a, int32_t lda, int32_t strideA, const DType *b, int32_t ldb, int32_t strideB, float beta, DType *c, int32_t ldc, int32_t strideC, int32_t batchCount, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT) { #if CUDA_VERSION >= 9010 using namespace mxnet::common::cuda; CHECK_EQ(s->blas_handle_ownership_, mshadow::Stream::OwnHandle) @@ -142,9 +142,9 @@ void gemm_switch_fp32accum(mshadow::Stream* s, bool transA, bool transB, float alpha, const DType *a, int32_t lda, int32_t strideA, const DType *b, int32_t ldb, int32_t strideB, float beta, DType *c, int32_t ldc, - int32_t strideC, int32_t batchCount) { + int32_t strideC, int32_t batchCount, bool using_fp16) { cudaStream_t stream = mshadow::Stream::GetStream(s); - if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { + if (using_fp16) { CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } else { @@ -175,6 +175,7 @@ void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, const int32_t batch_stride = 3 * head_dim; const float beta = req[0] == kAddTo ? 1.f : 0.f; const float scale = 1.0 / sqrt(static_cast(head_dim)); + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; if (req[0] == kNullOp) return; @@ -196,7 +197,8 @@ void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, output, qkv_seq_len, qkv_seq_len * qkv_seq_len, - attn_batches); + attn_batches, + using_fp16); }) } @@ -220,7 +222,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, const int32_t lead_dim = attn_batches * 3 * head_dim; const int32_t batch_stride = 3 * head_dim; const float scale = 1.0 / sqrt(static_cast(head_dim)); - const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; if (req[0] == kNullOp) return; @@ -247,7 +250,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, queries_keys_values_grads, lead_dim, batch_stride, - attn_batches); + attn_batches, + using_fp16); gemm_switch_fp32accum(s, false, true, @@ -265,7 +269,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, queries_keys_values_grads + head_dim, lead_dim, batch_stride, - attn_batches); + attn_batches, + using_fp16); }) } @@ -290,6 +295,7 @@ void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, const int32_t batch_stride = 3 * head_dim; const float alpha = 1.f; const float beta = req[0] == kAddTo ? 1.f : 0.f; + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; if (req[0] == kNullOp) return; @@ -311,7 +317,8 @@ void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, output, head_dim * attn_batches, head_dim, - attn_batches); + attn_batches, + using_fp16); }) } @@ -337,6 +344,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, const int32_t lead_dim = attn_batches * 3 * head_dim; const int32_t batch_stride = 3 * head_dim; const float alpha = 1.f; + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; + if (req[0] != kNullOp) { if (req[0] == kWriteTo) { cudaMemsetAsync(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType), @@ -360,7 +369,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, queries_keys_values_grads + 2 * head_dim, lead_dim, batch_stride, - attn_batches); + attn_batches, + using_fp16); } if (req[1] != kNullOp) { const float beta = req[1] == kAddTo ? 1.f : 0.f; @@ -381,7 +391,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, attention_maps_grads, qkv_seq_len, qkv_seq_len * qkv_seq_len, - attn_batches); + attn_batches, + using_fp16); } }) } @@ -412,6 +423,7 @@ void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, const int32_t batch_stride_kv = head_dim * 2; const float beta = req[0] == kAddTo ? 1.f : 0.f; const float scale = 1.f / sqrt(static_cast(head_dim)); + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; if (req[0] == kNullOp) return; @@ -433,7 +445,8 @@ void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, output, kv_seq_len, kv_seq_len * q_seq_len, - attn_batches); + attn_batches, + using_fp16); }) } @@ -463,6 +476,7 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, const int32_t batch_stride_q = head_dim; const int32_t batch_stride_kv = head_dim * 2; const float scale = 1.f / sqrt(static_cast(head_dim)); + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; if (req[0] != kNullOp) { const float beta = req[0] == kAddTo ? 1.f : 0.f; @@ -483,7 +497,8 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, queries_grads, lead_dim_q, batch_stride_q, - attn_batches); + attn_batches, + using_fp16); } if (req[1] != kNullOp) { if (req[1] == kWriteTo) { @@ -508,7 +523,8 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, keys_values_grads, lead_dim_kv, batch_stride_kv, - attn_batches); + attn_batches, + using_fp16); } }) } @@ -535,6 +551,7 @@ void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, const int32_t batch_stride_kv = 2 * head_dim; const float alpha = 1.f; const float beta = req[0] == kAddTo ? 1.f : 0.f; + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; if (req[0] == kNullOp) return; @@ -556,7 +573,8 @@ void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, output, head_dim * attn_batches, head_dim, - attn_batches); + attn_batches, + using_fp16); }) } @@ -583,6 +601,7 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, const int32_t lead_dim_kv = attn_batches * head_dim * 2; const int32_t batch_stride_kv = 2 * head_dim; const float alpha = 1.f; + const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16; if (req[0] != kNullOp) { if (req[0] == kWriteTo) { @@ -607,7 +626,8 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, keys_values_grads + head_dim, lead_dim_kv, batch_stride_kv, - attn_batches); + attn_batches, + using_fp16); } if (req[1] != kNullOp) { const float beta = req[1] == kAddTo ? 1.f : 0.f; @@ -628,7 +648,8 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, attention_maps_grads, kv_seq_len, kv_seq_len * q_seq_len, - attn_batches); + attn_batches, + using_fp16); } }) }