-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[perf] Fused linear : small FW cleanup and much better perfs #283
Conversation
0904eb1
to
cbc97eb
Compare
There's a lot of speed left in the BW, with a fresher pair of eyes it's obvious. This is a very small PR which already brings some decent speed, all from the FW pass. Fist bump @dianaml0, I know you were hoping for more perf there recently. I looked into that while contemplating a MHA - self attention dedicated projection kernel, which would be very similar (without the activation, and dispatching into 3 buffers) |
cbc97eb
to
3f168e1
Compare
Ideally @dianaml0 the backward pass should be improved, and the FusedMLP could revert to using this instead of the not-super-impactful fused dropout |
Codecov Report
@@ Coverage Diff @@
## main #283 +/- ##
=======================================
Coverage 92.72% 92.72%
=======================================
Files 61 61
Lines 3407 3407
=======================================
Hits 3159 3159
Misses 248 248
Flags with carried forward coverage won't be shown. Click here to find out more. Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome results!
@blefaudeux Hey Ben! I'm curious: did you run this on an end-to-end benchmark? I've been able to recreate the increased FLOPs in a timing measurement, but it performs roughly the same as a PyTorch vanilla linear + GeLU layer on a BERT-Base (110M) on an A100 w/ 80GB VRAM. Is this expected? I did install As always, great work! I expect to be a heavy |
hey @moinnadeem thanks for the check and message ! I've been willing to follow up on that for some time (this branch and a few others), but could not find the time. There are a couple of caveats, some of which I discovered after this PR, I should probably alter the curves:
|
#296 added to make sure that I don't forget to update the curves. Ideally I would love to check in a small perf bump at the same time |
I'm also beginning to wonder whether this is related to pytorch/pytorch#76509 |
#76509 landed yesterday, so couldn't affect earlier results. Also, if affects only fp32 matmul ops, not fp16. |
fair enough, I was wondering if in some of the curves I was comparing triton (tf32 internally) vs. pytorch (fp32 at some point, then tf32, now back to fp32). I meant to reference the issue to say that I could have tripped over one of these changes (I was not always tracking pytorch dev, although our CI is) |
Pytorch was always tf32 until yesterday it became fp32. But for fp16 inputs it doesn't matter at all, and I see only fp16 curves. |
What does this PR do?
Improves the perfs for the fused linear, mostly the forward pass, but measurable effects all around. The BW pass could be greatly improved, tons of perf on the table, I'll try to work a bit on that in the coming days. This would be applicable for MLP and MHA projection, if we end up writing a dedicated self-attention projection kernel (the op is wx+b, same as fused linear without the activation)
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.