-
Notifications
You must be signed in to change notification settings - Fork 632
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Memory efficient attention - backward pass (#281)
* Add naive CPU implementation for memory-efficient attention backward Needs cleanup * Optimize (at least) by a factor 2 * More cleanups * A few more comments * Add very naive CUDA implementation It's super slow! * Speedup CUDA kernel by 5x But we still have a long way to go * Make logsumexp an argument * Make it 30% faster Merge two loops together and use local buffers for accumulation and grad_q. The use of local buffers as is currently introduces limitations on the sizes of dimension K * 3.5x speedup by blocking strategy * Use vector loads and improve tile selection Brings an extra 2x improvement * Recompute attention for grad_q computation Makes it another 20% faster, and doesn't use extra memory * Smal cleanups * clang-format * Make it 0.5% faster Use all threads to compute grad_q * Make it 1% faster by caching the loads * Make it 6% faster with better hyperparameters * Slightly better hyperparameter * axpy == FMA * Separate grad_q into its own kernel This brings 50% speedup compared to the previous approach, despite redundant computation. The benefit comes from the fact that we are using better block sizes for the matmul computation of grad_q, which doesnt involve the transpose of the attention matrix * Avoid additional global writes by recomputing grad_aatn_v in grad_k Brings an additional 12% speedup despite duplicate computation * Trying out new idea * Almost on par with my previous best implementation * Improve perf by 5% Potentially due to avoiding bank conflicts? * Remove query-key from shared memory and increase tile size Brings 10% improvement, being better than my previous best version * Make it 20% faster with better hyperparameters This is now significantly faster than what we had before, and is even faster than the vanilla implementation * Make it another 12% faster This is now 18% faster than the vanilla implementation * Code cleanup * Further cleanups Remove previous implementation * Variable rename * clang-format * Add alternative implementation for grad_v So far it has exactly the same speed as the previous kernel, but is much more similar to the grad_q and grad_k kernels * Speed it up by 10% with better hyperparameters * Delete old implementation * Centralize all input accesses in the beginning This will make it easier to support inputs which are not multiple of 32. Plus, this seems to give a small performance improvement, in the order of 1% * Bugfix Only shows up for certain sizes for some reason * Make kernels generic wrt sequence length This introduces a slowdown of 25%, mostly due to the index computation in the preamble of each kernel. In a next commit I'll try to optimize this out * Add template argument to skip bound checking Brings back speed to where it was, for the cases where we can safely skip this * Make it support all use-cases * Let logsumexp be returned by forward Also add an autograd Function for backward * clang-format * Add scaling factor * Add tests + silly bugfix * Add benchmark function for backward * Add comment * clang-format
- Loading branch information
Showing
6 changed files
with
1,037 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.