Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use faster version of batched matmul #820

Merged
merged 2 commits into from
Aug 4, 2024
Merged

feat: use faster version of batched matmul #820

merged 2 commits into from
Aug 4, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Aug 4, 2024

relaxing the tolerances since Zygote uses more optimized paths.

@avik-pal avik-pal force-pushed the ap/faster_bmm branch 2 times, most recently from 8e0e37a to 48c6ae5 Compare August 4, 2024 20:09
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: a1fba59 Previous: a0095bb Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3651.75 ns 4210.375 ns 0.87
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 8027.416666666667 ns 7280.857142857143 ns 1.10
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20588 ns 21329.5 ns 0.97
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9732.2 ns 9790.4 ns 0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9004.8 ns 9162.25 ns 0.98
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4500.875 ns 4474.625 ns 1.01
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 4993 ns 4985.5 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 2351.3 ns 1582 ns 1.49
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 2378.4 ns 1464.8 ns 1.62
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1789.6296296296296 ns 1802.0566037735848 ns 0.99
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.91456582633054 ns 179.46036161335186 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17172 ns 17282 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 34624 ns 13104 ns 2.64
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 38982 ns 37641 ns 1.04
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29796 ns 29095 ns 1.02
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 21520 ns 21600 ns 1.00
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17312 ns 17072 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 25668 ns 25698 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 20498 ns 2609.3333333333335 ns 7.86
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 20639 ns 2913.222222222222 ns 7.08
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4900.571428571428 ns 4874.857142857143 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1681.2 ns 1654.1 ns 1.02
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 102586483.5 ns 88901650 ns 1.15
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 76039377 ns 76591182 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 174809159.5 ns 145816642 ns 1.20
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 208167806 ns 174645570 ns 1.19
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 195714546 ns 162467847 ns 1.20
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11592473 ns 11901058 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 242481509.5 ns 195200307.5 ns 1.24
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 16238642 ns 15544763.5 ns 1.04
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 16217379.5 ns 15554299 ns 1.04
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 33076614 ns 42589068.5 ns 0.78
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6376350 ns 6385176 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 1065426109.5 ns 1065642865.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2936169633 ns 2933656388 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 186205517 ns 169506774 ns 1.10
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 1325428110 ns 1395702878 ns 0.95
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 4018793221 ns 3742059235 ns 1.07
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 385413460 ns 384902400.5 ns 1.00
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 1637511141 ns 1673204957.5 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 4312602384 ns 4683819068 ns 0.92
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 428282138 ns 459579475 ns 0.93
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 367986266 ns 361855245 ns 1.02
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 896733901 ns 884689737 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 55350471 ns 56182817 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 361324315 ns 364122724 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 902340480 ns 880222233.5 ns 1.03
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 54719444 ns 55637947 ns 0.98
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 526829519 ns 562801142 ns 0.94
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 1429045300 ns 1517562171 ns 0.94
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 183593037 ns 171205515 ns 1.07
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1335362233.5 ns 1254741222 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1603843274.5 ns 1561957190.5 ns 1.03
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2293438670 ns 2274530908 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2407320619 ns 2467776572 ns 0.98
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 2371461967.5 ns 2236433184 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 2238199868 ns 2124313017 ns 1.05
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 285054881 ns 285138498 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 287576903.5 ns 284816604 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 462667224 ns 426557087 ns 1.08
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11786808.5 ns 11831175 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 35054183 ns 34523223 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 16420838 ns 16447087 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 20992218 ns 21042971 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 15255953.5 ns 15151300 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1148590 ns 1151983 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 36178330.5 ns 35724522 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 4770600 ns 4546295 ns 1.05
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 4780398.5 ns 4524386 ns 1.06
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 1979994.5 ns 1970497 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 196647 ns 197294 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 411798 ns 380051 ns 1.08
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 302143 ns 222526 ns 1.36
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 379617 ns 384974.5 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 553903 ns 528056 ns 1.05
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 292610.5 ns 294290 ns 0.99
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 406167 ns 408554 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 428739 ns 429642.5 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 152093 ns 70512 ns 2.16
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 162803.5 ns 70061 ns 2.32
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 91871 ns 92674 ns 0.99
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104465 ns 104425 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 332682645.5 ns 305601445 ns 1.09
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 286008512 ns 288797620.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 604381460 ns 544934785.5 ns 1.11
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 697162939 ns 589046818 ns 1.18
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 600816225 ns 552663074 ns 1.09
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 318902910.5 ns 319651032.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 651911236 ns 660000513 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 39491132.5 ns 39489513.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 39424549.5 ns 39313659 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 117159418 ns 99796304 ns 1.17
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 27898012 ns 28394041 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 23232598 ns 21157343.5 ns 1.10
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 17223273 ns 16678967 ns 1.03
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 22762205.5 ns 22876715 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 28422179 ns 28080465 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19295901.5 ns 19385408 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 20691216.5 ns 20728329 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6119938 ns 5545671.5 ns 1.10
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6071009.5 ns 5626649 ns 1.08
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6526233 ns 6521885 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal merged commit b6171a6 into main Aug 4, 2024
4 of 6 checks passed
@avik-pal avik-pal deleted the ap/faster_bmm branch August 4, 2024 23:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant