-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Triton Flash Attention #479
Conversation
11c3524
to
bbff21d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR Diana!
Can you add some tests to this backend? You basically only need to add the TritonFlashAttentionOp
in
xformers/tests/test_mem_eff_attention.py
Line 81 in e31c571
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp, |
Also, can you run the benchmarks in https://github.com/facebookresearch/xformers/blob/main/xformers/benchmarks/benchmark_mem_eff_attention.py with this backend so that we can compare to the existing ones?
6ec9fc0
to
6280055
Compare
hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that |
Thanks @blefaudeux, that's a good point. I have the updated Triton locally and still facing some errors but still that's needed. Do you think you'll push the changes you have or should I look into it? |
#483 should help, it should accept any modern triton pip package ! |
79eeda0
to
3f362bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! The results really look amazing!
Would be curious to see if we can use triton's fwd with flash's bwd for instance to get the best of both worlds (not in this PR)
Updated Numbers Performance Compared to Vanilla FWD
Performance Compared to Vanilla BWD
|
3c5496f
to
d8c610c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR Diana!
I have some comments. Also, if there were no changes to the triton implementation from the one present in flashattention, would it make sense to just use it directly, instead of copying its code?
We already have triton compiled and installed (for the CUDA dependency), so we could directly call into its Python API. What do you think?
def supports(cls, d: "AttentionOpDispatch") -> bool: | ||
if not has_triton_flashattention: | ||
return False | ||
device_capability = torch.cuda.get_device_capability(d.device) | ||
is_sm80 = device_capability[0] >= 8 | ||
if not is_sm80: | ||
return False | ||
return super(TritonFlashAttentionOp, cls).supports(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the triton implementation supports all values of K
, is that right? Can you make the test_mem_eff_attention
run on this operator so that we double-check that this all works as expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it supports all values of K until 128. I ran the tests and they're all passing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, so it also supports K which is not a multiple of 8? cc @danthe3rd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I ran the tests locally and it passes for all the values of K less than 128, even those which are not a multiple of 8
ebbee1f
to
a5786d3
Compare
Thanks a lot for the reviews @fmassa and @danthe3rd ! I've made some changes and added an op for Triton fwd with Flash bwd |
Forwards for Triton fwd and Flash bwd:
Backwards for Triton fwd and Flash bwd:
Times are in microseconds (us). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, thanks a lot Diana!
I've added a few more comments that can be addressed in the future, but we can get started with this for now!
requirements-test.txt
Outdated
|
||
# Dependency for triton flash attn | ||
flash-attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is needed given that we have flash-attn
as a submodule, which is setup in our build system?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we would need to change our build scripts to also install flash-attn
. Otherwise we might have this override our version of _C_flashattn
that we compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, all tests in CI from Triton seem to have been skipped. Worth looking into this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they're skipping because I'm requiring sm80 for the Triton implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added flash-attention as a dependency in setup.py, is that what you had in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they're skipping because I'm requiring sm80 for the Triton implementation
Indeed we don't have sm80 on the CI. But it should work with Sm75 (we have it in the CI)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added sm75 in #556, will check if it works!
def supports(cls, d: "AttentionOpDispatch") -> bool: | ||
if not has_triton_flashattention: | ||
return False | ||
device_capability = torch.cuda.get_device_capability(d.device) | ||
is_sm80 = device_capability[0] >= 8 | ||
if not is_sm80: | ||
return False | ||
return super(TritonFlashAttentionOp, cls).supports(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, so it also supports K which is not a multiple of 8? cc @danthe3rd
@@ -80,6 +80,8 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): | |||
xformers.ops.MemoryEfficientAttentionCutlassOp, | |||
xformers.ops.MemoryEfficientAttentionFlashAttentionOp, | |||
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp, | |||
xformers.ops.TritonFlashAttentionOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the future, it might be worth adding a test for checking about race conditions. This has been illustrated in this comment in flashattention, and I know that @danthe3rd has had issues with race conditions in the past so might be good to extend our tests to cover this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They already have testing for race condition in the HazyResarch repo here but maybe makes sense to add it if we want the same testing for other implementations as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check that a bit by running with very large batches, to ensure the GPU is saturated (so missing __syncthreads()
cause wrong results). We could also have something similar to what flash is doing tho
MemoryEfficientAttentionTritonFwdFlashBwOp, | ||
TritonFlashAttentionOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we change the order of priority in here so that TritonFwdFlashBwdOp
is dispatched more often?
Maybe writing another _is_triton_faster_than_..
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I added a method for now but haven't filled it out. It seems like triton is faster than cutlass for most except for the following:
B=64, M=128, H=16, K=128
B=64, M=128, H=16, K=64
B=64, M=128, H=16, K=16
B=16, M=128, H=16, K=128
B=16, M=128, H=16, K=64
B=16, M=128, H=16, K=32
B=16, M=128, H=16, K=16
B=1024, M=82, H=8, K=64
B=16, M=197, H=16, K=128
B=16, M=197, H=16, K=64
B=32, M=197, H=16, K=128
B=32, M=197, H=16, K=64
B=384, M=197, H=1, K=64
@danthe3rd do you know for which cases in general we should expect cutlass to be faster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know - we could be conservative for now (eg keep the old behavior) and enable it one-by-one if we see opportunities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll post in the xFormers group so people know how to try it out if they want
@fmassa Thanks a lot for the helpful comments and for taking another pass! I've updated with related changes. Okay to merge for now? |
Codecov ReportBase: 89.79% // Head: 89.06% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #479 +/- ##
==========================================
- Coverage 89.79% 89.06% -0.74%
==========================================
Files 80 80
Lines 4839 4927 +88
==========================================
+ Hits 4345 4388 +43
- Misses 494 539 +45
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot Diana! This looks great. A few nits that can be addressed later - let's get this merged :)
del flatten_diff | ||
assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( | ||
f"{msg}: " | ||
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" | ||
f"/ atol={atol}, rtol={rtol}" | ||
f"/ total failing elements: {num_different}, percentage={percentage}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
MemoryEfficientAttentionCutlassOp, | ||
TritonFlashAttentionOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we should include the Triton backward into the priority list, unless we are confident it works fine without correctness issues (Tri wasn't really sure about that).
Since it seems properly tested in the tests, I guess we should be good to go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I removed it in #556 for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's get this merged.
Thanks Diana!
@@ -267,6 +267,7 @@ def run(self): | |||
version=version, | |||
install_requires=fetch_requirements(), | |||
packages=setuptools.find_packages(exclude=("tests", "tests.*")), | |||
dependency_links=["file:///./third_party/flash-attention#egg=flash-attention"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danthe3rd let's test this afterwards. I think we will need to change the way we depend on flashattention in our code
What does this PR do?
Adds Triton Flash Attention
Performance Compared to Vanilla
Performance Compared to MemoryEfficientAttentionCutlassFwdFlashBwOp
TODO:
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.