From 59249e22087906395bd6e4900c7e38c664593a7d Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Tue, 25 Oct 2022 14:17:29 +0000 Subject: [PATCH] SwiGLU optimized fw/bw ghstack-source-id: 7b874c69561bf1756e95ccfad9407e4ea9d18e85 Pull Request resolved: https://github.com/facebookresearch/xformers/pull/490 --- setup.py | 22 +- tests/test_swiglu.py | 23 +- xformers/benchmarks/benchmark_swiglu.py | 21 +- .../attention/csrc/cuda/sddmm2_cuda.cu | 2 +- .../swiglu/cuda/43_dual_gemm/CMakeLists.txt | 35 + .../cuda/43_dual_gemm/device/dual_gemm.h | 457 ++++++++++ .../swiglu/cuda/43_dual_gemm/dual_gemm_run.h | 829 ++++++++++++++++++ .../cuda/43_dual_gemm/kernel/dual_gemm.h | 487 ++++++++++ .../swiglu/cuda/43_dual_gemm/test_run.h | 94 ++ .../43_dual_gemm/thread/left_silu_and_mul.h | 150 ++++ .../43_dual_gemm/threadblock/dual_epilogue.h | 430 +++++++++ .../43_dual_gemm/threadblock/dual_mma_base.h | 218 +++++ .../threadblock/dual_mma_multistage.h | 760 ++++++++++++++++ .../cuda/dual_gemm_silu_identity_mul.cu | 150 ++++ .../swiglu/cuda/gemm_fused_operand_sum.cu | 212 +++++ .../components/swiglu/cuda/silu_bw_fused.cu | 98 +++ xformers/components/swiglu/swiglu.cpp | 10 + xformers/ops/__init__.py | 2 +- xformers/ops/swiglu.py | 89 +- xformers/ops/unbind.py | 19 +- 20 files changed, 4072 insertions(+), 36 deletions(-) create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/CMakeLists.txt create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/device/dual_gemm.h create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/dual_gemm_run.h create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/kernel/dual_gemm.h create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/test_run.h create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/thread/left_silu_and_mul.h create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_epilogue.h create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_base.h create mode 100644 xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_multistage.h create mode 100644 xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu create mode 100644 xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu create mode 100644 xformers/components/swiglu/cuda/silu_bw_fused.cu create mode 100644 xformers/components/swiglu/swiglu.cpp diff --git a/setup.py b/setup.py index f3ddc246f6..6efb94b9f4 100644 --- a/setup.py +++ b/setup.py @@ -145,24 +145,21 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) - extensions_dir = os.path.join( - this_dir, "xformers", "components", "attention", "csrc" - ) + extensions_dir = os.path.join(this_dir, "xformers", "components") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) - source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob( - os.path.join(extensions_dir, "autograd", "*.cpp") - ) + source_cpu = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True) sources = main_file + source_cpu source_cuda = glob.glob( - os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True + os.path.join(extensions_dir, "**", "cuda", "**", "*.cu"), recursive=True ) sputnik_dir = os.path.join(this_dir, "third_party", "sputnik") cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include") + cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples") if not os.path.exists(cutlass_dir): raise RuntimeError( f"CUTLASS submodule not found at {cutlass_dir}. " @@ -189,8 +186,15 @@ def get_extensions(): ) == "1": extension = CUDAExtension sources += source_cuda - include_dirs += [sputnik_dir, cutlass_dir] - nvcc_flags = ["-DHAS_PYTORCH", "--use_fast_math", "--generate-line-info"] + include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir] + nvcc_flags = [ + "-DHAS_PYTORCH", + "--use_fast_math", + "--generate-line-info", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--extended-lambda", + ] if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1": nvcc_flags.append("-DNDEBUG") nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", "")) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 39538cf269..da9a1ca7a5 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -83,16 +83,19 @@ def generate_test_shapes(): # Add some random shapes r = random.Random(0) for _ in range(20): - shapes.append((r.randint(1, 5000), r.randint(1, 5000), r.randint(1, 512) * 8)) + shapes.append( + (r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8) + ) return shapes _test_shapes = list(generate_test_shapes()) _test_shapes_ids = [str(s) for s in _test_shapes] -_dtypes = [torch.float, torch.float16] +_dtypes = [torch.float16] @pytest.mark.parametrize("autocast", [False]) # TODO: Enable autocast testing +@pytest.mark.parametrize("pack_weights", [True, False]) @pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes]) @pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize( @@ -105,8 +108,11 @@ def test_forward_backward( device, dtype, autocast: bool, + pack_weights: bool, ): - FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-3} + torch.manual_seed(shape[0] * shape[1] * shape[2]) + FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2} + FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3} BACKWARD_ATOL = { torch.float: 3e-4, torch.half: 0.5, @@ -124,8 +130,11 @@ def test_forward_backward( inp_model_dtype = torch.float if autocast else dtype x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype) op = xsw._SwiGLUDecomposedOp + op = xsw.SwiGLUFusedOp - module = xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2]) + module = xsw._SwiGLUModule( + in_features=shape[1], hidden_features=shape[2], pack_weights=pack_weights + ) x_f32: Optional[torch.Tensor] ref_f32: Optional[torch.Tensor] module_f32: Optional[torch.nn.Module] @@ -150,11 +159,7 @@ def test_forward_backward( ref_f32 = ref assert_allclose( - out, - ref, - ref_f32, - "fw", - atol=FORWARD_ATOL[dtype], + out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype] ) # Backward diff --git a/xformers/benchmarks/benchmark_swiglu.py b/xformers/benchmarks/benchmark_swiglu.py index eedf75d17b..631f7c90bd 100644 --- a/xformers/benchmarks/benchmark_swiglu.py +++ b/xformers/benchmarks/benchmark_swiglu.py @@ -12,6 +12,7 @@ from utils import benchmark_main_helper import xformers.ops.swiglu as xsw +from xformers.ops import unbind as xunbind min_run_time = 0.5 device = torch.device("cuda") @@ -22,10 +23,13 @@ (9456, 1536, 4096), (4440, 1536, 4096), (4728, 1536, 4096), + # Some smaller shapes as well + (4728, 1536, 1024), ] -OP = xsw._SwiGLUDecomposedOp +# OP = xsw._SwiGLUDecomposedOp +OP = xsw.SwiGLUFusedOp def product_dict(**kwargs): @@ -38,7 +42,7 @@ def product_dict(**kwargs): CASES = list( product_dict( shape=SHAPES, - dtype=[torch.half, torch.float], + dtype=[torch.half], ) ) @@ -61,11 +65,16 @@ def benchmark_swiglu(shape, dtype): sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}" assert not autocast + + params = module._ordered_params_for_op() + # w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape]) + # params[0], params[2] = w1w2.unbind(dim=0) + yield benchmark.Timer( stmt="fn(x, *args)", globals={ "x": x, - "args": module._ordered_params_for_op(), + "args": params, "fn": partial(xsw.functional_swiglu, op=OP), }, label="swiglu_fw", @@ -103,7 +112,11 @@ def benchmark_swiglu_bw(shape, dtype): sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}" assert not autocast - out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=OP) + params = module._ordered_params_for_op() + w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape]).detach() + w1w2.requires_grad_() + params[0], params[2] = xunbind(w1w2, dim=0) + out = xsw.functional_swiglu(x, *params, op=OP) grad = torch.zeros_like(out) yield benchmark.Timer( stmt="out.backward(grad, retain_graph=True)", diff --git a/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu b/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu index a6179b6193..c41ac83ab4 100644 --- a/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu +++ b/xformers/components/attention/csrc/cuda/sddmm2_cuda.cu @@ -5,7 +5,7 @@ #include #include #include -#include "computeUtil.h" +#include "../computeUtil.h" namespace ge_spmm { diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/CMakeLists.txt b/xformers/components/swiglu/cuda/43_dual_gemm/CMakeLists.txt new file mode 100644 index 0000000000..e5eeb2eb42 --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 43_dual_gemm + dual_gemm.cu + ) diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/device/dual_gemm.h b/xformers/components/swiglu/cuda/43_dual_gemm/device/dual_gemm.h new file mode 100644 index 0000000000..0211f50e36 --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/device/dual_gemm.h @@ -0,0 +1,457 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Performs a dual gemm: +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue0(X @ B1, C1) +D2 = element_wise(D0, D1) +``` +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +#include "../kernel/dual_gemm.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp0_, + typename EpilogueOutputOp1_, + typename EpilogueOutputOp2_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + bool StoreD0 = true, + bool StoreD1 = true, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class DualGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using EpilogueOutputOp2 = EpilogueOutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static bool const kSplitKSerial = SplitKSerial; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + using LayoutScaleBias = layout::RowMajor; + /// Define the kernel + /// Define the threadblock-scoped matrix multiply-accumulate + static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); + static_assert(kStages >= 3, "Only multistage is implemented"); + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using DualMma = threadblock::DualMmaMultistage< + typename Mma::Shape, + typename Mma::IteratorA, + typename Mma::SmemIteratorA, + Mma::kCacheOpA, + typename Mma::IteratorB, + typename Mma::SmemIteratorB, + Mma::kCacheOpB, + typename Mma::ElementC, + typename Mma::LayoutC, + typename Mma::Policy, + Mma::kStages, + SharedMemoryClearOption::kNone + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue0 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0, + EpilogueOutputOp0::kCount>::Epilogue; + using Epilogue1 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using DualGemmKernel = kernel::DualGemm< + DualMma, + Epilogue0, Epilogue1, EpilogueOutputOp2, + ThreadblockSwizzle, kSplitKSerial, + kStoreD0, kStoreD1>; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A0; + TensorRef ref_B0; + TensorRef ref_C0; + TensorRef ref_D0; + TensorRef ref_B1; + TensorRef ref_C1; + TensorRef ref_D1; + TensorRef ref_D2; + typename EpilogueOutputOp0::Params epilogue0; + typename EpilogueOutputOp1::Params epilogue1; + typename EpilogueOutputOp2::Params epilogue2; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A0_, + TensorRef ref_B0_, + TensorRef ref_C0_, + TensorRef ref_D0_, + TensorRef ref_B1_, + TensorRef ref_C1_, + TensorRef ref_D1_, + TensorRef ref_D2_, + typename EpilogueOutputOp0::Params epilogue0_ = + typename EpilogueOutputOp0::Params(), + typename EpilogueOutputOp1::Params epilogue1_ = + typename EpilogueOutputOp1::Params(), + typename EpilogueOutputOp2::Params epilogue2_ = + typename EpilogueOutputOp2::Params(), + int split_k_slices_ = 1 + ): + problem_size(problem_size_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_D0(ref_D0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + ref_D2(ref_D2_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + epilogue2(epilogue2_), + split_k_slices(split_k_slices_) { + + } + }; + +private: + + /// Kernel parameters object + typename DualGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + DualGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + if (kStoreD0 != (args.ref_D0.data() != nullptr)) { + return Status::kErrorInternal; + } + if (kStoreD1 != (args.ref_D1.data() != nullptr)) { + return Status::kErrorInternal; + } + + Status status = DualGemmKernel::can_implement( + args.problem_size, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename DualGemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2, + args.epilogue0, + args.epilogue1, + args.epilogue2, + static_cast(workspace), + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_D0.reset(args.ref_D0.data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); + params_.ref_D2.reset(args.ref_D2.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.output_op_2 = args.epilogue2; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(DualGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/dual_gemm_run.h b/xformers/components/swiglu/cuda/43_dual_gemm/dual_gemm_run.h new file mode 100644 index 0000000000..70dbb9ff39 --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/dual_gemm_run.h @@ -0,0 +1,829 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template < + typename OutputOp, + typename Element, + typename Layout> +struct TensorEpilogueForEachFunc { + /// View type + using TensorView = cutlass::TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view_x0; + TensorView view_x1; + TensorView view_y; + OutputOp output_op; + + + // + // Methods + // + + Params( + TensorView view_x0_ = TensorView(), + TensorView view_x1_ = TensorView(), + TensorView view_y_ = TensorView(), + OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) + ): + view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { + } + }; + + Params params; + + CUTLASS_DEVICE + TensorEpilogueForEachFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + Element const & x0 = params.view_x0.at(coord); + Element const & x1 = params.view_x1.at(coord); + Element& y = params.view_y.at(coord); + y = params.output_op(x0, x1); + } +}; + +template < + typename OutputOp, + typename Element, + typename Layout> +void TensorEpilogueForEach( + cutlass::TensorView x0, + cutlass::TensorView x1, + cutlass::TensorView y) { + + using Func = TensorEpilogueForEachFunc; + using Params = typename Func::Params; + + cutlass::reference::device::TensorForEach( + y.extent(), + Params(x0, x1, y) + ); +} + +//////////////////////////////////////////////////////////////////////////////// + +template +struct NonFusedDualGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + NonFusedDualGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + reference_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; + typename Gemm0::Arguments arguments_0{ + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0}, + split_k_slices + }; + + split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; + typename Gemm1::Arguments arguments_1{ + problem_size, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1}, + split_k_slices + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + // Allocate workspace memory + cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); + cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); + + cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1, workspace1.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } +#ifdef IS_PROFILING + return true; +#endif + // + // Run the GEMM + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size, + alpha1, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed0 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed0); + + bool passed1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed1); + if (!passed0 || !passed1) { + + std::stringstream fname; + + fname << "error_DualGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed0 && passed1; + } +}; + +template +struct DualFusedGemmRun +{ + + using DualGemm = DualGemm_; + using ElementAccumulator = typename DualGemm::ElementAccumulator; + using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; + using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + DualFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(1), + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename DualGemm::ElementA, + typename DualGemm::LayoutA> tensor_A0(problem_size.mk()); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB> tensor_B0(problem_size.kn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C0(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D0(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D0(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB> tensor_B1(problem_size.kn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C1(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D1(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D2(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D1(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D2(problem_size.mn()); + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + tensor_D2.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D2.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D0.sync_device(); + tensor_D1.sync_device(); + tensor_D2.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + reference_D2.sync_device(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; + typename cutlass::TensorRef nullptr_ref{}; + decltype(nullptr_ref) ref_B0, ref_B1; + if (beta0 != ElementCompute(0)) { + ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + if (beta1 != ElementCompute(0)) { + ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + typename DualGemm::Arguments arguments{ + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + ref_B0, + DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, + tensor_B1.device_ref(), + ref_B1, + DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, + tensor_D2.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + {}, + split_k_slices + }; + + DualGemm b2b_gemm_op; + + cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + CUTLASS_CHECK(status); + + status = b2b_gemm_op.initialize(arguments, workspace.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + +#ifdef IS_PROFILING + return true; +#endif + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + tensor_D2.sync_host(); + + // + // Verify + // + + cutlass::reference::device::Gemm< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB, + typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute, + ElementAccumulator, typename DualGemm::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size, + alpha1, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + reference_D2.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); + + bool passed_out0 = true; + if (DualGemm::kStoreD0) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + passed_out0 = cutlass::reference::host::TensorEquals( + reference_D0.host_view(), + tensor_D0.host_view()); + } + CHECK_TRUE(passed_out0); + + bool passed_out1 = true; + if (DualGemm::kStoreD1) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + passed_out1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + } + CHECK_TRUE(passed_out1); + + bool passed_out2 = cutlass::reference::host::TensorEquals( + reference_D2.host_view(), + tensor_D2.host_view()); + CHECK_TRUE(passed_out2); + + bool passed = passed_out0 && passed_out1 && passed_out2; + if (!passed) + { + + std::stringstream fname; + + fname << "error_DualGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference0 =\n" << reference_D0.host_view() + << "\nComputed0 =\n" << tensor_D0.host_view() + << "\n\nReference1 =\n" << reference_D1.host_view() + << "\nComputed1 =\n" << tensor_D1.host_view() + << "\n\nReference2 =\n" << reference_D2.host_view() + << "\nComputed2 =\n" << tensor_D2.host_view(); + } + //std::cout << "A0 " << tensor_A0.host_view() << std::endl; + // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; + // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; + //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/kernel/dual_gemm.h b/xformers/components/swiglu/cuda/43_dual_gemm/kernel/dual_gemm.h new file mode 100644 index 0000000000..7a66ca750a --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/kernel/dual_gemm.h @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "../threadblock/dual_mma_multistage.h" +#include "../threadblock/dual_epilogue.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue0_, ///! Epilogue + typename Epilogue1_, ///! Epilogue + typename OutputOp2_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. + bool StoreD0, + bool StoreD1 +> +struct DualGemm { + + using DualMma = DualMma_; + + using Epilogue0 = Epilogue0_; + using Epilogue1 = Epilogue1_; + using OutputOp0 = typename Epilogue0::OutputOp; + using OutputOp1 = typename Epilogue1::OutputOp; + using OutputOp2 = OutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static constexpr bool kStoreD0 = StoreD0; + static constexpr bool kStoreD1 = StoreD1; + + using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue< + typename Epilogue0::Shape, + typename Epilogue0::WarpMmaOperator, + Epilogue0::kPartitionsK, + typename Epilogue0::OutputTileIterator, + typename Epilogue0::AccumulatorFragmentIterator, + typename Epilogue0::WarpTileIterator, + typename Epilogue0::SharedLoadIterator, + OutputOp0, + OutputOp1, + OutputOp2, + typename Epilogue0::Padding, + kStoreD0, + kStoreD1, + Epilogue0::kFragmentsPerIteration, + true // IterationsUnroll + >; + + static bool const kSplitKSerial = SplitKSerial; + static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), + "Split-K serial requires buffers for D0/D1 for reduction"); + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename DualMma::WarpCount; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + // Mma0 + typename DualMma::IteratorA::Params params_A0; + typename DualMma::IteratorA::TensorRef ref_A0; + typename DualMma::IteratorB::Params params_B0; + typename DualMma::IteratorB::TensorRef ref_B0; + typename Epilogue0::OutputTileIterator::Params params_C0; + typename Epilogue0::OutputTileIterator::TensorRef ref_C0; + typename Epilogue0::OutputTileIterator::Params params_D0; + typename Epilogue0::OutputTileIterator::TensorRef ref_D0; + typename OutputOp0::Params output_op_0; + + // Mma1 + typename DualMma::IteratorB::Params params_B1; + typename DualMma::IteratorB::TensorRef ref_B1; + typename Epilogue1::OutputTileIterator::Params params_C1; + typename Epilogue1::OutputTileIterator::TensorRef ref_C1; + typename Epilogue1::OutputTileIterator::Params params_D1; + typename Epilogue1::OutputTileIterator::TensorRef ref_D1; + typename OutputOp1::Params output_op_1; + + typename Epilogue1::OutputTileIterator::Params params_D2; + typename Epilogue1::OutputTileIterator::TensorRef ref_D2; + typename OutputOp2::Params output_op_2; + + int *semaphore; + int gemm_k_size; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + // Mma0: D0 = A @ B0 + C0 + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + // Mma1: D1 = A @ B1 + C1 + typename DualMma::IteratorB::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + + typename Epilogue1::OutputTileIterator::TensorRef ref_D2, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), + int *workspace = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + // Mma0 + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + params_D0(ref_D0.layout()), + ref_D0(ref_D0), + // Mma1 + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + params_D2(ref_D2.layout()), + ref_D2(ref_D2), + output_op_0(output_op_0), + output_op_1(output_op_1), + output_op_2(output_op_2) { + + int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename DualMma::SharedStorage main_loop; + typename DualEpilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DualGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + typename DualMma::IteratorB::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { + + static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D2, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * DualMma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B0{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename DualMma::IteratorA iterator_A0( + params.params_A0, + params.ref_A0.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A0); + + typename DualMma::IteratorB iterator_B0( + params.params_B0, + params.ref_B0.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B0); + + typename DualMma::IteratorB iterator_B1( + params.params_B1, + params.ref_B1.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B1); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + + // Construct thread-scoped matrix multiply + typename DualMma::FragmentC accum0; + typename DualMma::FragmentC accum1; + accum0.clear(); + accum1.clear(); + + DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accum0, accum1, + iterator_A0, iterator_B0, iterator_B1, + accum0, accum1); + } + + // + // Epilogue + // + + OutputOp0 output_op_0(params.output_op_0); + OutputOp1 output_op_1(params.output_op_1); + OutputOp2 output_op_2(params.output_op_2); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * DualMma::Shape::kM, + threadblock_tile_offset.n() * DualMma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue0::OutputTileIterator iterator_C0( + params.params_C0, + params.ref_C0.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_C1( + params.params_C1, + params.ref_C1.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue0::OutputTileIterator iterator_D0( + params.params_D0, + params.ref_D0.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D1( + params.params_D1, + params.ref_D1.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D2( + params.params_D2, + params.ref_D2.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + DualEpilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C0 = iterator_D0; + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + typename Epilogue0::OutputTileIterator source_iters[] = { + iterator_C0, iterator_C1 + }; + const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1); + epilogue( + output_op_0, output_op_1, output_op_2, + iterator_D0, iterator_D1, iterator_D2, + accum0, accum1, + source_iters, + writeToD2 + ); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/test_run.h b/xformers/components/swiglu/cuda/43_dual_gemm/test_run.h new file mode 100644 index 0000000000..4c0787fcfb --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/test_run.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include + +// Run tests on GPUs + +int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { + + bool supported = false; + + int arch_major = arch / 10; + int arch_minor = arch - arch / 10 * 10; + + if(arch_major >= 8) { + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { + supported = true; + } + } + else if(arch_major >= 7) { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { + supported = true; + } + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major == arch_major && props.minor == arch_minor)) { + supported = false; + } + + if (!supported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + std::cout << "This example isn't supported on current architecture" << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "Device: " << props.name << std::endl; + std::cout << "Arch: SM" << arch << std::endl; + std::cout << "Test: " << test_name << std::endl; + for(auto func : test_funcs) { + pass &= func(); + } + + + if(pass) + return 0; + else + return -1; + +} diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/thread/left_silu_and_mul.h b/xformers/components/swiglu/cuda/43_dual_gemm/thread/left_silu_and_mul.h new file mode 100644 index 0000000000..b3c0560a9e --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/thread/left_silu_and_mul.h @@ -0,0 +1,150 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftSiLUAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const ¶ms) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + return compute_to_output(mul(silu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + return ElementOutput(mul(silu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_epilogue.h b/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_epilogue.h new file mode 100644 index 0000000000..d4b8ef7689 --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_epilogue.h @@ -0,0 +1,430 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + ///< Output operator + typename OutputOp0_, + typename OutputOp1_, + typename OutputOp2_, + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + bool StoreD0 = true, + bool StoreD1 = true, + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class DualEpilogue { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp0 = OutputOp0_; + using OutputOp1 = OutputOp1_; + using OutputOp2 = OutputOp2_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + struct SharedStorage { + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = typename Base::Shape; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = typename Base::SharedStorage::StorageShape; + + // + // Data members + // + + AlignedBuffer storage[2]; + + // + // Methods + // + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference(int i) { + return TensorRef( + storage[i].data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; + +public: + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator0_; + SharedLoadIterator shared_load_iterator1_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator0_; + WarpTileIterator warp_tile_iterator1_; + +public: + + /// Constructor + CUTLASS_DEVICE + DualEpilogue( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + shared_load_iterator0_(shared_storage.reference(0), thread_idx), + shared_load_iterator1_(shared_storage.reference(1), thread_idx), + warp_tile_iterator0_(shared_storage.reference(0), lane_idx), + warp_tile_iterator1_(shared_storage.reference(1), lane_idx) + { + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator0_.add_tile_offset(warp_offset); + warp_tile_iterator1_.add_tile_offset(warp_offset); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + OutputTileIterator dest0, + OutputTileIterator dest1, + OutputTileIterator dest2, + AccumulatorTile const &accumulator0, + AccumulatorTile const &accumulator1, + OutputTileIterator source_iterator[2], + bool writeToD2 // true if it's the final split-k + ) { + // TODO: Implement when no source is needed + + typename OutputTileIterator::Fragment source_fragment[2]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_fragment[i].clear(); + } + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1}; + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_iterator[i].load(source_fragment[i]); + ++source_iterator[i]; + } + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[0], this->warp_tile_iterator0_); + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[1], this->warp_tile_iterator1_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK]; + typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK]; + + shared_load_iterator0_.load(aligned_accum_fragment0[0]); + shared_load_iterator1_.load(aligned_accum_fragment1[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator0_.load(aligned_accum_fragment0[i]); + shared_load_iterator1_.load(aligned_accum_fragment1[i]); + aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]); + aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]); + } + + shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment[3]; + + apply_output_operator_(output_fragment, + output_op0, output_op1, output_op2, + aligned_accum_fragment0[0], aligned_accum_fragment1[0], + source_fragment); + + + // + // Store the final result + // + + if (kStoreD0) { + dest0.store(output_fragment[0]); + ++dest0; + } + if (kStoreD1) { + dest1.store(output_fragment[1]); + ++dest1; + } + if (writeToD2) { + dest2.store(output_fragment[2]); + ++dest2; + } + } + } + +private: + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment (&output_fragment)[3], + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment0, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment1, + typename OutputTileIterator::Fragment const (&source_fragment)[2]) { + + OutputAccessType* output_frag_ptr[3] = { + reinterpret_cast(&output_fragment[0]), + reinterpret_cast(&output_fragment[1]), + reinterpret_cast(&output_fragment[2]) + }; + + AccumulatorAccessType const *compute_frag_ptr[2] = { + reinterpret_cast(&aligned_accum_fragment0), + reinterpret_cast(&aligned_accum_fragment1) + }; + + OutputAccessType const *source_frag_ptr[2] = { + reinterpret_cast(&source_fragment[0]), + reinterpret_cast(&source_fragment[1]) + }; + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operators + output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); + output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); + output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_base.h b/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_base.h new file mode 100644 index 0000000000..975eb137bc --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_base.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DualMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B0; + AlignedBuffer operand_B1; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B0_ref() { + return TensorRefB{operand_B0.data(), LayoutB()}; + } + CUTLASS_HOST_DEVICE + TensorRefB operand_B1_ref() { + return TensorRefB{operand_B1.data(), LayoutB()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B0_; + typename Operator::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_multistage.h b/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_multistage.h new file mode 100644 index 0000000000..dae6ec4d7b --- /dev/null +++ b/xformers/components/swiglu/cuda/43_dual_gemm/threadblock/dual_mma_multistage.h @@ -0,0 +1,760 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "dual_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class DualMmaMultistage : + public DualMmaBase { +public: + ///< Base class + using Base = DualMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B0_; + SmemIteratorB smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx), + smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B0.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + iterator_B1.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B); + this->smem_iterator_B1_.set_iteration_index(group_start_B); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum0, + FragmentC &accum1, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B0, + IteratorB iterator_B1, + ///< initial value of accumulator + FragmentC const &src_accum0, + FragmentC const &src_accum1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B0.set_iteration_index(0); + iterator_B1.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum0 = src_accum0; + accum1 = src_accum1; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + typename IteratorB::AccessType zero_B; + zero_B.clear(); + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_); + last_smem_iterator_B0.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B0.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B0; + } + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_); + last_smem_iterator_B1.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B1.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B1; + } + } + + // Waits until stages up to the previous (kStages-2)th stage have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B0[2]; + WarpLoadedFragmentB warp_loaded_frag_B1[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B0[2]; + WarpTransformedFragmentB warp_transformed_frag_B1[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum0, tmp_accum1; + + if (platform::is_same::value + || platform::is_same::value) { + + tmp_accum0.clear(); + tmp_accum1.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + } + + if (platform::is_same::value + || platform::is_same::value) { + + warp_mma( + tmp_accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + tmp_accum0 + ); + warp_mma( + tmp_accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum1 + ); + + if (warp_mma_k == 0) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + tmp_accum0.clear(); + tmp_accum1.clear(); + } + } else { + warp_mma( + accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + warp_mma( + accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum1 + ); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until stages up to the previous (kStages-2)th stage have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + } + + if (platform::is_same::value + || platform::is_same::value) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu b/xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu new file mode 100644 index 0000000000..79f1ac3413 --- /dev/null +++ b/xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu @@ -0,0 +1,150 @@ +#include +#include +#include +#include +#include + +#include "43_dual_gemm/device/dual_gemm.h" +#include "43_dual_gemm/thread/left_silu_and_mul.h" + +namespace { +template +std::tuple dual_gemm_silu_identity_mul_( + const at::Tensor& x, + const at::Tensor& w0, + const at::Tensor& b0, + const at::Tensor& w1, + const at::Tensor& b1 +) { + TORCH_CHECK(x.stride(-1) == 1); + TORCH_CHECK(w0.stride(-1) == 1); + TORCH_CHECK(w1.stride(-1) == 1); + + at::cuda::CUDAGuard device_guard(x.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = x.size(0); + int64_t I = x.size(1); + int64_t H = w0.size(0); + + at::Tensor d0 = at::empty({B, H}, x.options()); + at::Tensor d1 = at::empty({B, H}, x.options()); + at::Tensor d2 = at::empty({B, H}, x.options()); + + // templati-ze the cutlass kernel + cutlass::gemm::GemmCoord problem_size(B, H, I); + + constexpr int kStages = 3; + constexpr bool kSplitKSerial = false; + + using ElementOutput = scalar_t; + using ElementAccumulator = float; + using ElementCompute = float; + using EpilogueOutputOp01 = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >; + using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementOutput, + ElementCompute + >; + + const ElementCompute alpha0 = ElementCompute(1); + const ElementCompute beta0 = ElementCompute(1); + const ElementCompute alpha1 = ElementCompute(1); + const ElementCompute beta1 = ElementCompute(1); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + scalar_t, + cutlass::layout::RowMajor, + scalar_t, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp01, + EpilogueOutputOp01, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; + using RefA = typename cutlass::TensorRef; + using RefB = typename cutlass::TensorRef; + using RefC = typename cutlass::TensorRef; + typename DualGemm::Arguments arguments{ + problem_size, + RefA{(scalar_t*)x.data_ptr(), typename DualGemm::LayoutA::Stride(x.stride(0))}, + RefB{(scalar_t*)w0.data_ptr(), typename DualGemm::LayoutB::Stride(w0.stride(0))}, + RefC{(scalar_t*)b0.data_ptr(), typename DualGemm::LayoutC::Stride(0)}, + RefC{(scalar_t*)d0.data_ptr(), typename DualGemm::LayoutC::Stride(d0.stride(0))}, + RefB{(scalar_t*)w1.data_ptr(), typename DualGemm::LayoutB::Stride(w1.stride(0))}, + RefC{(scalar_t*)b1.data_ptr(), typename DualGemm::LayoutC::Stride(0)}, + RefC{(scalar_t*)d1.data_ptr(), typename DualGemm::LayoutC::Stride(d1.stride(0))}, + RefC{(scalar_t*)d2.data_ptr(), typename DualGemm::LayoutC::Stride(d2.stride(0))}, + typename DualGemm::EpilogueOutputOp0::Params{alpha0, beta0}, + typename DualGemm::EpilogueOutputOp1::Params{alpha1, beta1}, + typename DualGemm::EpilogueOutputOp2::Params{}, + split_k_slices + }; + DualGemm dual_gemm; + at::Tensor workspace = at::empty({int64_t(dual_gemm.get_workspace_size(arguments))}, x.options().dtype(at::ScalarType::Byte)); + cutlass::Status status = dual_gemm.can_implement(arguments); + TORCH_CHECK(status == cutlass::Status::kSuccess, "not supported by this kernel"); + status = dual_gemm.initialize(arguments, (uint8_t*)workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel initialize failed"); + status = dual_gemm(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel run failed"); + return std::make_tuple(d0, d1, d2); +} + +std::tuple dual_gemm_silu_identity_mul( + const at::Tensor& x, + const at::Tensor& w0, + const at::Tensor& b0, + const at::Tensor& w1, + const at::Tensor& b1 +) { + // TODO: Check all params. This would take a lot of lines of code... + TORCH_CHECK(x.dim() == 2); + TORCH_CHECK(w0.dim() == 2); + TORCH_CHECK(w1.dim() == 2); + + #define FWD_PARAMS x,w0,b0,w1,b1 + + if (x.scalar_type() == at::ScalarType::Half) { + return dual_gemm_silu_identity_mul_(FWD_PARAMS); + } else { + TORCH_CHECK(x.scalar_type() == at::ScalarType::BFloat16, "Only supports bf16/f16"); + return dual_gemm_silu_identity_mul_(FWD_PARAMS); + } +} +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::dual_gemm_silu_identity_mul"), + TORCH_FN(dual_gemm_silu_identity_mul)); +} diff --git a/xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu b/xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu new file mode 100644 index 0000000000..8ac77a4187 --- /dev/null +++ b/xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu @@ -0,0 +1,212 @@ +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/kernel/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/matrix_coord.h" + +#define WORKAROUND_CUTLASS_BUG + +namespace { +template +void gemm_fused_operand_sum_( + const at::Tensor& a, // col-major + const at::Tensor& b, // row-major + at::Tensor& out_mm, // row-major + at::Tensor& out_sum // row-major +) { + at::cuda::CUDAGuard device_guard(a.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t M = a.size(0); + int64_t N = b.size(0); + int64_t K = a.size(1); + + // templati-ze the cutlass kernel + cutlass::gemm::GemmCoord problem_size(M, N, K); + using ElementAccumulator = float; // Data type of accumulator + using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation + using ElementInputA = scalar_t; + using ElementInputB = scalar_t; + using ElementOutput = scalar_t; + + using LayoutInputA = cutlass::layout::ColumnMajor; + TORCH_CHECK(a.stride(0) == 1); + using LayoutInputB = cutlass::layout::RowMajor; + TORCH_CHECK(b.stride(1) == 1); + using LayoutOutput = cutlass::layout::RowMajor; + TORCH_CHECK(out_mm.stride(1) == 1); + + // Layout of the output vector + using LayoutGemmKReduction = cutlass::layout::PitchLinear; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm80; + + // This code section describes the tile size a thread block will compute + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape + + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + + // Number of pipelines you want to use + constexpr int NumStages = 4; + + // Reduce A or B operand along the K dimension +#ifdef WORKAROUND_CUTLASS_BUG + constexpr bool ReduceKForA = false; +#else + constexpr bool ReduceKForA = true; +#endif + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction< + ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8, + ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + ReduceKForA, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd + >::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Below is the reduction kernel used in the case of parallel split-k + using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; + + using ReduceOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + ElementOutput, + EpilogueOp::kCount + >; + + using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< + ReduceGemmSplitKShape, + EpilogueOp, + ReduceOp + >; + + using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK; + + using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;; + + // This code section describes the epilogue part of the kernel, we use default value + using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; + + using ReduceVectorSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< + ReduceVectorSplitKShape, + DummyEpilogueOp, + ReduceOp + >; + + using ReduceVectorSplitK = cutlass::reduction::device::ReduceSplitK; + auto alpha = ElementComputeEpilogue(1); + auto beta = ElementComputeEpilogue(0); + + using RefA = cutlass::TensorRef; + using RefB = cutlass::TensorRef; + using RefC = cutlass::TensorRef; + int reduce_vector_length = ReduceKForA ? problem_size.m() : problem_size.n(); + int split_k_slices = 2; + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + int(split_k_slices), + {alpha, beta}, + (ElementInputA const*)a.data_ptr(), + (ElementInputB const*)b.data_ptr(), + (ElementOutput*)nullptr, + (ElementOutput*)out_mm.data_ptr(), + (ElementOutput*)out_sum.data_ptr(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + reduce_vector_length, + a.stride(1), + b.stride(0), + int64_t(0), // bias + out_mm.stride(0), + int64_t(1) // out_sum + }; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + size_t workspace_size = Gemm::get_workspace_size(arguments); + at::Tensor workspace = at::empty({int64_t(gemm_op.get_workspace_size(arguments))}, a.options().dtype(at::ScalarType::Byte)); + cutlass::Status status = gemm_op.can_implement(arguments); + TORCH_CHECK(status == cutlass::Status::kSuccess, "not supported by this kernel"); + status = gemm_op.initialize(arguments, (uint8_t*)workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel initialize failed"); + status = gemm_op(stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel run failed"); +} + +std::tuple gemm_fused_operand_sum( + const at::Tensor& a, + const at::Tensor& b, + at::Tensor& out_mm, + at::Tensor& out_sum +) { + // TODO: Check all params. This would take a lot of lines of code... + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(b.dim() == 2); + TORCH_CHECK(out_mm.dim() == 2); + + #define FWD_PARAMS a,b,out_mm,out_sum + + if (a.scalar_type() == at::ScalarType::Half) { + gemm_fused_operand_sum_(FWD_PARAMS); + } else { + TORCH_CHECK(a.scalar_type() == at::ScalarType::BFloat16, "Only supports bf16/f16"); + gemm_fused_operand_sum_(FWD_PARAMS); + } + return std::make_tuple(out_mm, out_sum); +} +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::gemm_fused_operand_sum"), + TORCH_FN(gemm_fused_operand_sum)); +} diff --git a/xformers/components/swiglu/cuda/silu_bw_fused.cu b/xformers/components/swiglu/cuda/silu_bw_fused.cu new file mode 100644 index 0000000000..b03a2e8e07 --- /dev/null +++ b/xformers/components/swiglu/cuda/silu_bw_fused.cu @@ -0,0 +1,98 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace { +/* +Computes the following: + +def silu_bw_fused(x1, x2, dx4): + x3 = F.silu(x1) + dx3 = dx4 * x2 + dx2 = dx4 * x3 + x4 = x2 * x3 # checkpointing + # silu bw + sigm = 1 / (1 + torch.exp(-x1.float())) + dx1 = (dx3.float() * sigm * (1 + x1.float() * (1 - sigm))).to(x1.dtype) + return dx1, dx2, x4 +*/ + +template +struct KernelTraits { + using AccumulationElement = T; +}; + +template <> +struct KernelTraits { + using AccumulationElement = float; +}; + +template <> +struct KernelTraits { + using AccumulationElement = float; +}; + +std::tuple silu_bw_fused( + const at::Tensor& x1, + const at::Tensor& x2, + const at::Tensor& dx4 +) { + // TODO: Check all params. This would take a lot of lines of code... + TORCH_CHECK(x2.dim() == 2); + TORCH_CHECK(dx4.dim() == 2); + TORCH_CHECK(x2.size(0) == dx4.size(0)); + TORCH_CHECK(x2.size(1) == dx4.size(1)); + + int64_t B = x2.size(0); + int64_t H = x2.size(1); + at::Tensor dx1dx2 = at::empty({B, 2, H}, x2.options()); + at::Tensor dx1 = dx1dx2.select(1, 0); + at::Tensor dx2 = dx1dx2.select(1, 1); + at::Tensor x4 = at::empty({B, H}, x2.options()); + auto iter = at::TensorIteratorConfig() + .add_output(dx1) + .add_output(dx2) + .add_output(x4) + .add_input(x1) + .add_input(x2) + .add_input(dx4) + .check_all_same_dtype(true) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x2.scalar_type(), + "silu_bw_fused", ([&] { + using acc_t = typename KernelTraits::AccumulationElement; + at::native::gpu_kernel_multiple_outputs( + iter, [=] GPU_LAMBDA (scalar_t x1_, scalar_t x2_, scalar_t dx4_) + -> thrust::tuple { + acc_t sigm = acc_t(1) / (acc_t(1) + std::exp(-acc_t(x1_))); + acc_t x3_ = sigm * x1_; + acc_t dx3_ = acc_t(dx4_) * acc_t(x2_); + acc_t dx2_ = acc_t(dx4_) * acc_t(x3_); + acc_t dx1_ = (dx3_ * sigm * (acc_t(1) + acc_t(x1_) * (acc_t(1) - sigm))); + acc_t x4_ = x3_ * x2_; + return thrust::tuple{ + dx1_, + dx2_, + x4_ + }; + }); + })); + return std::make_tuple(dx1, dx2, x4); +} +} // namespace + +TORCH_LIBRARY_IMPL(xformers, CUDA, m) { + m.impl( + TORCH_SELECTIVE_NAME("xformers::silu_bw_fused"), + TORCH_FN(silu_bw_fused)); +} diff --git a/xformers/components/swiglu/swiglu.cpp b/xformers/components/swiglu/swiglu.cpp new file mode 100644 index 0000000000..0f325411cf --- /dev/null +++ b/xformers/components/swiglu/swiglu.cpp @@ -0,0 +1,10 @@ +#include + +TORCH_LIBRARY_FRAGMENT(xformers, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::dual_gemm_silu_identity_mul(Tensor x, Tensor w1, Tensor b1, Tensor w2, Tensor b2) -> (Tensor, Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::silu_bw_fused(Tensor x1, Tensor x2, Tensor dx4) -> (Tensor, Tensor, Tensor)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "xformers::gemm_fused_operand_sum(Tensor a, Tensor b, Tensor out_mm, Tensor out_sum) -> (Tensor, Tensor)")); +} diff --git a/xformers/ops/__init__.py b/xformers/ops/__init__.py index 548f6d2df8..638b41c8ca 100644 --- a/xformers/ops/__init__.py +++ b/xformers/ops/__init__.py @@ -16,7 +16,7 @@ MemoryEfficientAttentionOp, memory_efficient_attention, ) -from .swiglu import functional_swiglu # noqa: F401 +from .swiglu import SwiGLUFusedOp, functional_swiglu # noqa: F401 from .unbind import efficient_stack, get_stack_strides, unbind # noqa: F401 diff --git a/xformers/ops/swiglu.py b/xformers/ops/swiglu.py index 77bbac4f97..2b81ea9e41 100644 --- a/xformers/ops/swiglu.py +++ b/xformers/ops/swiglu.py @@ -9,6 +9,8 @@ import torch.nn.functional as F from torch import nn +from .unbind import efficient_stack_or_none, unbind + class _SwiGLUModule(nn.Module): """ @@ -21,6 +23,7 @@ def __init__( hidden_features: Optional[int] = None, out_features: Optional[int] = None, align_as: int = 8, + pack_weights: bool = False, ) -> None: super().__init__() out_features = out_features or in_features @@ -30,23 +33,47 @@ def __init__( (swiglu_hidden_features + align_as - 1) // align_as * align_as ) - self.w1 = nn.Linear(in_features, swiglu_hidden_features) - self.w2 = nn.Linear(in_features, swiglu_hidden_features) + self.w12: Optional[nn.Linear] + if pack_weights: + self.w12 = nn.Linear(in_features, 2 * swiglu_hidden_features) + else: + self.w12 = None + self.w1 = nn.Linear(in_features, swiglu_hidden_features) + self.w2 = nn.Linear(in_features, swiglu_hidden_features) self.w3 = nn.Linear(swiglu_hidden_features, out_features) + self.swiglu_hidden_features = swiglu_hidden_features + self.out_features = out_features + self.in_features = in_features + def forward(self, x: torch.Tensor) -> torch.Tensor: - x1 = self.w1(x) - x2 = self.w2(x) + if self.w12 is not None: + x12 = self.w12(x).view([x.shape[0], 2, self.swiglu_hidden_features]) + x1, x2 = unbind(x12, dim=1) + else: + x1 = self.w1(x) + x2 = self.w2(x) hidden = F.silu(x1) * x2 return self.w3(hidden) def _ordered_params_for_op(self): """Used for testing - returns ordered arguments for operators""" + if self.w12 is not None: + w1w2 = self.w12.weight + b1b2 = self.w12.bias + w1, w2 = unbind( + w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), + dim=0, + ) + b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) + else: + w1, w2 = self.w1.weight, self.w2.weight + b1, b2 = self.w1.bias, self.w2.bias return [ - self.w1.weight, - self.w1.bias, - self.w2.weight, - self.w2.bias, + w1, + b1, + w2, + b2, self.w3.weight, self.w3.bias, ] @@ -108,6 +135,52 @@ def backward(cls, ctx, dx5): return (dx, dw1, db1, dw2, db2, dw3, db3) +class SwiGLUFusedOp(torch.autograd.Function): + NAME = "fused" + + @classmethod + def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): + x1, x2, x4 = torch.ops.xformers.dual_gemm_silu_identity_mul(x, w1, b1, w2, b2) + + x5 = F.linear(x4, w3, b3) + ctx.save_for_backward(x, w1, w2, w3, x1, x2) + return x5 + + @classmethod + def backward(cls, ctx, dx5): + x, w1, w2, w3, x1, x2 = ctx.saved_tensors + w1w2 = efficient_stack_or_none([w1, w2], dim=0) + + dx4 = dx5 @ w3 # 255us (nn) + dx1, dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4) + del x1, x2, dx4 + + db3 = dx5.sum(0) # 25us + dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt) + del x4, dx5 + if w1w2 is not None: + dx1dx2 = efficient_stack_or_none([dx1, dx2], dim=1) + assert dx1dx2 is not None + assert dx1dx2.is_contiguous() + assert w1w2.is_contiguous() + w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]]) + dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2 + + # backward of linear1 + linear2 - packed + dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x + db1db2 = dx1dx2.sum(0).view([2, dx1.shape[1]]) + dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0) + db1, db2 = torch.unbind(db1db2, dim=0) + else: + dx = dx2 @ w2 # 260us (nn) + torch.addmm(dx, dx1, w1, beta=1, alpha=1, out=dx) # dx += dx1 @ w1 + dw2 = dx2.transpose(-2, -1) @ x # 245us (nt) + db2 = dx2.sum(0) # 50us + dw1 = dx1.transpose(-2, -1) @ x # 245us (nt) + db1 = dx1.sum(0) # 50us + return (dx, dw1, db1, dw2, db2, dw3, db3) + + def functional_swiglu( x: torch.Tensor, w1: torch.Tensor, diff --git a/xformers/ops/unbind.py b/xformers/ops/unbind.py index 62bfeb7851..6535fb1179 100644 --- a/xformers/ops/unbind.py +++ b/xformers/ops/unbind.py @@ -43,15 +43,26 @@ def get_stack_strides( return tuple(final_stride) -def efficient_stack( - tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int -) -> torch.Tensor: +def efficient_stack_or_none( + tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], + dim: int, +) -> Optional[torch.Tensor]: strides = get_stack_strides(tensors, dim) if strides is not None: input_shape = list(tensors[0].shape) input_shape.insert(dim, len(tensors)) return tensors[0].as_strided(input_shape, strides) - return torch.stack(tensors, dim=dim) + return None + + +def efficient_stack( + tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], + dim: int, +) -> torch.Tensor: + out = efficient_stack_or_none(tensors, dim) + if out is None: + out = torch.stack(tensors, dim=dim) + return out class _Unbind(torch.autograd.Function):