Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwiGLU further optimization in MLP bw #502

Merged
merged 9 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def generate_test_shapes():
_ops: Sequence[xsw.SwiGLUOp] = [xsw.SwiGLUFusedOp, xsw.SwiGLUPackedFusedOp]


@pytest.mark.parametrize("autocast", [False, True])
@pytest.mark.parametrize("pack_weights", [False, True])
@pytest.mark.parametrize("autocast", [False, True], ids=["regular", "autocast"])
@pytest.mark.parametrize("pack_weights", [False, True], ids=["regular", "packed"])
@pytest.mark.parametrize("op", _ops, ids=[x.NAME for x in _ops])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
Expand Down
79 changes: 53 additions & 26 deletions xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ScalarOps.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/autocast_mode.h>
#include <torch/library.h>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_with_k_reduction.h"
#include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/kernel/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#include "cutlass/matrix_coord.h"

#define WORKAROUND_CUTLASS_BUG

namespace {
template <typename scalar_t>
Expand All @@ -25,12 +25,8 @@ void gemm_fused_operand_sum_(
at::cuda::CUDAGuard device_guard(a.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int64_t M = a.size(0);
int64_t N = b.size(0);
int64_t K = a.size(1);

// templati-ze the cutlass kernel
cutlass::gemm::GemmCoord problem_size(M, N, K);
cutlass::gemm::GemmCoord problem_size(a.size(0), b.size(1), a.size(1));
danthe3rd marked this conversation as resolved.
Show resolved Hide resolved
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation
using ElementInputA = scalar_t;
Expand Down Expand Up @@ -69,11 +65,13 @@ void gemm_fused_operand_sum_(
constexpr int NumStages = 4;

// Reduce A or B operand along the K dimension
#ifdef WORKAROUND_CUTLASS_BUG
constexpr bool ReduceKForA = false;
#else
constexpr bool ReduceKForA = true;
#endif

// Alignment of A operand
constexpr int AlignmentA = 8;

// Alignment of B operand
constexpr int AlignmentB = 8;

// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
Expand All @@ -82,12 +80,11 @@ void gemm_fused_operand_sum_(
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
ElementComputeEpilogue>;

using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction<
ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8,
ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8,
using Gemm = typename cutlass::gemm::device::GemmWithKReduction<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
Expand All @@ -99,13 +96,15 @@ void gemm_fused_operand_sum_(
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
AlignmentA,
AlignmentB,
cutlass::arch::OpMultiplyAdd,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;

// Below is the reduction kernel used in the case of parallel split-k
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;;
using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;

using ReduceOp = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator,
Expand All @@ -131,7 +130,7 @@ void gemm_fused_operand_sum_(
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
cutlass::epilogue::thread::ScaleType::Nothing>;

using ReduceVectorSplitKKernel = cutlass::reduction::kernel::ReduceSplitK<
ReduceVectorSplitKShape,
Expand All @@ -143,11 +142,8 @@ void gemm_fused_operand_sum_(
auto alpha = ElementComputeEpilogue(1);
auto beta = ElementComputeEpilogue(0);

using RefA = cutlass::TensorRef<ElementInputA, LayoutInputA>;
using RefB = cutlass::TensorRef<ElementInputB, LayoutInputB>;
using RefC = cutlass::TensorRef<ElementOutput, LayoutOutput>;
int reduce_vector_length = ReduceKForA ? problem_size.m() : problem_size.n();
int split_k_slices = 2;
int split_k_slices = 1;
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
Expand Down Expand Up @@ -192,21 +188,52 @@ std::tuple<at::Tensor, at::Tensor> gemm_fused_operand_sum(
TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(b.dim() == 2);
TORCH_CHECK(out_mm.dim() == 2);
TORCH_CHECK(out_mm.size(0) == a.size(0));
TORCH_CHECK(out_mm.size(1) == b.size(1));
TORCH_CHECK(out_sum.dim() == 1);

#define FWD_PARAMS a,b,out_mm,out_sum

if (a.scalar_type() == at::ScalarType::Half) {
TORCH_CHECK(b.scalar_type() == at::ScalarType::Half);
TORCH_CHECK(out_mm.scalar_type() == at::ScalarType::Half);
TORCH_CHECK(out_sum.scalar_type() == at::ScalarType::Half);
gemm_fused_operand_sum_<cutlass::half_t>(FWD_PARAMS);
} else {
TORCH_CHECK(a.scalar_type() == at::ScalarType::BFloat16, "Only supports bf16/f16");
TORCH_CHECK(b.scalar_type() == at::ScalarType::BFloat16);
TORCH_CHECK(out_mm.scalar_type() == at::ScalarType::BFloat16);
TORCH_CHECK(out_sum.scalar_type() == at::ScalarType::BFloat16);
gemm_fused_operand_sum_<cutlass::bfloat16_t>(FWD_PARAMS);
}
return std::make_tuple(out_mm, out_sum);
}

std::tuple<at::Tensor, at::Tensor> gemm_fused_operand_sum_autocast(
const at::Tensor& a,
const at::Tensor& b,
at::Tensor& out_mm,
at::Tensor& out_sum
) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::get_autocast_gpu_dtype();
return gemm_fused_operand_sum(
at::autocast::cached_cast(exec_type, a),
at::autocast::cached_cast(exec_type, b),
out_mm,
out_sum
);
}
} // namespace

TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::gemm_fused_operand_sum"),
TORCH_FN(gemm_fused_operand_sum));
}

TORCH_LIBRARY_IMPL(xformers, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::gemm_fused_operand_sum"),
TORCH_FN(gemm_fused_operand_sum_autocast));
}
24 changes: 17 additions & 7 deletions xformers/components/swiglu/swiglu_packedw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ std::tuple<at::Tensor, at::Tensor> silu_bw_fused(
.typed<decltype(silu_bw_fused)>();
return op.call(x1, x2, dx4);
}
std::tuple<at::Tensor, at::Tensor> gemm_fused_operand_sum(
const at::Tensor& a,
const at::Tensor& b,
at::Tensor& out_mm,
at::Tensor& out_sum
) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("xformers::gemm_fused_operand_sum", "")
.typed<decltype(gemm_fused_operand_sum)>();
return op.call(a, b, out_mm, out_sum);
}

bool shapesMatch(at::Tensor x, std::vector<int64_t> expectedShape) {
if (x.dim() != int64_t(expectedShape.size())) {
Expand Down Expand Up @@ -107,11 +118,10 @@ class SwiGLUPackedWeights : public torch::autograd::Function<SwiGLUPackedWeights
x2.reset();
dx4.reset();

auto db3 = dx5.sum(0);
auto db3 = torch::empty({O}, w3.options());
auto dw3 = torch::empty({O, H}, w3.options());
TORCH_INTERNAL_ASSERT(dx5.size(0) == x4.size(0));
auto dw3 = torch::mm(dx5.transpose(-2, -1), x4);
TORCH_INTERNAL_ASSERT_SHAPE(db3, O);
TORCH_INTERNAL_ASSERT_SHAPE(dw3, O, H);
gemm_fused_operand_sum(dx5.transpose(-2, -1), x4, dw3, db3);
x4.reset();
dx5.reset();

Expand All @@ -123,10 +133,10 @@ class SwiGLUPackedWeights : public torch::autograd::Function<SwiGLUPackedWeights
auto dx = torch::mm(dx1dx2, w1w2);

// backward of linear1 + linear2 - packed
auto dw1dw2 = torch::mm(dx1dx2.transpose(-2, -1), x);
auto db1db2 = dx1dx2.sum(0);
auto dw1dw2 = torch::empty({2 * H, I}, w1w2.options());
auto db1db2 = torch::empty({2 * H}, w1w2.options());
gemm_fused_operand_sum(dx1dx2.transpose(-2, -1), x, dw1dw2, db1db2);

auto p = db1db2.view({2, H});
return {dx, dw1dw2.view({2, H, I}), db1db2.view({2, H}), dw3, db3};
}
};
Expand Down
25 changes: 17 additions & 8 deletions xformers/ops/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -153,6 +153,15 @@ def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
ctx.save_for_backward(x, w1, w2, w3, x1, x2)
return x5

@staticmethod
def _linear_bw(
dy: torch.Tensor, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
torch.ops.xformers.gemm_fused_operand_sum(dy.transpose(-2, -1), x, dw, db)
return dw, db

@classmethod
@torch.cuda.amp.custom_bwd
def backward(cls, ctx, dx5):
Expand All @@ -164,8 +173,7 @@ def backward(cls, ctx, dx5):
dx1, dx2 = dx1dx2.unbind(1)
del x1, x2, dx4

db3 = dx5.sum(0) # 25us
dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt)
dw3, db3 = cls._linear_bw(dx5, x4)
del x4, dx5
if w1w2 is not None:
assert dx1dx2.is_contiguous()
Expand All @@ -175,18 +183,19 @@ def backward(cls, ctx, dx5):

# backward of linear1 + linear2 - packed
dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x
db1db2 = dx1dx2.sum(0).view([2, dx1.shape[1]])
dw1dw2, db1db2 = cls._linear_bw(
dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x
)
db1db2 = db1db2.view([2, dx1.shape[1]])
dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0)
db1, db2 = torch.unbind(db1db2, dim=0)
else:
dx = dx2 @ w2 # 260us (nn)
torch.addmm(
dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx
) # dx += dx1 @ w1
dw2 = dx2.transpose(-2, -1) @ x # 245us (nt)
db2 = dx2.sum(0) # 50us
dw1 = dx1.transpose(-2, -1) @ x # 245us (nt)
db1 = dx1.sum(0) # 50us
dw2, db2 = cls._linear_bw(dx2, x)
dw1, db1 = cls._linear_bw(dx1, x)
return (dx, dw1, db1, dw2, db2, dw3, db3)


Expand Down