Skip to content

Commit

Permalink
small cleanup and much better perfs
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Apr 26, 2022
1 parent a0fb375 commit 3f168e1
Show file tree
Hide file tree
Showing 23 changed files with 21 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix some torchscriptability [#246]
- Fix FourierMix being compatible with AMP [#258]
- Better asserts on QKV dimensions [#264]
- Better perfs for FusedMLP and FusedLinearLayer [#283]

### Added
- Simplicial Embeddings [#259]
- Mem efficient attention, FW pass [#267]
- MHA benchmark
- MLP benchmark
- Move all triton kernels to triton v2 [#272]
- Mem efficient attention, BW pass [#281]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_none.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions tests/test_triton_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def test_fused_matmul(shape, dtype):

# Test that not passing any bias is fine
res_torch = a @ b
res_triton, _ = fused_matmul(a, b.transpose(0, 1), None)
res_triton, _ = fused_matmul(a, b.transpose(0, 1).contiguous(), None)
assert torch.allclose(res_torch, res_triton), "Vanilla matmul is broken"

# Now test with a real FMA
c = -torch.rand((shape[-2],), dtype=dtype, device="cuda")
res_torch = torch.addmm(c, a, b)
res_triton, _ = fused_matmul(a, b.transpose(1, 0), c)
res_triton, _ = fused_matmul(a, b.transpose(1, 0).contiguous(), c)

assert torch.allclose(
res_torch, res_triton
Expand All @@ -65,7 +65,9 @@ def test_fused_matmul(shape, dtype):
res_torch = torch_activation(torch.addmm(c, a, b))

triton_activation = get_triton_activation_kernel(activation)
res_triton, _ = fused_matmul(a, b.transpose(1, 0), c, triton_activation)
res_triton, _ = fused_matmul(
a, b.transpose(1, 0).contiguous(), c, triton_activation
)

# NOTE: @lefaudeux
# GeLUs are not well handled for now, we use an approximation
Expand Down
26 changes: 14 additions & 12 deletions xformers/triton/k_fused_matmul_fw.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def kernel_fma(
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_om, stride_im,
stride_wn, stride_wk,
stride_wn,
# Meta-parameters
BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
Expand Down Expand Up @@ -93,8 +93,8 @@ def kernel_fma(
rk = tl.arange(0, BLOCK_K)

# the memory addresses of elements can follow numpy broadcasting
input_ptrs = INPUT + rm[:, None] * stride_im + rk[None, :]
weight_ptrs = WEIGHT + rk[:, None] * stride_wk + rn[None, :] * stride_wn
input_ptrs = INPUT + rm[:, None] * stride_im
weight_ptrs = WEIGHT + rn[None, :] * stride_wn

# initialize and iteratively update accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
Expand All @@ -105,27 +105,28 @@ def kernel_fma(

# block level matrix multiplication.
# We fetch a block memory block from both inputs, matmul and accumulate, then repeat
for _ in range(K, 0, -BLOCK_K):
a = tl.load(input_ptrs, mask=((rk[None, :] < K) & (rm[:, None] < M)), other=0.0)
w = tl.load(weight_ptrs, mask=((rk[:, None] < K) & (rn[None, :] < N)), other=0.0)
mask_rn = rn < N
mask_rm = rm < M

acc += tl.dot(a, w).to(tl.float32)
for i in range(0, K, BLOCK_K):
rk = tl.arange(0, BLOCK_K) + i
a = tl.load(input_ptrs + rk[None, :], mask=((rk[None, :] < K) & mask_rm[:, None]), other=0.0)
w = tl.load(weight_ptrs + rk[:, None], mask=((rk[:, None] < K) & mask_rn[None, :]), other=0.0)

input_ptrs += BLOCK_K
weight_ptrs += BLOCK_K * stride_wk
acc += tl.dot(a, w)

# optional: save the activation inputs
if SAVE_ACT_INPUTS:
act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :]
tl.store(act_in_ptrs, acc, mask=(rm[:, None] < M) & (rn[None, :] < N))
tl.store(act_in_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])

# optional: fused activation (while the data is in shared memory)
if ACTIVATION:
acc = ACTIVATION(acc)

# write back result
out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :]
tl.store(out_ptrs, acc, mask=(rm[:, None] < M) & (rn[None, :] < N))
tl.store(out_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])


# Activation needs to be a triton kernel
Expand Down Expand Up @@ -153,6 +154,7 @@ def fused_matmul(
assert (
bias is None or bias.shape[0] == weight.shape[0]
), "Incompatible dimensions in between weight and bias"
assert weight.is_contiguous()

M, K = x_.shape
N, K = weight.shape
Expand All @@ -169,7 +171,7 @@ def fused_matmul(
bias if bias is not None else x, # auto skip bias if not present
M, N, K, # shapes
outputs.stride(0), x_.stride(0), # strides
weight.stride(0), weight.stride(1),
weight.stride(0),
ACTIVATION=activation, # optional fused activation
BIAS=bias is not None, # optional fused bias
GROUP_M=8, # speed optimization: group the programs
Expand Down

0 comments on commit 3f168e1

Please sign in to comment.