Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vae_mnist increasing runtime and memory usage after each epoch #383

Open
MariusDrulea opened this issue Dec 27, 2022 · 5 comments · May be fixed by #408
Open

vae_mnist increasing runtime and memory usage after each epoch #383

MariusDrulea opened this issue Dec 27, 2022 · 5 comments · May be fixed by #408

Comments

@MariusDrulea
Copy link

MariusDrulea commented Dec 27, 2022

While training the vae_mnist example https://github.com/FluxML/model-zoo/blob/master/vision/vae_mnist/vae_mnist.jl, the runtime for each epoch is increasing from 4 minutes up to 1:22 hours, see the log. We expect a similar runtime for each epoch.

Tested with Flux 0.13.10, Windows 10 machine.

[ Info: Training on GPU
[ Info: Start Training, total 20 epochs
[ Info: Epoch 1
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:04:35
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_1.png
[ Info: Epoch 2
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:10:44
  loss:  173.78094
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_2.png
[ Info: Epoch 3
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:15:41
  loss:  185.91678
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_3.png
[ Info: Epoch 4
Progress:  30%|████████████████████████████████                                                                         |  ETA: 0:13:08
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:21:34
  loss:  186.1274
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_4.png
[ Info: Epoch 5
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:28:25
  loss:  181.45102
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_5.png
[ Info: Epoch 6
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:36:31
  loss:  179.04253
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_6.png
[ Info: Epoch 7
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:43:55
  loss:  180.57823
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_7.png
[ Info: Epoch 8
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:52:03
  loss:  173.47208
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_8.png
[ Info: Epoch 9
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:59:13
  loss:  176.1585
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_9.png
[ Info: Epoch 10
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 1:02:29
  loss:  168.37003
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_10.png
[ Info: Epoch 11
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 1:04:24
  loss:  160.70734
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_11.png
[ Info: Epoch 12
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 1:10:29
  loss:  161.37071
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_12.png
[ Info: Epoch 13
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 1:16:34
  loss:  172.15349
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_13.png
[ Info: Epoch 14
Progress: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 1:22:42
  loss:  175.96423
[ Info: Image saved: D:\z_installed_programs\julia-depot\datadeps\runs\vae_mnist\epoch_14.png
[ Info: Epoch 15
@MariusDrulea
Copy link
Author

For debugging I did the following modifications:

  1. I only use the first 100 images from the training set: xtrain, ytrain = MLDatasets.MNIST(split=:train)[1:100]
  2. @time the pullback call:
@time loss, back = Flux.pullback(encoder, decoder) do enc, dec
                                      model_loss(enc, dec, args.λ, x_dev, device)
                                end

As you can see in the log bellow, each call to Flux.pullback leads to additional memory allocations and the runtime also increases.

[ Info: Training on GPU
[ Info: Start Training, total 20 epochs
[ Info: Epoch 1
 60.692297 seconds (136.49 M allocations: 6.928 GiB, 3.03% gc time, 72.05% compilation time: 0% of which was recompilation)
  0.044060 seconds (13.73 k allocations: 641.859 KiB, 85.29% compilation time)
  0.052623 seconds (19.05 k allocations: 873.812 KiB, 84.56% compilation time)
  0.054776 seconds (24.38 k allocations: 1.074 MiB, 86.36% compilation time)
  0.064174 seconds (29.70 k allocations: 1.296 MiB, 87.24% compilation time)
  0.059060 seconds (35.02 k allocations: 1.520 MiB, 83.75% compilation time)
  0.060225 seconds (40.35 k allocations: 1.743 MiB, 83.00% compilation time)
  0.063138 seconds (45.67 k allocations: 1.966 MiB, 84.38% compilation time)
  0.059852 seconds (51.00 k allocations: 2.196 MiB, 82.92% compilation time)
  0.064041 seconds (56.32 k allocations: 2.419 MiB, 81.38% compilation time)
  0.060938 seconds (61.64 k allocations: 2.637 MiB, 82.00% compilation time)
  0.061926 seconds (66.97 k allocations: 2.860 MiB, 80.17% compilation time)
  0.078901 seconds (72.29 k allocations: 3.085 MiB, 80.99% compilation time)
  0.073115 seconds (77.61 k allocations: 3.307 MiB, 81.06% compilation time)
  0.073749 seconds (82.94 k allocations: 3.530 MiB, 80.76% compilation time)
  0.077940 seconds (88.26 k allocations: 3.754 MiB, 81.19% compilation time)
  0.092180 seconds (93.59 k allocations: 3.982 MiB, 79.64% compilation time)
  0.080315 seconds (98.91 k allocations: 4.201 MiB, 79.01% compilation time)
  0.082732 seconds (104.24 k allocations: 4.427 MiB, 79.66% compilation time)
  0.085479 seconds (109.56 k allocations: 4.647 MiB, 79.95% compilation time)
  0.082228 seconds (114.88 k allocations: 4.871 MiB, 79.05% compilation time)
  0.080974 seconds (120.21 k allocations: 5.094 MiB, 75.57% compilation time)
  0.082802 seconds (125.53 k allocations: 5.318 MiB, 78.26% compilation time)
  0.087348 seconds (130.85 k allocations: 5.541 MiB, 75.32% compilation time)
  0.089511 seconds (136.18 k allocations: 5.765 MiB, 75.46% compilation time)
 82.986604 seconds (190.15 M allocations: 9.660 GiB, 2.92% gc time, 72.29% compilation time: 0% of which was recompilation)
[ Info: Epoch 2
  0.088317 seconds (141.50 k allocations: 5.988 MiB, 74.94% compilation time)
  0.088266 seconds (146.83 k allocations: 6.211 MiB, 75.61% compilation time)
  0.095731 seconds (152.15 k allocations: 6.438 MiB, 76.10% compilation time)
  0.094834 seconds (157.47 k allocations: 6.658 MiB, 75.17% compilation time)
  0.095741 seconds (162.80 k allocations: 6.897 MiB, 74.25% compilation time)
  0.095184 seconds (168.12 k allocations: 7.105 MiB, 75.30% compilation time)
  0.100779 seconds (173.45 k allocations: 7.328 MiB, 73.37% compilation time)
  0.100133 seconds (178.77 k allocations: 7.552 MiB, 74.14% compilation time)
  0.098675 seconds (184.10 k allocations: 7.775 MiB, 72.11% compilation time)
  0.103099 seconds (189.42 k allocations: 7.998 MiB, 73.75% compilation time)
  0.109252 seconds (194.74 k allocations: 8.222 MiB, 74.28% compilation time)
  0.103832 seconds (200.07 k allocations: 8.453 MiB, 72.23% compilation time)
  0.109997 seconds (205.39 k allocations: 8.669 MiB, 70.23% compilation time)
  0.111814 seconds (210.72 k allocations: 8.892 MiB, 73.89% compilation time)
  0.110144 seconds (216.04 k allocations: 9.116 MiB, 71.72% compilation time)
  0.111723 seconds (221.36 k allocations: 9.339 MiB, 73.25% compilation time)
  0.113262 seconds (226.69 k allocations: 9.568 MiB, 72.66% compilation time)
  0.313881 seconds (266.02 k allocations: 11.329 MiB, 19.59% gc time, 33.01% compilation time)
  0.112271 seconds (237.34 k allocations: 10.009 MiB, 72.88% compilation time)
  0.123936 seconds (242.66 k allocations: 10.232 MiB, 71.37% compilation time)
  0.112173 seconds (247.98 k allocations: 10.456 MiB, 71.79% compilation time)
  0.119971 seconds (253.31 k allocations: 10.679 MiB, 71.44% compilation time)
  0.117311 seconds (258.63 k allocations: 10.903 MiB, 71.51% compilation time)
  0.120341 seconds (263.96 k allocations: 11.126 MiB, 72.03% compilation time)
  0.121589 seconds (269.28 k allocations: 11.349 MiB, 70.35% compilation time)
  4.103159 seconds (5.29 M allocations: 226.116 MiB, 1.98% gc time, 63.50% compilation time)
Progress:  10%|███████████                                                                                              |  ETA: 0:12:59[ Info: Epoch 3
  0.127100 seconds (274.60 k allocations: 11.573 MiB, 71.77% compilation time)
  0.126540 seconds (279.93 k allocations: 11.797 MiB, 70.30% compilation time)
  0.135527 seconds (285.25 k allocations: 12.020 MiB, 69.83% compilation time)
  0.132130 seconds (290.57 k allocations: 12.243 MiB, 70.75% compilation time)
  0.137595 seconds (295.90 k allocations: 12.467 MiB, 71.55% compilation time)
  0.131641 seconds (301.22 k allocations: 12.690 MiB, 71.53% compilation time)
  0.134207 seconds (306.55 k allocations: 12.913 MiB, 70.09% compilation time)
  0.143800 seconds (311.87 k allocations: 13.137 MiB, 70.83% compilation time)
  0.163375 seconds (317.19 k allocations: 13.360 MiB, 69.90% compilation time)
  0.247659 seconds (326.38 k allocations: 13.701 MiB, 21.07% gc time, 39.36% compilation time)
  0.139439 seconds (327.85 k allocations: 13.831 MiB, 68.89% compilation time)
  0.135705 seconds (333.17 k allocations: 14.038 MiB, 68.87% compilation time)
  0.136081 seconds (338.49 k allocations: 14.254 MiB, 70.95% compilation time)
  0.137029 seconds (343.82 k allocations: 14.477 MiB, 69.47% compilation time)
  0.178940 seconds (349.14 k allocations: 14.701 MiB, 67.51% compilation time)
  0.176940 seconds (354.46 k allocations: 14.924 MiB, 67.67% compilation time)
  0.199261 seconds (359.79 k allocations: 15.148 MiB, 68.39% compilation time)
  0.194970 seconds (365.11 k allocations: 15.383 MiB, 67.80% compilation time)
  0.184447 seconds (370.44 k allocations: 15.594 MiB, 66.97% compilation time)
  0.201620 seconds (375.76 k allocations: 15.818 MiB, 69.27% compilation time)
  0.201842 seconds (381.08 k allocations: 16.041 MiB, 66.59% compilation time)
  0.206759 seconds (386.41 k allocations: 16.264 MiB, 67.04% compilation time)
  0.212716 seconds (391.73 k allocations: 16.488 MiB, 66.91% compilation time)
  0.208171 seconds (397.06 k allocations: 16.711 MiB, 67.83% compilation time)
  0.298336 seconds (415.48 k allocations: 17.529 MiB, 18.96% gc time, 63.92% compilation time)
  5.591896 seconds (8.67 M allocations: 368.380 MiB, 2.37% gc time, 64.12% compilation time)

@MariusDrulea
Copy link
Author

MariusDrulea commented Dec 27, 2022

After some debugging, it seems the usage of L2 regularization in the model_loss function leads to this issue.
https://github.com/FluxML/model-zoo/blob/master/vision/vae_mnist/vae_mnist.jl#L64

If I comment out the regularization (simply set reg = 0), the issue no longer appears.

[ Info: Training on GPU
[ Info: Start Training, total 20 epochs
[ Info: Epoch 1
  0.292209 seconds (285.14 k allocations: 13.223 MiB, 98.55% compilation time: 48% of which was recompilation)
  0.003073 seconds (1.27 k allocations: 79.484 KiB)
  0.002721 seconds (1.27 k allocations: 79.484 KiB)
  0.002822 seconds (1.27 k allocations: 79.484 KiB)
  0.002602 seconds (1.27 k allocations: 79.484 KiB)
  0.003065 seconds (1.27 k allocations: 79.484 KiB)
  0.002056 seconds (1.27 k allocations: 79.484 KiB)
  0.002566 seconds (1.27 k allocations: 79.484 KiB)
  0.002887 seconds (1.27 k allocations: 79.484 KiB)
  0.002360 seconds (1.27 k allocations: 79.484 KiB)
  0.002748 seconds (1.27 k allocations: 79.484 KiB)
  0.002835 seconds (1.27 k allocations: 79.484 KiB)
  0.003085 seconds (1.27 k allocations: 79.484 KiB)
  0.002624 seconds (1.27 k allocations: 79.484 KiB)
  0.002892 seconds (1.27 k allocations: 79.484 KiB)
  0.002751 seconds (1.27 k allocations: 79.484 KiB)
  0.002232 seconds (1.27 k allocations: 79.484 KiB)
  0.002663 seconds (1.27 k allocations: 79.484 KiB)
  0.004064 seconds (1.27 k allocations: 79.484 KiB)
  0.002424 seconds (1.27 k allocations: 79.484 KiB)
  0.003199 seconds (1.27 k allocations: 79.484 KiB)
  0.002628 seconds (1.27 k allocations: 79.484 KiB)
  0.002049 seconds (1.27 k allocations: 79.484 KiB)
  0.002289 seconds (1.27 k allocations: 79.484 KiB)
  0.002424 seconds (1.27 k allocations: 79.484 KiB)
  0.722162 seconds (746.61 k allocations: 38.258 MiB, 60.29% compilation time: 46% of which was recompilation)
[ Info: Epoch 2
  0.002194 seconds (1.27 k allocations: 79.484 KiB)
  0.003316 seconds (1.27 k allocations: 79.484 KiB)
  0.003163 seconds (1.27 k allocations: 79.484 KiB)
  0.002398 seconds (1.27 k allocations: 79.484 KiB)
  0.002068 seconds (1.27 k allocations: 79.484 KiB)
  0.002256 seconds (1.27 k allocations: 79.484 KiB)
  0.002365 seconds (1.27 k allocations: 79.484 KiB)
  0.003138 seconds (1.27 k allocations: 79.484 KiB)
  0.002814 seconds (1.27 k allocations: 79.484 KiB)
  0.002353 seconds (1.27 k allocations: 79.484 KiB)
  0.002757 seconds (1.27 k allocations: 79.484 KiB)
  0.002572 seconds (1.27 k allocations: 79.484 KiB)
  0.001917 seconds (1.27 k allocations: 79.484 KiB)
  0.003073 seconds (1.27 k allocations: 79.484 KiB)
  0.002455 seconds (1.27 k allocations: 79.484 KiB)
  0.002445 seconds (1.27 k allocations: 79.484 KiB)
  0.002638 seconds (1.27 k allocations: 79.484 KiB)
  0.003332 seconds (1.27 k allocations: 79.484 KiB)
  0.002055 seconds (1.27 k allocations: 79.484 KiB)
  0.003426 seconds (1.27 k allocations: 79.484 KiB)
  0.001962 seconds (1.27 k allocations: 79.484 KiB)
  0.002117 seconds (1.27 k allocations: 79.484 KiB)
  0.003524 seconds (1.27 k allocations: 79.484 KiB)
  0.001779 seconds (1.27 k allocations: 79.484 KiB)
  0.002332 seconds (1.27 k allocations: 79.484 KiB)
  0.288302 seconds (120.02 k allocations: 8.372 MiB)

@ToucheSir
Copy link
Member

What concerns me more is that there seems to be compilation happening on every batch? FluxML/Flux.jl#2040 used more or less the same code as the old MNIST VAE model, so I don't know what could be the culprit.

For the regularization part, this is a good opportunity to remove it from the loss function and use the WeightDecay optimization rule instead. That should simultaneously remove the last use of implicit params from this model and make everything a little faster to boot.

@MariusDrulea
Copy link
Author

MariusDrulea commented Dec 27, 2022

Indeed, the issue is in this line: reg = λ * sum(x->sum(x.^2), Flux.params(decoder))

See bellow a MWE for the issue. loss_fun_slow calls Flux.params(model) each iteration and leads to the mentioned issue. loss_fun uses the very same ps each iteration. The values of the gradients are the same, so probably loss_fun also performs regularization. Is loss_fun guaranteed to always perform regularization?

using Flux

model = Dense(2, 2)

xt = rand(Float32, 2, 4) # batch size of 4
yt = rand(Float32, 2, 4)

ps = Flux.params(model)
loss_fun(m, x, y) = sum((m(x) - y).^2) + sum(p->sum(p.^2), ps)

loss_fun_slow(m, x, y) = sum((m(x) - y).^2) + sum(p->sum(p.^2), Flux.params(model))

for i in 1:5
    @time ∇m = gradient(m->loss_fun(m, xt, yt), model)    
    @time ∇m_slow = gradient(m->loss_fun_slow(m, xt, yt), model)    
end

Here is the log:

 21.642858 seconds (60.17 M allocations: 3.051 GiB, 4.18% gc time, 99.96% compilation time)
  4.606277 seconds (11.97 M allocations: 591.543 MiB, 2.79% gc time, 99.80% compilation time)
  0.000475 seconds (528 allocations: 26.016 KiB)
  0.018889 seconds (4.08 k allocations: 208.984 KiB, 92.57% compilation time)
  0.000291 seconds (528 allocations: 26.016 KiB)
  0.018429 seconds (5.18 k allocations: 256.828 KiB, 92.63% compilation time)
  0.000232 seconds (528 allocations: 26.016 KiB)
  0.017229 seconds (6.28 k allocations: 303.641 KiB, 91.83% compilation time)
  0.000310 seconds (528 allocations: 26.016 KiB)
  0.017961 seconds (7.38 k allocations: 350.672 KiB, 91.94% compilation time)

@lucifer1702
Copy link

Hello everyone,
as per the advice of other mentors I would like to start to contributing in FLUXML . I want to contribute to this issue so can I know where to start and what tech stacks other than Julia I should be familiar with.

@mcabbott mcabbott linked a pull request Mar 21, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants