Skip to content

Commit

Permalink
SwiGLU optimized fw/bw
Browse files Browse the repository at this point in the history
ghstack-source-id: a87a46d345dcb98dc0c53c56575fcda38cd5bccd
Pull Request resolved: #490
  • Loading branch information
danthe3rd committed Oct 25, 2022
1 parent 5227f2f commit bf91e62
Show file tree
Hide file tree
Showing 11 changed files with 614 additions and 36 deletions.
22 changes: 13 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,21 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(
this_dir, "xformers", "components", "attention", "csrc"
)
extensions_dir = os.path.join(this_dir, "xformers", "components")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))

source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "autograd", "*.cpp")
)
source_cpu = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)

sources = main_file + source_cpu

source_cuda = glob.glob(
os.path.join(extensions_dir, "cuda", "**", "*.cu"), recursive=True
os.path.join(extensions_dir, "**", "cuda", "**", "*.cu"), recursive=True
)

sputnik_dir = os.path.join(this_dir, "third_party", "sputnik")
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include")
cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples")
if not os.path.exists(cutlass_dir):
raise RuntimeError(
f"CUTLASS submodule not found at {cutlass_dir}. "
Expand All @@ -189,8 +186,15 @@ def get_extensions():
) == "1":
extension = CUDAExtension
sources += source_cuda
include_dirs += [sputnik_dir, cutlass_dir]
nvcc_flags = ["-DHAS_PYTORCH", "--use_fast_math", "--generate-line-info"]
include_dirs += [sputnik_dir, cutlass_dir, cutlass_examples_dir]
nvcc_flags = [
"-DHAS_PYTORCH",
"--use_fast_math",
"--generate-line-info",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--extended-lambda",
]
if os.getenv("XFORMERS_ENABLE_DEBUG_ASSERTIONS", "0") != "1":
nvcc_flags.append("-DNDEBUG")
nvcc_flags += shlex.split(os.getenv("NVCC_FLAGS", ""))
Expand Down
23 changes: 14 additions & 9 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,19 @@ def generate_test_shapes():
# Add some random shapes
r = random.Random(0)
for _ in range(20):
shapes.append((r.randint(1, 5000), r.randint(1, 5000), r.randint(1, 512) * 8))
shapes.append(
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8)
)
return shapes


_test_shapes = list(generate_test_shapes())
_test_shapes_ids = [str(s) for s in _test_shapes]
_dtypes = [torch.float, torch.float16]
_dtypes = [torch.float16]


@pytest.mark.parametrize("autocast", [False]) # TODO: Enable autocast testing
@pytest.mark.parametrize("pack_weights", [True, False])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize(
Expand All @@ -105,8 +108,11 @@ def test_forward_backward(
device,
dtype,
autocast: bool,
pack_weights: bool,
):
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-3}
torch.manual_seed(shape[0] * shape[1] * shape[2])
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2}
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3}
BACKWARD_ATOL = {
torch.float: 3e-4,
torch.half: 0.5,
Expand All @@ -124,8 +130,11 @@ def test_forward_backward(
inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)
op = xsw._SwiGLUDecomposedOp
op = xsw._SwiGLUFusedOp

module = xsw._SwiGLUModule(in_features=shape[1], hidden_features=shape[2])
module = xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=pack_weights
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
module_f32: Optional[torch.nn.Module]
Expand All @@ -150,11 +159,7 @@ def test_forward_backward(
ref_f32 = ref

assert_allclose(
out,
ref,
ref_f32,
"fw",
atol=FORWARD_ATOL[dtype],
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)

# Backward
Expand Down
21 changes: 17 additions & 4 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from utils import benchmark_main_helper

import xformers.ops.swiglu as xsw
from xformers.ops import unbind as xunbind

min_run_time = 0.5
device = torch.device("cuda")
Expand All @@ -22,10 +23,13 @@
(9456, 1536, 4096),
(4440, 1536, 4096),
(4728, 1536, 4096),
# Some smaller shapes as well
(4728, 1536, 1024),
]


OP = xsw._SwiGLUDecomposedOp
# OP = xsw._SwiGLUDecomposedOp
OP = xsw._SwiGLUFusedOp


def product_dict(**kwargs):
Expand All @@ -38,7 +42,7 @@ def product_dict(**kwargs):
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.half, torch.float],
dtype=[torch.half],
)
)

Expand All @@ -61,11 +65,16 @@ def benchmark_swiglu(shape, dtype):
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast

params = module._ordered_params_for_op()
# w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape])
# params[0], params[2] = w1w2.unbind(dim=0)

yield benchmark.Timer(
stmt="fn(x, *args)",
globals={
"x": x,
"args": module._ordered_params_for_op(),
"args": params,
"fn": partial(xsw.functional_swiglu, op=OP),
},
label="swiglu_fw",
Expand Down Expand Up @@ -103,7 +112,11 @@ def benchmark_swiglu_bw(shape, dtype):
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]}"

assert not autocast
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=OP)
params = module._ordered_params_for_op()
w1w2 = torch.cat([params[0], params[2]], dim=0).view([2, *params[0].shape]).detach()
w1w2.requires_grad_()
params[0], params[2] = xunbind(w1w2, dim=0)
out = xsw.functional_swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/csrc/cuda/sddmm2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/types.h>
#include "computeUtil.h"
#include "../computeUtil.h"

namespace ge_spmm {

Expand Down
150 changes: 150 additions & 0 deletions xformers/components/swiglu/cuda/dual_gemm_silu_identity_mul.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include <ATen/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ScalarOps.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

#include "43_dual_gemm/device/dual_gemm.h"
#include "43_dual_gemm/thread/left_silu_and_mul.h"

namespace {
template <typename scalar_t>
std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
const at::Tensor& x,
const at::Tensor& w0,
const at::Tensor& b0,
const at::Tensor& w1,
const at::Tensor& b1
) {
TORCH_CHECK(x.stride(-1) == 1);
TORCH_CHECK(w0.stride(-1) == 1);
TORCH_CHECK(w1.stride(-1) == 1);

at::cuda::CUDAGuard device_guard(x.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int64_t B = x.size(0);
int64_t I = x.size(1);
int64_t H = w0.size(0);

at::Tensor d0 = at::empty({B, H}, x.options());
at::Tensor d1 = at::empty({B, H}, x.options());
at::Tensor d2 = at::empty({B, H}, x.options());

// templati-ze the cutlass kernel
cutlass::gemm::GemmCoord problem_size(B, H, I);

constexpr int kStages = 3;
constexpr bool kSplitKSerial = false;

using ElementOutput = scalar_t;
using ElementAccumulator = float;
using ElementCompute = float;
using EpilogueOutputOp01 = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>;
using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementOutput,
ElementCompute
>;

const ElementCompute alpha0 = ElementCompute(1);
const ElementCompute beta0 = ElementCompute(1);
const ElementCompute alpha1 = ElementCompute(1);
const ElementCompute beta1 = ElementCompute(1);

using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;

// Optionally, we might not need intermediate GEMM outputs
constexpr bool kStoreD0 = true;
constexpr bool kStoreD1 = true;

using DualGemm = cutlass::gemm::device::DualGemm<
scalar_t,
cutlass::layout::RowMajor,
scalar_t,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOutputOp01,
EpilogueOutputOp01,
EpilogueOutputOp2,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
kStages,
kStoreD0,
kStoreD1,
kSplitKSerial
>;

int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
using RefA = typename cutlass::TensorRef<typename DualGemm::ElementA, typename DualGemm::LayoutA>;
using RefB = typename cutlass::TensorRef<typename DualGemm::ElementB, typename DualGemm::LayoutB>;
using RefC = typename cutlass::TensorRef<typename DualGemm::ElementC, typename DualGemm::LayoutC>;
typename DualGemm::Arguments arguments{
problem_size,
RefA{(scalar_t*)x.data_ptr(), typename DualGemm::LayoutA::Stride(x.stride(0))},
RefB{(scalar_t*)w0.data_ptr(), typename DualGemm::LayoutB::Stride(w0.stride(0))},
RefC{(scalar_t*)b0.data_ptr(), typename DualGemm::LayoutC::Stride(0)},
RefC{(scalar_t*)d0.data_ptr(), typename DualGemm::LayoutC::Stride(d0.stride(0))},
RefB{(scalar_t*)w1.data_ptr(), typename DualGemm::LayoutB::Stride(w1.stride(0))},
RefC{(scalar_t*)b1.data_ptr(), typename DualGemm::LayoutC::Stride(0)},
RefC{(scalar_t*)d1.data_ptr(), typename DualGemm::LayoutC::Stride(d1.stride(0))},
RefC{(scalar_t*)d2.data_ptr(), typename DualGemm::LayoutC::Stride(d2.stride(0))},
typename DualGemm::EpilogueOutputOp0::Params{alpha0, beta0},
typename DualGemm::EpilogueOutputOp1::Params{alpha1, beta1},
typename DualGemm::EpilogueOutputOp2::Params{},
split_k_slices
};
DualGemm dual_gemm;
at::Tensor workspace = at::empty({int64_t(dual_gemm.get_workspace_size(arguments))}, x.options().dtype(at::ScalarType::Byte));
cutlass::Status status = dual_gemm.can_implement(arguments);
TORCH_CHECK(status == cutlass::Status::kSuccess, "not supported by this kernel");
status = dual_gemm.initialize(arguments, (uint8_t*)workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel initialize failed");
status = dual_gemm(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "kernel run failed");
return std::make_tuple(d0, d1, d2);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul(
const at::Tensor& x,
const at::Tensor& w0,
const at::Tensor& b0,
const at::Tensor& w1,
const at::Tensor& b1
) {
// TODO: Check all params. This would take a lot of lines of code...
TORCH_CHECK(x.dim() == 2);
TORCH_CHECK(w0.dim() == 2);
TORCH_CHECK(w1.dim() == 2);

#define FWD_PARAMS x,w0,b0,w1,b1

if (x.scalar_type() == at::ScalarType::Half) {
return dual_gemm_silu_identity_mul_<cutlass::half_t>(FWD_PARAMS);
} else {
TORCH_CHECK(x.scalar_type() == at::ScalarType::BFloat16, "Only supports bf16/f16");
return dual_gemm_silu_identity_mul_<cutlass::bfloat16_t>(FWD_PARAMS);
}
}
} // namespace

TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::dual_gemm_silu_identity_mul"),
TORCH_FN(dual_gemm_silu_identity_mul));
}
Loading

0 comments on commit bf91e62

Please sign in to comment.