From b5a30656d8190b5b3e81b933b3e871e1a3dad34e Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Wed, 2 Nov 2022 15:18:29 +0000 Subject: [PATCH 1/3] SwiGLU further optimization in MLP bw [ghstack-poisoned] --- tests/test_swiglu.py | 4 +- .../swiglu/cuda/gemm_fused_operand_sum.cu | 79 +++++++++++++------ xformers/components/swiglu/swiglu_packedw.cpp | 24 ++++-- xformers/ops/swiglu.py | 25 ++++-- 4 files changed, 89 insertions(+), 43 deletions(-) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index fee36ba87f..58c6d3a283 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -102,8 +102,8 @@ def generate_test_shapes(): _ops: Sequence[xsw.SwiGLUOp] = [xsw.SwiGLUFusedOp, xsw.SwiGLUPackedFusedOp] -@pytest.mark.parametrize("autocast", [False, True]) -@pytest.mark.parametrize("pack_weights", [False, True]) +@pytest.mark.parametrize("autocast", [False, True], ids=["regular", "autocast"]) +@pytest.mark.parametrize("pack_weights", [False, True], ids=["regular", "packed"]) @pytest.mark.parametrize("op", _ops, ids=[x.NAME for x in _ops]) @pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) @pytest.mark.parametrize("device", _devices) diff --git a/xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu b/xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu index 8ac77a4187..9307df42e9 100644 --- a/xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu +++ b/xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu @@ -2,17 +2,17 @@ #include #include #include +#include #include #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/device/gemm_with_k_reduction.h" #include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" #include "cutlass/reduction/device/reduce_split_k.h" #include "cutlass/reduction/kernel/reduce_split_k.h" #include "cutlass/reduction/thread/reduction_operators.h" #include "cutlass/matrix_coord.h" -#define WORKAROUND_CUTLASS_BUG namespace { template @@ -25,12 +25,8 @@ void gemm_fused_operand_sum_( at::cuda::CUDAGuard device_guard(a.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int64_t M = a.size(0); - int64_t N = b.size(0); - int64_t K = a.size(1); - // templati-ze the cutlass kernel - cutlass::gemm::GemmCoord problem_size(M, N, K); + cutlass::gemm::GemmCoord problem_size(a.size(0), b.size(1), a.size(1)); using ElementAccumulator = float; // Data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation using ElementInputA = scalar_t; @@ -69,11 +65,13 @@ void gemm_fused_operand_sum_( constexpr int NumStages = 4; // Reduce A or B operand along the K dimension -#ifdef WORKAROUND_CUTLASS_BUG - constexpr bool ReduceKForA = false; -#else constexpr bool ReduceKForA = true; -#endif + + // Alignment of A operand + constexpr int AlignmentA = 8; + + // Alignment of B operand + constexpr int AlignmentB = 8; // This code section describes the epilogue part of the kernel, we use default value using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -82,12 +80,11 @@ void gemm_fused_operand_sum_( // memory access. This becomes the vector width of // math instructions in the epilogue too. ElementAccumulator, // Data type of accumulator - ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType::NoBetaScaling>; + ElementComputeEpilogue>; - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction< - ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8, - ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8, + using Gemm = typename cutlass::gemm::device::GemmWithKReduction< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, ElementOutput, LayoutOutput, ElementAccumulator, MMAOp, @@ -99,13 +96,15 @@ void gemm_fused_operand_sum_( EpilogueOp, SwizzleThreadBlock, NumStages, - cutlass::arch::OpMultiplyAdd - >::GemmKernel; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + AlignmentA, + AlignmentB, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; // Below is the reduction kernel used in the case of parallel split-k - using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; + using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>; using ReduceOp = cutlass::reduction::thread::ReduceAdd< ElementAccumulator, @@ -131,7 +130,7 @@ void gemm_fused_operand_sum_( // math instructions in the epilogue too. ElementAccumulator, // Data type of accumulator ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType::NoBetaScaling>; + cutlass::epilogue::thread::ScaleType::Nothing>; using ReduceVectorSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< ReduceVectorSplitKShape, @@ -143,11 +142,8 @@ void gemm_fused_operand_sum_( auto alpha = ElementComputeEpilogue(1); auto beta = ElementComputeEpilogue(0); - using RefA = cutlass::TensorRef; - using RefB = cutlass::TensorRef; - using RefC = cutlass::TensorRef; int reduce_vector_length = ReduceKForA ? problem_size.m() : problem_size.n(); - int split_k_slices = 2; + int split_k_slices = 1; typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, @@ -192,17 +188,42 @@ std::tuple gemm_fused_operand_sum( TORCH_CHECK(a.dim() == 2); TORCH_CHECK(b.dim() == 2); TORCH_CHECK(out_mm.dim() == 2); + TORCH_CHECK(out_mm.size(0) == a.size(0)); + TORCH_CHECK(out_mm.size(1) == b.size(1)); + TORCH_CHECK(out_sum.dim() == 1); #define FWD_PARAMS a,b,out_mm,out_sum if (a.scalar_type() == at::ScalarType::Half) { + TORCH_CHECK(b.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(out_mm.scalar_type() == at::ScalarType::Half); + TORCH_CHECK(out_sum.scalar_type() == at::ScalarType::Half); gemm_fused_operand_sum_(FWD_PARAMS); } else { TORCH_CHECK(a.scalar_type() == at::ScalarType::BFloat16, "Only supports bf16/f16"); + TORCH_CHECK(b.scalar_type() == at::ScalarType::BFloat16); + TORCH_CHECK(out_mm.scalar_type() == at::ScalarType::BFloat16); + TORCH_CHECK(out_sum.scalar_type() == at::ScalarType::BFloat16); gemm_fused_operand_sum_(FWD_PARAMS); } return std::make_tuple(out_mm, out_sum); } + +std::tuple gemm_fused_operand_sum_autocast( + const at::Tensor& a, + const at::Tensor& b, + at::Tensor& out_mm, + at::Tensor& out_sum +) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + auto exec_type = at::autocast::get_autocast_gpu_dtype(); + return gemm_fused_operand_sum( + at::autocast::cached_cast(exec_type, a), + at::autocast::cached_cast(exec_type, b), + out_mm, + out_sum + ); +} } // namespace TORCH_LIBRARY_IMPL(xformers, CUDA, m) { @@ -210,3 +231,9 @@ TORCH_LIBRARY_IMPL(xformers, CUDA, m) { TORCH_SELECTIVE_NAME("xformers::gemm_fused_operand_sum"), TORCH_FN(gemm_fused_operand_sum)); } + +TORCH_LIBRARY_IMPL(xformers, Autocast, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::gemm_fused_operand_sum"), + TORCH_FN(gemm_fused_operand_sum_autocast)); +} diff --git a/xformers/components/swiglu/swiglu_packedw.cpp b/xformers/components/swiglu/swiglu_packedw.cpp index d03f0573ed..a59605f1de 100644 --- a/xformers/components/swiglu/swiglu_packedw.cpp +++ b/xformers/components/swiglu/swiglu_packedw.cpp @@ -28,6 +28,17 @@ std::tuple silu_bw_fused( .typed(); return op.call(x1, x2, dx4); } +std::tuple gemm_fused_operand_sum( + const at::Tensor& a, + const at::Tensor& b, + at::Tensor& out_mm, + at::Tensor& out_sum +) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("xformers::gemm_fused_operand_sum", "") + .typed(); + return op.call(a, b, out_mm, out_sum); +} bool shapesMatch(at::Tensor x, std::vector expectedShape) { if (x.dim() != int64_t(expectedShape.size())) { @@ -107,11 +118,10 @@ class SwiGLUPackedWeights : public torch::autograd::Function Tuple[torch.Tensor, torch.Tensor]: + db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device) + dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device) + torch.ops.xformers.gemm_fused_operand_sum(dy.transpose(-2, -1), x, dw, db) + return dw, db + @classmethod @torch.cuda.amp.custom_bwd def backward(cls, ctx, dx5): @@ -169,8 +178,7 @@ def backward(cls, ctx, dx5): dx1, dx2 = dx1dx2.unbind(1) del x1, x2, dx4 - db3 = dx5.sum(0) # 25us - dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt) + dw3, db3 = cls._linear_bw(dx5, x4) del x4, dx5 if w1w2 is not None: assert dx1dx2.is_contiguous() @@ -180,7 +188,10 @@ def backward(cls, ctx, dx5): # backward of linear1 + linear2 - packed dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x - db1db2 = dx1dx2.sum(0).view([2, dx1.shape[1]]) + dw1dw2, db1db2 = cls._linear_bw( + dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x + ) + db1db2 = db1db2.view([2, dx1.shape[1]]) dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0) db1, db2 = torch.unbind(db1db2, dim=0) else: @@ -188,10 +199,8 @@ def backward(cls, ctx, dx5): torch.addmm( dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx ) # dx += dx1 @ w1 - dw2 = dx2.transpose(-2, -1) @ x # 245us (nt) - db2 = dx2.sum(0) # 50us - dw1 = dx1.transpose(-2, -1) @ x # 245us (nt) - db1 = dx1.sum(0) # 50us + dw2, db2 = cls._linear_bw(dx2, x) + dw1, db1 = cls._linear_bw(dx1, x) return (dx, dw1, db1, dw2, db2, dw3, db3) From 28e896a054e0021114cde814045624d372aa3365 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Thu, 3 Nov 2022 08:21:43 +0000 Subject: [PATCH 2/3] Update on "SwiGLU further optimization in MLP bw" ***PERFORMANCE A100** ``` operandfused_all <- THIS PR SwiGLUPackedFusedOp <- previous pr [--------------------------------------- swiglu_bw ---------------------------------------] | operandfused_all | eager | SwiGLUPackedFusedOp 1 threads: -------------------------------------------------------------------------------- b16 B=9456, I=1536, H=4096 | 2227.6 | 2708.3 | 2341.6 f16 B=9456, I=1536, H=4096 | 2337.5 | 2705.8 | 2339.1 f16.ac B=9456, I=1536, H=4096 | 2630.5 | 2998.5 | 2806.6 b16 B=4440, I=1536, H=4096 | 1177.9 | 1424.5 | 1246.4 f16 B=4440, I=1536, H=4096 | 1205.1 | 1418.8 | 1240.6 f16.ac B=4440, I=1536, H=4096 | 1409.0 | 1637.4 | 1541.7 b16 B=4728, I=1536, H=4096 | 1238.6 | 1493.5 | 1397.5 f16 B=4728, I=1536, H=4096 | 1274.8 | 1488.2 | 1392.7 f16.ac B=4728, I=1536, H=4096 | 1478.2 | 1710.3 | 1512.9 b16 B=4728, I=1536, H=1024 | 461.0 | 518.7 | 487.7 f16 B=4728, I=1536, H=1024 | 438.2 | 498.3 | 479.8 f16.ac B=4728, I=1536, H=1024 | 560.9 | 623.2 | 601.4 Times are in microseconds (us). ``` [ghstack-poisoned] --- .gitmodules | 2 +- third_party/cutlass | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index ab23324aec..6d5c0f8734 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/HazyResearch/flash-attention.git [submodule "third_party/cutlass"] path = third_party/cutlass - url = https://github.com/NVIDIA/cutlass.git + url = https://github.com/hwu36/cutlass.git diff --git a/third_party/cutlass b/third_party/cutlass index 1b4e24470a..3b30873c6c 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 1b4e24470a369fc0dfc12987c2a43036207b4f04 +Subproject commit 3b30873c6c31c55cf0ef08a5e55d3f0435b0d958 From c827f1c9149c2bf4af7f2aa9c8b1b452ba0bb0c9 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Thu, 3 Nov 2022 08:31:21 +0000 Subject: [PATCH 3/3] Update on "SwiGLU further optimization in MLP bw" ***PERFORMANCE A100** ``` operandfused_all <- THIS PR SwiGLUPackedFusedOp <- previous pr [--------------------------------------- swiglu_bw ---------------------------------------] | operandfused_all | eager | SwiGLUPackedFusedOp 1 threads: -------------------------------------------------------------------------------- b16 B=9456, I=1536, H=4096 | 2227.6 | 2708.3 | 2341.6 f16 B=9456, I=1536, H=4096 | 2337.5 | 2705.8 | 2339.1 f16.ac B=9456, I=1536, H=4096 | 2630.5 | 2998.5 | 2806.6 b16 B=4440, I=1536, H=4096 | 1177.9 | 1424.5 | 1246.4 f16 B=4440, I=1536, H=4096 | 1205.1 | 1418.8 | 1240.6 f16.ac B=4440, I=1536, H=4096 | 1409.0 | 1637.4 | 1541.7 b16 B=4728, I=1536, H=4096 | 1238.6 | 1493.5 | 1397.5 f16 B=4728, I=1536, H=4096 | 1274.8 | 1488.2 | 1392.7 f16.ac B=4728, I=1536, H=4096 | 1478.2 | 1710.3 | 1512.9 b16 B=4728, I=1536, H=1024 | 461.0 | 518.7 | 487.7 f16 B=4728, I=1536, H=1024 | 438.2 | 498.3 | 479.8 f16.ac B=4728, I=1536, H=1024 | 560.9 | 623.2 | 601.4 Times are in microseconds (us). ``` [ghstack-poisoned] --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index 3b30873c6c..f9b6c32dcb 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 3b30873c6c31c55cf0ef08a5e55d3f0435b0d958 +Subproject commit f9b6c32dcb1b19cf184261a34118d271371f26f9