-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: use faster version of batched matmul #820
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
avik-pal
force-pushed
the
ap/faster_bmm
branch
2 times, most recently
from
August 4, 2024 20:09
8e0e37a
to
48c6ae5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: a1fba59 | Previous: a0095bb | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3651.75 ns |
4210.375 ns |
0.87 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
8027.416666666667 ns |
7280.857142857143 ns |
1.10 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
20588 ns |
21329.5 ns |
0.97 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9732.2 ns |
9790.4 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
9004.8 ns |
9162.25 ns |
0.98 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4500.875 ns |
4474.625 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
4993 ns |
4985.5 ns |
1.00 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
2351.3 ns |
1582 ns |
1.49 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
2378.4 ns |
1464.8 ns |
1.62 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1789.6296296296296 ns |
1802.0566037735848 ns |
0.99 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
179.91456582633054 ns |
179.46036161335186 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17172 ns |
17282 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
34624 ns |
13104 ns |
2.64 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
38982 ns |
37641 ns |
1.04 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
29796 ns |
29095 ns |
1.02 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
21520 ns |
21600 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
17312 ns |
17072 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
25668 ns |
25698 ns |
1.00 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
20498 ns |
2609.3333333333335 ns |
7.86 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
20639 ns |
2913.222222222222 ns |
7.08 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4900.571428571428 ns |
4874.857142857143 ns |
1.01 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1681.2 ns |
1654.1 ns |
1.02 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
102586483.5 ns |
88901650 ns |
1.15 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
76039377 ns |
76591182 ns |
0.99 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
174809159.5 ns |
145816642 ns |
1.20 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
208167806 ns |
174645570 ns |
1.19 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
195714546 ns |
162467847 ns |
1.20 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
11592473 ns |
11901058 ns |
0.97 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
242481509.5 ns |
195200307.5 ns |
1.24 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
16238642 ns |
15544763.5 ns |
1.04 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
16217379.5 ns |
15554299 ns |
1.04 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
33076614 ns |
42589068.5 ns |
0.78 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6376350 ns |
6385176 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
1065426109.5 ns |
1065642865.5 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2936169633 ns |
2933656388 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
186205517 ns |
169506774 ns |
1.10 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
1325428110 ns |
1395702878 ns |
0.95 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
4018793221 ns |
3742059235 ns |
1.07 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
385413460 ns |
384902400.5 ns |
1.00 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
1637511141 ns |
1673204957.5 ns |
0.98 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
4312602384 ns |
4683819068 ns |
0.92 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
428282138 ns |
459579475 ns |
0.93 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
367986266 ns |
361855245 ns |
1.02 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
896733901 ns |
884689737 ns |
1.01 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
55350471 ns |
56182817 ns |
0.99 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
361324315 ns |
364122724 ns |
0.99 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
902340480 ns |
880222233.5 ns |
1.03 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
54719444 ns |
55637947 ns |
0.98 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
526829519 ns |
562801142 ns |
0.94 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
1429045300 ns |
1517562171 ns |
0.94 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
183593037 ns |
171205515 ns |
1.07 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1335362233.5 ns |
1254741222 ns |
1.06 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1603843274.5 ns |
1561957190.5 ns |
1.03 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2293438670 ns |
2274530908 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2407320619 ns |
2467776572 ns |
0.98 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
2371461967.5 ns |
2236433184 ns |
1.06 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
2238199868 ns |
2124313017 ns |
1.05 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
285054881 ns |
285138498 ns |
1.00 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
287576903.5 ns |
284816604 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
462667224 ns |
426557087 ns |
1.08 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11786808.5 ns |
11831175 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
35054183 ns |
34523223 ns |
1.02 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
16420838 ns |
16447087 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
20992218 ns |
21042971 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
15255953.5 ns |
15151300 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1148590 ns |
1151983 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
36178330.5 ns |
35724522 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
4770600 ns |
4546295 ns |
1.05 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
4780398.5 ns |
4524386 ns |
1.06 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
1979994.5 ns |
1970497 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
196647 ns |
197294 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
411798 ns |
380051 ns |
1.08 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
302143 ns |
222526 ns |
1.36 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
379617 ns |
384974.5 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
553903 ns |
528056 ns |
1.05 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
292610.5 ns |
294290 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
406167 ns |
408554 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
428739 ns |
429642.5 ns |
1.00 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
152093 ns |
70512 ns |
2.16 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
162803.5 ns |
70061 ns |
2.32 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
91871 ns |
92674 ns |
0.99 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104465 ns |
104425 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
332682645.5 ns |
305601445 ns |
1.09 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
286008512 ns |
288797620.5 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
604381460 ns |
544934785.5 ns |
1.11 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
697162939 ns |
589046818 ns |
1.18 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
600816225 ns |
552663074 ns |
1.09 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
318902910.5 ns |
319651032.5 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
651911236 ns |
660000513 ns |
0.99 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
39491132.5 ns |
39489513.5 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
39424549.5 ns |
39313659 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
117159418 ns |
99796304 ns |
1.17 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
27898012 ns |
28394041 ns |
0.98 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
23232598 ns |
21157343.5 ns |
1.10 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
17223273 ns |
16678967 ns |
1.03 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
22762205.5 ns |
22876715 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
28422179 ns |
28080465 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19295901.5 ns |
19385408 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
20691216.5 ns |
20728329 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6119938 ns |
5545671.5 ns |
1.10 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6071009.5 ns |
5626649 ns |
1.08 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6526233 ns |
6521885 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
relaxing the tolerances since Zygote uses more optimized paths.