Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improvements to the Training API #794

Merged
merged 7 commits into from
Jul 26, 2024
Merged

feat: improvements to the Training API #794

merged 7 commits into from
Jul 26, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Jul 26, 2024

Overview

  • Semvar guarantees for the training functionality 🎉
  • Drops support for ancient ADTypes

Tasks

  • move training functions into Training module
  • deprecate accesses via Experimental
  • update the training overloads in extensions
  • update documentation
  • update the tests do this in Aggregate changes for v1 #744 to avoid accidentally breaking anything

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 9e262df Previous: 6a4453d Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 4056.25 ns 3733.6666666666665 ns 1.09
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7515.5 ns 7376.714285714285 ns 1.02
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20688 ns 20980 ns 0.99
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 10012.6 ns 9760 ns 1.03
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9184.5 ns 8970.7 ns 1.02
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4432 ns 4422.666666666667 ns 1.00
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 4597.375 ns 4676.25 ns 0.98
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1113.6853146853148 ns 1107.1069182389938 ns 1.01
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1189.2720588235295 ns 1164.4285714285713 ns 1.02
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1838.4166666666667 ns 1789.4313725490197 ns 1.03
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.16806722689077 ns 179.6459802538787 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 18795 ns 17112 ns 1.10
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17032 ns 16811 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37771 ns 37100 ns 1.02
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 30497 ns 28213 ns 1.08
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 21900 ns 19947 ns 1.10
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17473 ns 17192 ns 1.02
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 25537 ns 25488 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3921 ns 3822.25 ns 1.03
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3975 ns 3913.625 ns 1.02
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 5075.214285714286 ns 4784.714285714285 ns 1.06
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1658 ns 1651.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 42715187 ns 39070284 ns 1.09
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 59252224 ns 58211891 ns 1.02
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 76605959 ns 77724823 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 92393949 ns 89555817 ns 1.03
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 76241101.5 ns 88370701 ns 0.86
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12219173 ns 11594550 ns 1.05
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 92551710 ns 91959934 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7755123 ns 7684647 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7684995 ns 7572126.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 12450342 ns 9887620.5 ns 1.26
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6429271 ns 6379262 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 708450329 ns 680001467 ns 1.04
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2571225675 ns 2574834317 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 150479252 ns 133556588.5 ns 1.13
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 847523054 ns 832059382 ns 1.02
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3377075825 ns 2940015627 ns 1.15
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 238610051 ns 219247861 ns 1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 805012617 ns 712943058.5 ns 1.13
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 3019494444 ns 2615778063 ns 1.15
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 151177313.5 ns 129342095 ns 1.17
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 176370364 ns 175664907 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 662307484 ns 664305664.5 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34934788 ns 45501101 ns 0.77
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 167755377 ns 165819611 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 655269720 ns 651918836 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30550807 ns 30062079 ns 1.02
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 214915437 ns 186124645.5 ns 1.15
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 808581895 ns 724968189.5 ns 1.12
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 36213817 ns 36094692 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1332137811.5 ns 1304162516.5 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1882466726 ns 1868331598 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2590300918 ns 2245902418 ns 1.15
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2667258422 ns 2495872396 ns 1.07
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 2027863164 ns 1950101995.5 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 2060681474 ns 2157665806 ns 0.96
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 339635505 ns 330073872 ns 1.03
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 348360474 ns 327700026 ns 1.06
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 412072993.5 ns 458706730 ns 0.90
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12069816 ns 11936295 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18294213 ns 17958576 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19673974 ns 19018839.5 ns 1.03
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 24190035 ns 23737989 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18112221 ns 17752292.5 ns 1.02
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1173037.5 ns 1158352 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 23393449.5 ns 22921221 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2435332.5 ns 2438307 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2253434 ns 2207289 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2080020 ns 2059552 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 205463 ns 193042 ns 1.06
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 295652 ns 290375 ns 1.02
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 267123.5 ns 264375 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 370346 ns 365440 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 410906 ns 409918 ns 1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 277397.5 ns 280316 ns 0.99
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 406659 ns 405440 ns 1.00
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 396489 ns 396112 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81452 ns 81703 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 82483 ns 83426 ns 0.99
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87784 ns 88967 ns 0.99
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104405 ns 104306 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 195005827 ns 199911623 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 330439675 ns 330582216.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 451522213 ns 431440040 ns 1.05
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 447192572.5 ns 483202494 ns 0.93
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 412209783 ns 390600061 ns 1.06
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 335635213 ns 329500710.5 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 450442113 ns 476877423.5 ns 0.94
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 47402593 ns 47470381 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 46935602 ns 46946569 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 58767138.5 ns 60012110 ns 0.98
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28604401 ns 27667373.5 ns 1.03
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19596106.5 ns 19236897 ns 1.02
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19778809 ns 19653638 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23973413.5 ns 23866747 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24534127 ns 24431304 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19865420 ns 19723207.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 21243792 ns 21004501 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6606106 ns 6534005 ns 1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6620508 ns 6514002 ns 1.02
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6728845 ns 6501944.5 ns 1.03

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal merged commit 2a55829 into main Jul 26, 2024
60 of 64 checks passed
@avik-pal avik-pal deleted the ap/training branch July 26, 2024 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant