Skip to content

Commit

Permalink
Memory efficient attention - backward pass (#281)
Browse files Browse the repository at this point in the history
* Add naive CPU implementation for memory-efficient attention backward

Needs cleanup

* Optimize (at least) by a factor 2

* More cleanups

* A few more comments

* Add very naive CUDA implementation

It's super slow!

* Speedup CUDA kernel by 5x

But we still have a long way to go

* Make logsumexp an argument

* Make it 30% faster

Merge two loops together and use local buffers for accumulation and grad_q. The use of local buffers as is currently introduces limitations on the sizes of dimension K

* 3.5x speedup by blocking strategy

* Use vector loads and improve tile selection

Brings an extra 2x improvement

* Recompute attention for grad_q computation

Makes it another 20% faster, and doesn't use extra memory

* Smal cleanups

* clang-format

* Make it 0.5% faster

Use all threads to compute grad_q

* Make it 1% faster by caching the loads

* Make it 6% faster with better hyperparameters

* Slightly better hyperparameter

* axpy == FMA

* Separate grad_q into its own kernel

This brings 50% speedup compared to the previous approach, despite redundant computation. The benefit comes from the fact that we are using better block sizes for the matmul computation of grad_q, which doesnt involve the transpose of the attention matrix

* Avoid additional global writes by recomputing grad_aatn_v in grad_k

Brings an additional 12% speedup despite duplicate computation

* Trying out new idea

* Almost on par with my previous best implementation

* Improve perf by 5%

Potentially due to avoiding bank conflicts?

* Remove query-key from shared memory and increase tile size

Brings 10% improvement, being better than my previous best version

* Make it 20% faster with better hyperparameters

This is now significantly faster than what we had before, and is even faster than the vanilla implementation

* Make it another 12% faster

This is now 18% faster than the vanilla implementation

* Code cleanup

* Further cleanups

Remove previous implementation

* Variable rename

* clang-format

* Add alternative implementation for grad_v

So far it has exactly the same speed as the previous kernel, but is much more similar to the grad_q and grad_k kernels

* Speed it up by 10% with better hyperparameters

* Delete old implementation

* Centralize all input accesses in the beginning

This will make it easier to support inputs which are not multiple of 32. Plus, this seems to give a small performance improvement, in the order of 1%

* Bugfix

Only shows up for certain sizes for some reason

* Make kernels generic wrt sequence length

This introduces a slowdown of 25%, mostly due to the index computation in the preamble of each kernel. In a next commit I'll try to optimize this out

* Add template argument to skip bound checking

Brings back speed to where it was, for the cases where we can safely skip this

* Make it support all use-cases

* Let logsumexp be returned by forward

Also add an autograd Function for backward

* clang-format

* Add scaling factor

* Add tests + silly bugfix

* Add benchmark function for backward

* Add comment

* clang-format
  • Loading branch information
fmassa authored Apr 25, 2022
1 parent ad94fd1 commit a0fb375
Show file tree
Hide file tree
Showing 6 changed files with 1,037 additions and 121 deletions.
55 changes: 55 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,58 @@ def test_key_query_all_ones(device, q_len, kv_len, batch_size, k_len):
ref = value.mean(1, keepdim=True).expand_as(query)

assert torch.allclose(out, ref, atol=1e-5)


@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 3, 5])
@pytest.mark.parametrize("device", _devices)
def test_logsumexp(device, q_len, kv_len, batch_size, k_len):
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

_, lse = torch.ops.xformers.efficient_attention(query, key, value, True)
ref_lse = ((query / k_len ** 0.5) @ key.transpose(-2, -1)).logsumexp(-1)

assert torch.allclose(lse, ref_lse, atol=2e-4)


@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 3, 5])
@pytest.mark.parametrize("device", _devices)
def test_memory_efficient_attention_backward(device, q_len, kv_len, batch_size, k_len):
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)

out = xformers.ops.memory_efficient_attention(query, key, value)
out.backward(torch.ones_like(query))

grad_q = query.grad
grad_k = key.grad
grad_v = value.grad

query.grad = None
key.grad = None
value.grad = None

ref = ref_attention(query, key, value)
ref.backward(torch.ones_like(query))

# there is some extra precision loss in the CPU implementation due to an
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
atol = 3e-4 if device == "cuda" else 4e-4
assert torch.allclose(grad_q, query.grad, atol=atol), "grad_q doesn't match"
assert torch.allclose(grad_k, key.grad, atol=atol), "grad_k doesn't match"
assert torch.allclose(grad_v, value.grad, atol=atol), "grad_v doesn't match"
207 changes: 144 additions & 63 deletions xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,66 +30,147 @@ def ref_attention(q, k, v):
results = []
mem_use: Dict[str, Dict[str, float]] = dict(optimized={}, vanilla={})

print(f"Processing {len(SHAPES)} cases")
for num_threads in NUM_THREADS:
for shape in SHAPES:
print(f"===== {shape} =====")
B, M, K = shape
q = torch.rand(shape, device=device)
sub_label = f"B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q)

rr = ref_attention(q, q, q)
assert (r - rr).abs().max() < 1e-5

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": torch.ops.xformers.efficient_attention,
},
label="attention",
description="optimized",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

print("Optimized", memory_str)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": ref_attention,
},
label="attention",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)


compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)

def benchmark_forward():
print(f"Processing {len(SHAPES)} cases")
print("Forward")
for num_threads in NUM_THREADS:
for shape in SHAPES:
print(f"===== {shape} =====")
B, M, K = shape
q = torch.rand(shape, device=device)
sub_label = f"B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q)

rr = ref_attention(q, q, q)
assert (r - rr).abs().max() < 1e-5

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": xformers.ops.memory_efficient_attention,
},
label="attention",
description="optimized",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

print("Optimized", memory_str)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="fn(q, q, q)",
globals={
"q": q,
"fn": ref_attention,
},
label="attention",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)

compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)


def benchmark_backward():
print(f"Processing {len(SHAPES)} cases")
print("Backward")
for num_threads in NUM_THREADS:
for shape in SHAPES:
print(f"===== {shape} =====")
B, M, K = shape
q = torch.rand(shape, device=device, requires_grad=True)
sub_label = f"B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q)
r.backward(torch.ones_like(q))

grad = q.grad
q.grad = None

rr = ref_attention(q, q, q)
rr.backward(torch.ones_like(q))
assert (grad - q.grad).abs().max() < 1e-5

out = xformers.ops.memory_efficient_attention(q, q, q)
grad = torch.ones_like(q)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="attention",
description="optimized",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)
torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["optimized"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"

print("Optimized", memory_str)

out = ref_attention(q, q, q)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
results.append(
benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="attention",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
).blocked_autorange(min_run_time=min_run_time)
)

torch.cuda.synchronize()
memory = torch.cuda.max_memory_allocated() / 2 ** 20
mem_use["vanilla"][sub_label] = memory
memory_str = f"Memory used: {memory} MB"
print("Vanilla", memory_str)

compare = benchmark.Compare(results)
compare.print()

pprint.pprint(mem_use)


benchmark_forward()
benchmark_backward()
4 changes: 3 additions & 1 deletion xformers/components/attention/csrc/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@

TORCH_LIBRARY_FRAGMENT(xformers, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value) -> Tensor"));
"xformers::efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_logsumexp) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp) -> (Tensor, Tensor, Tensor)"));
}
Loading

0 comments on commit a0fb375

Please sign in to comment.