Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory efficient attention - backward pass #281

Merged
merged 45 commits into from
Apr 25, 2022
Merged

Conversation

fmassa
Copy link
Contributor

@fmassa fmassa commented Apr 22, 2022

What does this PR do?

This PR implements the memory-efficient attention mechanism from https://arxiv.org/pdf/2112.05682v2.pdf, with both CPU and CUDA kernels, targetting the backward implementation. For now, only fp32 is supported.

The CPU implementation is naive and not meant to be fast, and is there only as a reference.

The CUDA implementation has competitive runtimes compared to a vanilla PyTorch implementation, while using 10x+ less memory.

Contrary to the forward implementation, the backwards supports inputs of arbitrary number of embedding sizes. I'll probably update the forward implementation to use a similar approach in the future.

In order to keep the backwards somewhat efficient, we need to return the logsumexp during forward as well. For some reason, I needed to template it in CUDA otherwise I would face performance slowdowns.

In the same vein, in order to support arbitrary sequence lengths, I use a masking approach. But the masking (and in particular the min(index, M) operations) are very slow, so I templated this as well so that we don't need to run this if the sequence length is a multiple of 32.

Full details for the benchmark I run can be found in here
Backward
===== (1, 127, 16) =====
Optimized Memory used: 0.07861328125 MB
Vanilla Memory used: 0.310546875 MB
===== (1, 127, 32) =====
Optimized Memory used: 0.15673828125 MB
Vanilla Memory used: 0.373046875 MB
===== (1, 128, 16) =====
Optimized Memory used: 0.07861328125 MB
Vanilla Memory used: 0.3125 MB
===== (1, 128, 32) =====
Optimized Memory used: 0.15673828125 MB
Vanilla Memory used: 0.375 MB
===== (1, 512, 16) =====
Optimized Memory used: 0.314453125 MB
Vanilla Memory used: 4.25 MB
===== (1, 512, 32) =====
Optimized Memory used: 0.626953125 MB
Vanilla Memory used: 4.5 MB
===== (1, 513, 16) =====
Optimized Memory used: 0.31982421875 MB
Vanilla Memory used: 4.271484375 MB
===== (1, 513, 32) =====
Optimized Memory used: 0.63232421875 MB
Vanilla Memory used: 4.521484375 MB
===== (1, 1023, 16) =====
Optimized Memory used: 0.62890625 MB
Vanilla Memory used: 16.470703125 MB
===== (1, 1023, 32) =====
Optimized Memory used: 1.25390625 MB
Vanilla Memory used: 16.970703125 MB
===== (1, 1024, 16) =====
Optimized Memory used: 0.62890625 MB
Vanilla Memory used: 16.5 MB
===== (1, 1024, 32) =====
Optimized Memory used: 1.25390625 MB
Vanilla Memory used: 17.0 MB
===== (8, 127, 16) =====
Optimized Memory used: 0.6240234375 MB
Vanilla Memory used: 2.466796875 MB
===== (8, 127, 32) =====
Optimized Memory used: 1.244140625 MB
Vanilla Memory used: 2.962890625 MB
===== (8, 128, 16) =====
Optimized Memory used: 0.62890625 MB
Vanilla Memory used: 2.5 MB
===== (8, 128, 32) =====
Optimized Memory used: 1.25390625 MB
Vanilla Memory used: 3.0 MB
===== (8, 512, 16) =====
Optimized Memory used: 2.515625 MB
Vanilla Memory used: 34.0 MB
===== (8, 512, 32) =====
Optimized Memory used: 5.015625 MB
Vanilla Memory used: 36.0 MB
===== (8, 513, 16) =====
Optimized Memory used: 2.52099609375 MB
Vanilla Memory used: 34.130859375 MB
===== (8, 513, 32) =====
Optimized Memory used: 5.02587890625 MB
Vanilla Memory used: 36.134765625 MB
===== (8, 1023, 16) =====
Optimized Memory used: 5.0263671875 MB
Vanilla Memory used: 131.99609375 MB
===== (8, 1023, 32) =====
Optimized Memory used: 10.021484375 MB
Vanilla Memory used: 135.9921875 MB
===== (8, 1024, 16) =====
Optimized Memory used: 5.03125 MB
Vanilla Memory used: 132.0 MB
===== (8, 1024, 32) =====
Optimized Memory used: 10.03125 MB
Vanilla Memory used: 136.0 MB
===== (32, 127, 16) =====
Optimized Memory used: 2.49609375 MB
Vanilla Memory used: 9.861328125 MB
===== (32, 127, 32) =====
Optimized Memory used: 4.9765625 MB
Vanilla Memory used: 11.845703125 MB
===== (32, 128, 16) =====
Optimized Memory used: 2.515625 MB
Vanilla Memory used: 10.0 MB
===== (32, 128, 32) =====
Optimized Memory used: 5.015625 MB
Vanilla Memory used: 12.0 MB
===== (32, 512, 16) =====
Optimized Memory used: 10.0625 MB
Vanilla Memory used: 136.0 MB
===== (32, 512, 32) =====
Optimized Memory used: 20.0625 MB
Vanilla Memory used: 144.0 MB
===== (32, 513, 16) =====
Optimized Memory used: 13.9453125 MB
Vanilla Memory used: 140.2548828125 MB
===== (32, 513, 32) =====
Optimized Memory used: 23.08056640625 MB
Vanilla Memory used: 147.51171875 MB
===== (32, 1023, 16) =====
Optimized Memory used: 21.130859375 MB
Vanilla Memory used: 529.001953125 MB
===== (32, 1023, 32) =====
Optimized Memory used: 40.11328125 MB
Vanilla Memory used: 543.9765625 MB
===== (32, 1024, 16) =====
Optimized Memory used: 21.138671875 MB
Vanilla Memory used: 529.013671875 MB
===== (32, 1024, 32) =====
Optimized Memory used: 40.125 MB
Vanilla Memory used: 544.0 MB
===== (256, 127, 16) =====
Optimized Memory used: 20.0615234375 MB
Vanilla Memory used: 79.470703125 MB
===== (256, 127, 32) =====
Optimized Memory used: 40.216796875 MB
Vanilla Memory used: 95.6572265625 MB
===== (256, 128, 16) =====
Optimized Memory used: 20.234375 MB
Vanilla Memory used: 80.109375 MB
===== (256, 128, 32) =====
Optimized Memory used: 40.125 MB
Vanilla Memory used: 96.0 MB
===== (256, 512, 16) =====
Optimized Memory used: 80.5 MB
Vanilla Memory used: 1088.0 MB
===== (256, 512, 32) =====
Optimized Memory used: 160.5 MB
Vanilla Memory used: 1152.0 MB
===== (256, 513, 16) =====
Optimized Memory used: 80.6572265625 MB
Vanilla Memory used: 1096.125 MB
===== (256, 513, 32) =====
Optimized Memory used: 160.8134765625 MB
Vanilla Memory used: 1160.25 MB
===== (256, 1023, 16) =====
Optimized Memory used: 160.9990234375 MB
Vanilla Memory used: 4216.03515625 MB
===== (256, 1023, 32) =====
Optimized Memory used: 320.8427734375 MB
Vanilla Memory used: 4343.91015625 MB
===== (256, 1024, 16) =====
Optimized Memory used: 161.0 MB
Vanilla Memory used: 4224.0 MB
===== (256, 1024, 32) =====
Optimized Memory used: 321.0 MB
Vanilla Memory used: 4352.0 MB

[------------------- attention -------------------]
                           |  optimized  |  vanilla
1 threads: ----------------------------------------
      B=1, M=127, K=16     |     183.7   |    237.8
      B=1, M=127, K=32     |     183.9   |    239.5
      B=1, M=128, K=16     |     146.0   |    168.1
      B=1, M=128, K=32     |     144.5   |    167.1
      B=1, M=512, K=16     |     142.1   |    165.5
      B=1, M=512, K=32     |     142.1   |    165.0
      B=1, M=513, K=16     |     146.7   |    171.5
      B=1, M=513, K=32     |     146.8   |    171.6
      B=1, M=1023, K=16    |     156.5   |    186.3
      B=1, M=1023, K=32    |     229.5   |    211.8
      B=1, M=1024, K=16    |     166.0   |    195.8
      B=1, M=1024, K=32    |     208.0   |    184.9
      B=8, M=127, K=16     |     143.0   |    165.8
      B=8, M=127, K=32     |     141.2   |    165.9
      B=8, M=128, K=16     |     141.2   |    174.7
      B=8, M=128, K=32     |     149.3   |    206.2
      B=8, M=512, K=16     |     223.7   |    399.3
      B=8, M=512, K=32     |     380.1   |    408.9
      B=8, M=513, K=16     |     331.0   |    419.5
      B=8, M=513, K=32     |     527.2   |    432.3
      B=8, M=1023, K=16    |    1018.1   |   1220.1
      B=8, M=1023, K=32    |    1552.4   |   1243.4
      B=8, M=1024, K=16    |     752.7   |   1206.7
      B=8, M=1024, K=32    |    1324.3   |   1234.4
      B=32, M=127, K=16    |     142.0   |    167.4
      B=32, M=127, K=32    |     141.5   |    169.9
      B=32, M=128, K=16    |     146.4   |    168.7
      B=32, M=128, K=32    |     141.8   |    165.5
      B=32, M=512, K=16    |     766.8   |   1180.1
      B=32, M=512, K=32    |    1362.6   |   1227.0
      B=32, M=513, K=16    |    1201.3   |   1288.6
      B=32, M=513, K=32    |    1950.0   |   1359.9
      B=32, M=1023, K=16   |    3996.6   |   4392.1
      B=32, M=1023, K=32   |    5977.2   |   4546.1
      B=32, M=1024, K=16   |    2939.9   |   4347.4
      B=32, M=1024, K=32   |    5110.6   |   4508.7
      B=256, M=127, K=16   |     581.3   |    629.5
      B=256, M=127, K=32   |     946.7   |    706.3
      B=256, M=128, K=16   |     446.4   |    622.7
      B=256, M=128, K=32   |     828.6   |    695.7
      B=256, M=512, K=16   |    6017.2   |   8183.5
      B=256, M=512, K=32   |   10432.4   |   8744.9
      B=256, M=513, K=16   |    9561.7   |   9208.7
      B=256, M=513, K=32   |   15071.8   |   9893.5
      B=256, M=1023, K=16  |   31757.3   |  32366.8
      B=256, M=1023, K=32  |   47725.4   |  33994.0
      B=256, M=1024, K=16  |   23353.4   |  32033.4
      B=256, M=1024, K=32  |   40533.0   |  33811.4

Times are in microseconds (us).

Fixes #161.

fmassa added 30 commits April 13, 2022 02:27
But we still have a long way to go
Merge two loops together and use local buffers for accumulation and grad_q. The use of local buffers as is currently introduces limitations on the sizes of dimension K
Brings an extra 2x improvement
Makes it another 20% faster, and doesn't use extra memory
Use all threads to compute grad_q
This brings 50% speedup compared to the previous approach, despite redundant computation. The benefit comes from the fact that we are using better block sizes for the matmul computation of grad_q, which doesnt involve the transpose of the attention matrix
Brings an additional 12% speedup despite duplicate computation
Potentially due to avoiding bank conflicts?
Brings 10% improvement, being better than my previous best version
This is now significantly faster than what we had before, and is even faster than the vanilla implementation
This is now 18% faster than the vanilla implementation
Remove previous implementation
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2022
@blefaudeux
Copy link
Contributor

it's great, thanks @fmassa ! Open question: how would it make the more sense to integrate it with the rest, for people who build from the registers ? I can add a follow up PR to add that to the existing attention mechanisms, or it could just be a flag for "scaled dot product" (or something else ?)

ref = ref_attention(query, key, value)
ref.backward(torch.ones_like(query))

# there is some extra precision loss in the CPU implementation due to an
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comment, helpful !

q = torch.rand(shape, device=device)
sub_label = f"B={B}, M={M}, K={K}"

if True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a debug flag that is sometimes helpful: sometimes I "break" the kernel by removing some parts of the computation and see what speedup I would get. But doing so means that the computation won't be correct anymore, so it was useful to just disable correctness checks.

I can remove this in if you want, but as I expect to still do some more performance tuning, I'd like to keep this around for a bit longer if it's ok with you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's totally ok, flagging it just in case but understood, no worries

pprint.pprint(mem_use)


benchmark_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] maybe possible to factorize the two, but not super important, good tool to have already !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's totally possible. I've also added in a separate branch a benchmark_forward_and_backward case, and it started to have quite a bit of duplication. I can look into refactoring this up in a follow-up PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not urgent and not blocking, same as above more of a mental note, sounds good

at::TensorAccessor<scalar_t, 3> query,
at::TensorAccessor<scalar_t, 3> key,
at::TensorAccessor<scalar_t, 3> value,
at::TensorAccessor<scalar_t, 3> buffer //,
at::TensorAccessor<scalar_t, 3> buffer,
bool compute_logsumexp
// at::TensorAccessor<int64_t, 2> mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahah, next step is to make it sparse ? :D (not that much of a joke..)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, integrating sparsity is in the plans :-) But before, I'll look into the K > 32 case

@@ -90,15 +92,18 @@ void attention_kernel(
for (int64_t k = 0; k < K; k++) {
oo[k] = buf[k] / s_prime;
}
if (compute_logsumexp)
logsumexp[i][j] = m_prime + std::log(s_prime);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I didn't think of that, this is why your backward is so fast, and it does not weight that much actually (one per line, not the whole attention map)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, keeping this temporary around made the backward kernel faster and also easier to implement, so I decided that it was worth the extra memory.

Also, in a follow-up PR I'll avoid allocating the logsumexp buffer in the function if compute_logsumexp is false, so that this is only a memory price to pay during training

@@ -58,6 +59,25 @@ __device__ __forceinline__ void iDiv(scalar_t x1, float* out) {
out[0] /= x1;
}

template <typename scalar_t>
__device__ __forceinline__ void myGpuAtomicAdd(scalar_t* address, float4 val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting, I thought that this was a little costly and not too much of a good idea in practice, looks like it's not a good intuition.. We're doing this (accumulate across threads) on the triton side for layernorm, but could be a good idea to extend this for linear layer / bias gradient, I'll have a look when I get the time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also a bit afraid of the atomicAdds in the beginning, but it turned out to not be too slow. Plus it made it easier to parallelize the kernel, so why not.

@@ -393,6 +416,7 @@ __global__ void attention_kernel(
output_block[q_item_idx] =
reinterpret_cast<vec_t*>(output[batch_idx][index].data());
m_prime[q_item_idx] = -std::numeric_limits<scalar_t>::infinity();
logsumexp_block[q_item_idx] = &logsumexp[batch_idx][index];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it was already there in the FW PR, but I missed this, could you elaborate ? You're collecting all the pointers as a first step ? Necessary for the pragma unroll down the line ? Trying to understand this cuda trick :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main benefit of this preamble is to handle sequence lengths which are not multiple of 32 (or the block size which I'm using). Instead of handling the out-of-bonds directly in the hotpaths of the kernel, I handle it beforehand and repeat the last element if needed, so I don't index out of the bounds of the kernel.
There are probably other / better ways of handling generic sequence lengths, but this was the easiest to implement so I decided to go for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I didn't see that at first sight, makes sense. I don't know of a better way myself (except that maybe that CUDA has similar concepts as all the graphics interfaces, where you can clamp memory/textures fetches automatically so that it does not out of bound and repeat or pad). It's been a long time since I wrote in Cuda so by now I just don't know..

@@ -473,56 +497,30 @@ __global__ void attention_kernel(
output_block[q_item_idx][k] = tmp;
}
}

if (compute_logsumexp) {
#pragma unroll
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guess is that the pointer array is because of this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might also help as well, yes, although it was not directly the main reason why I did it

at::Tensor logsumexp = at::empty({B, M}, query.options());

// have to pass compute_logsumexp as a template parameter
// otherwise there is a slowdown in the kernel...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register spilling ? (like, it inlines compute logsumexp + a branch, and it takes too much space ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be. Maybe I should try this out as a non-inlined function call to see if it make things better. Lots of improvements to be done in the future! :-)

vec_t tt = __ldg(vb[k_item_idx] + k);
#pragma unroll
for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) {
sputnik::VectorCompute<vec_t>::Dot(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I guess that computeDot is the most interesting place for that, but I cannot comment there.) It could be interesting to pull in tensor cores for the .dot() for fp16, and for fp32 on newer hardware (A100 +), I don't know if there are existing primitives to preferably use there (from Cuda/cudnn/cublas, like in this example or from torch).

Maybe that it's not needed actually, sorry for bringing these tensor cores up all the time but just trying to think ahead for the possible next steps (and possibly hidden caveats if you developed on a P100)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorCores are going to be a very important next step indeed, and I'm already looking into how to use them. But before that I think it might be better to support the K > 32 case first, wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree on bigger K first, at this point the perf impact is not super clear to me (for fp32 at least, on the 3080 which has tensor cores this was very competitive vs. pytorch and tensor cores) while the limitation of a small K is very well defined

r = xformers.ops.memory_efficient_attention(q, q, q)

rr = ref_attention(q, q, q)
assert (r - rr).abs().max() < 1e-5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not pass on a 3080 / cuda 11.6, could be interesting to test with the T4s on CircleCi, could well be because of TF32 (you would need to switch the torch flag forcing fp32 computations). Implicitly this probably means that the torch implementation switched to tensor cores I think, which changes the time difference in between the two implementations (but not a fundamental issue)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I should probably change those defaults, or just disable TF32 in the benchmarks (but that makes it for slower baselines), or just disable this correctness check by default. Which one would you prefer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would switch tf32 off here, I think that it's the best correctness check : you assume fp32 in the kernel, let's check correctness against fp32 ? (torch.backends.cuda.matmul.allow_tf32 = False)
Good to keep in mind in the benchmarks that the comparison is not iso-accuracy by the way, your implementation is actually more precise :)


rr = ref_attention(q, q, q)
rr.backward(torch.ones_like(q))
assert (grad - q.grad).abs().max() < 1e-5
Copy link
Contributor

@blefaudeux blefaudeux Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, this does not pass on a 3080, guess is because of tf32 vs. float32 (would be the same with a A100, not sure about tf32 on a V100)

int TILE_SIZEQ,
int TILE_SIZEK,
bool check_bounds>
__global__ void attention_backward_grad_v_kernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit harder for me to follow since it's not something that I dived into (vs. the fw pass) so I'm forced to skim a liltte.. looks clean as always

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, I should probably write some comments in the top of the functions with the implementation that this is actually doing, as a PyTorch code for ease of read. If you don't mind, I'd like to get this PR merged now, but I can add more comments with the K > 32 PR that I'm working on

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done > wild plans, sound good to me !

#pragma unroll
for (int q_item_idx = 0; q_item_idx < kBlockSizeQ; q_item_idx++) {
fact[kBlockSizeQ * threadIdx.x + q_item_idx]
[kBlockSizeK * threadIdx.y + k_item_idx] = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no faster call than that to zero the buffer ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean with memset, or with default initializer (like = {};) ?

It turns out that shared memory doesn't allow for default initializers, so I let all threads in the block to collaboratively zero the buffer. From some basic profiling, this part of the code doesn't seem to take any noticeable time, so I decided to leave it like that.

Also worth noting that I need to zero this only in the cases where the input sequence length is not a multiple of 32, otherwise there would be uninitialized values in there. Now that we have a template parameter in the kernel for this case, I could probably also put this under an if (check_bounds), but I wouldn't expect it to bring noticeable improvements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I was thinking about memset or similar, it's super weird that cuda does not offer a primitive to that ? Ok for the timing, could be that the compiler optimizes this away actually..

scalar_t normalizer[kBlockSizeQ];
scalar_t tmp_sum[kBlockSizeQ] = {0};

vec_t *qb[kBlockSizeQ], *kb[kBlockSizeK], *vb[kBlockSizeK], *gb[kBlockSizeQ],
Copy link
Contributor

@blefaudeux blefaudeux Apr 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the other one, collect all pointers first then proceed, ok. It's typical I guess but I'm not familiar enough probably

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it was to simplify a bit the handling of the code in other parts of the kernel. Maybe it saves on a couple of instructions, so it might be slightly faster to do this.

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks massively great to me, thanks @fmassa ! Couple of questions, mostly to understand better on my side and ask for possible follow ups down the line

@blefaudeux
Copy link
Contributor

blefaudeux commented Apr 23, 2022

if that helps, speed report on a 3080

[------------------- attention -------------------]
                           |  optimized  |  vanilla
1 threads: ----------------------------------------
      B=1, M=127, K=16     |      19.9   |     54.1
      B=1, M=127, K=32     |      25.2   |     49.5
      B=1, M=128, K=16     |      19.3   |     52.2
      B=1, M=128, K=32     |      19.9   |     50.1
      B=1, M=512, K=16     |      35.3   |     55.2
      B=1, M=512, K=32     |      63.7   |     63.0
      B=1, M=513, K=16     |      35.9   |     58.6
      B=1, M=513, K=32     |      64.4   |     63.7
      B=1, M=1023, K=16    |      69.6   |     55.1
      B=1, M=1023, K=32    |     127.3   |     59.6
      B=1, M=1024, K=16    |      66.0   |     58.8
      B=1, M=1024, K=32    |     121.7   |     65.7
      B=8, M=127, K=16     |      18.7   |     55.3
      B=8, M=127, K=32     |      25.4   |     49.9
      B=8, M=128, K=16     |      21.5   |     58.7
      B=8, M=128, K=32     |      20.1   |     59.3
      B=8, M=512, K=16     |      66.4   |    116.1
      B=8, M=512, K=32     |     126.5   |    121.6
      B=8, M=513, K=16     |      67.5   |    132.5
      B=8, M=513, K=32     |     130.3   |    145.7
      B=8, M=1023, K=16    |     256.2   |    427.2
      B=8, M=1023, K=32    |     537.2   |    452.1
      B=8, M=1024, K=16    |     249.7   |    397.4
      B=8, M=1024, K=32    |     518.8   |    410.4
      B=32, M=127, K=16    |      25.8   |     55.8
      B=32, M=127, K=32    |      48.9   |     50.7
      B=32, M=128, K=16    |      22.1   |     59.6
      B=32, M=128, K=32    |      39.5   |     59.2
      B=32, M=512, K=16    |     259.7   |    406.6
      B=32, M=512, K=32    |     534.8   |    425.6
      B=32, M=513, K=16    |     263.9   |    464.9
      B=32, M=513, K=32    |     544.2   |    510.8
      B=32, M=1023, K=16   |     948.7   |   1629.7
      B=32, M=1023, K=32   |    2027.9   |   1729.1
      B=32, M=1024, K=16   |     916.9   |   1510.2
      B=32, M=1024, K=32   |    1945.0   |   1573.7
      B=256, M=127, K=16   |     238.6   |    261.0
      B=256, M=127, K=32   |     403.1   |    339.7
      B=256, M=128, K=16   |     165.9   |    236.1
      B=256, M=128, K=32   |     316.9   |    261.1
      B=256, M=512, K=16   |    1870.8   |   3078.6
      B=256, M=512, K=32   |    3908.3   |   3236.4
      B=256, M=513, K=16   |    1927.2   |   3638.1
      B=256, M=513, K=32   |    4016.3   |   3922.3
      B=256, M=1023, K=16  |    7383.7   |  12697.6
      B=256, M=1023, K=32  |   15664.0   |  13622.2
      B=256, M=1024, K=16  |    7172.9   |  11887.2
      B=256, M=1024, K=32  |   15038.3   |  12385.9

Times are in microseconds (us).

@blefaudeux
Copy link
Contributor

Another follow up question is about the supported K, I think that 256 is typical in NLP / big models for instance, or 64 / 128 for the GPT series. Certainly not blocking, but it's a request bound to come up :)

@fmassa
Copy link
Contributor Author

fmassa commented Apr 25, 2022

Thanks for the timely review @blefaudeux !

Yes, the K > 32 case is definitely super important and I'm looking into how to address this. Naively bumping up BUFFER_SIZE is not ideal, and our kernels end up being 2x slower than the baseline, so there is definitely some room for improvement here.

Using TensorCores is also next in my optimization list, which I hope to get to after I'm back from holidays.

@fmassa fmassa merged commit a0fb375 into main Apr 25, 2022
@fmassa fmassa deleted the mem-eff-attn-backward-t2 branch April 25, 2022 09:05
@fmassa
Copy link
Contributor Author

fmassa commented Apr 25, 2022

BTW, I forgot to answer one of your comments:

how would it make the more sense to integrate it with the rest, for people who build from the registers

My thinking was to enable this directly in scaled_dot_product_attention, maybe through a global flag, so that all other attention mechanisms can benefit from it if enabled. I'll do it in a follow-up PR as well (once I get to the other tasks)

@blefaudeux
Copy link
Contributor

BTW, I forgot to answer one of your comments:

how would it make the more sense to integrate it with the rest, for people who build from the registers

My thinking was to enable this directly in scaled_dot_product_attention, maybe through a global flag, so that all other attention mechanisms can benefit from it if enabled. I'll do it in a follow-up PR as well (once I get to the other tasks)

sounds good to me, I would have done that also !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Add a fast implementation of Rabe and Staats algorigthm (mem efficient attention) on GPU
3 participants