-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Move CUDA support to a package extension #2132
Conversation
c93aa3b
to
bf3d95f
Compare
Excellent! We should still maintain |
One question, I structured this so that you need to load both Alternatively |
I prefer the latter since it enables a nicer user experience. Does the added precompilation cost a tonne of time? |
Ah.. CUDA is a dep of NNlibCUDA.. it would have to be NNlibCUDA if only one trigger package. Updated above now |
Is there a way we could have CUDA.jl as the only trigger dep? Up to this point we've been acting on the model that users shouldn't have to know NNlibCUDA even exists, so if possible it would be great to keep that. |
@KristofferC would that be possible? NNlibCUDA deps on CUDA so AFAIK it's not possible |
If needed, any and all changes required on the NNlibCUDA side should be on the table. Would making its CUDA dependency also weak be enough? |
Maybe absorb it into NNlib and make CUDA a weak dep of NNlib? |
Note that Registrator is not updated yet to handle weak deps. |
Two suggestions, although the first one is likely unrealizable:
|
I haven't thought through the implications of this design, but yes it should be possible via something like
|
Wouldn't that incur in world age issues? That is the limitation of the approach (not the same) in LazyModules.jl |
It would be good to have a concrete usage example here. Like maybe a toy session in a REPL to see how a user would use it. And then based on that it might be easier to come up with a good design. |
Ideally, the following could happen, which is also non-breaking: julia> using Flux # only cpu code imported, CUDA.jl and NNlibCUDA.jl are not loaded yet
julia> x = [1.0,2.0,3.0];
3-element Vector{Float64}:
1.0
2.0
3.0
julia> gpu(x) # here the gpu code and libraries are loaded if the device has cuda capabilities
3-element CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
1.0
2.0
3.0
# successive calls to `gpu` are fast |
Even something like the current behaviour of warning and passing the value through when CUDA isn't loaded could work, IMO.
This is theoretically possible and something we always planned to do, but nobody expected weak deps to land this quickly. The main considerations would be whether a) base NNlib and Flux load times are affected (with or without CUDA loaded), and b) whether TTFX for CUDA functionality in NNlib is impacted. If neither regresses, then it should be an easy sell. |
This seems like the obvious goal. Right now, with no graphics card, I don't like the idea of making a function call magically load the package CUDA. Making behaviour depend on your environment outside of what you've loaded seems weird. |
Tests are passing now, btw |
Locally, it looks like I can have just |
Did you remove NNlibCUDA from [weakdeps]? |
Yes, but I don't know if it matters. It must remain in [deps] for now I think. |
Well that means it's not a weak dep, so it and CUDA will be downloaded, installed and precompiled. You do avoid loading it until the extension loads, but personally I'm motivated by both avoiding installing CUDA etc and loading it. |
For more context, using Flux on a resource constrained system like a small embedded SBC without a gpu is painful because of all of the above, not just load time |
So a package listed under both deps and weakdeps (as in the present state of this PR), is effectively deleted from deps. I'm a bit surprised that the design doesn't allow for loading 3rd packages, but maybe this is hard, didn't follow closely. Not downloading would be nice too. But requiring anyone to know about this obscure thing called NNlibCUDA seems like a step backwards. Moving it to be an extension of NNlib is probably the way to go then. |
Having NNlibCUDA as an extension of NNlib seems the way to go. Also, we finally move back that code to the original repo as we wanted to do for some time. Since # In NNlib.jl
if !isdefined(Base, :get_extension)
# do nothing
end and continue to import |
So the proposal is to have NNlibCUDA under both [deps] and [weakdeps] here, load it unconditionally on Julia 1.6, but don't try to load it at all on 1.10. Then give NNlib an extension for CUDA, which has another copy of the exact same code as the package NNlibCUDA? If the registered package is moved to the same repository as NNlib, then both the package and NNlib's CUDAext can One issue is that any non-Flux project which loads both NNlib and NNlibCUDA will I think get two copies of all definitions. |
FluxML/NNlib.jl#445 is now up for anyone who wants to kick the tires with this. I haven't tested import times or TTFX, so any data would be appreciated. |
BTW, timing this (on a slow machine which has a GPU, Julia nightly): julia> @time using Flux
21.362717 seconds (24.70 M allocations: 1.449 GiB, 5.16% gc time, 46.46% compilation time: 61% of which was recompilation) # before
9.096614 seconds (5.34 M allocations: 373.462 MiB, 4.08% gc time, 81.04% compilation time: 56% of which was recompilation) # after or, loading everything: julia> @time using Flux, NNlibCUDA, CUDA
21.134435 seconds (24.76 M allocations: 1.453 GiB, 3.64% gc time, 47.75% compilation time: 62% of which was recompilation) # after After this, the biggest offenders in Before: julia> @time_imports using Flux
1.2 ms Statistics
16.7 ms MacroTools
0.2 ms Reexport
6.8 ms ProgressLogging
8.7 ms IrrationalConstants
0.2 ms Compat
138.7 ms ChainRulesCore
8.1 ms DocStringExtensions 80.65% compilation time
1.1 ms ChangesOfVariables
1.4 ms InverseFunctions
1.2 ms LogExpFunctions
0.2 ms OpenLibm_jll
31.0 ms Preferences
0.3 ms JLLWrappers
461.8 ms OpenSpecFun_jll 115.01% compilation time (86% recompilation)
44.9 ms SpecialFunctions
0.3 ms Requires
0.3 ms Adapt
56.2 ms NNlib 62.75% compilation time (12% recompilation)
13.8 ms ShowCases
1.8 ms ConstructionBase
180.8 ms InitialValues
0.2 ms DataValueInterfaces
1.3 ms DataAPI
0.1 ms IteratorInterfaceExtensions
0.1 ms TableTraits
13.4 ms OrderedCollections
25.2 ms Tables
0.2 ms ZygoteRules
4.7 ms StaticArraysCore
19.4 ms Setfield
87.5 ms BangBang 58.56% compilation time
1.6 ms ContextVariablesX
0.1 ms FLoopsBase
1.3 ms PrettyPrint
0.5 ms NameResolution
31.8 ms MLStyle
2.4 ms JuliaVariables
0.5 ms ArgCheck
22.5 ms Baselet
0.1 ms CompositionsBase
0.1 ms DefineSingletons
17.1 ms MicroCollections
16.4 ms SplittablesBase
116.1 ms Transducers 36.62% compilation time
7.6 ms FLoops
41.7 ms Accessors 39.28% compilation time
22.2 ms FunctionWrappers
1017.0 ms FoldsThreads 193.39% compilation time
161.1 ms DataStructures
0.5 ms SortingAlgorithms
73.7 ms Missings
0.4 ms StatsAPI
51.0 ms StatsBase
5.5 ms SimpleTraits
1.0 ms DelimitedFiles
11.3 ms MLUtils
5.0 ms Functors
16.4 ms Optimisers
2.4 ms GPUArraysCore
0.1 ms RealDot
46.6 ms StructArrays
37.5 ms ChainRules
20.1 ms IRTools
0.8 ms DiffRules
0.3 ms NaNMath
420.6 ms FillArrays
20.7 ms AbstractFFTs
4.2 ms DiffResults
1271.6 ms StaticArrays
0.5 ms CommonSubexpressions
198.6 ms ForwardDiff
174.3 ms Zygote 52.62% compilation time
8.9 ms CEnum
474.3 ms LLVMExtra_jll 117.69% compilation time (84% recompilation)
406.2 ms LLVM 35.54% compilation time
0.4 ms ExprTools
180.9 ms TimerOutputs 19.49% compilation time
823.8 ms GPUCompiler 22.62% compilation time (68% recompilation)
1797.9 ms GPUArrays
13.6 ms BFloat16s
47.7 ms RandomNumbers 37.30% compilation time
12.4 ms Random123
6794.8 ms CUDA 2.02% compilation time
84.9 ms OneHotArrays
136.9 ms NNlibCUDA
100.9 ms Flux After: julia> @time_imports using Flux
1.2 ms Statistics
41.7 ms MacroTools
0.2 ms Reexport
6.7 ms ProgressLogging
8.5 ms IrrationalConstants
0.1 ms Compat
136.0 ms ChainRulesCore
8.2 ms DocStringExtensions 82.53% compilation time
1.1 ms ChangesOfVariables
1.3 ms InverseFunctions
1.1 ms LogExpFunctions
0.2 ms OpenLibm_jll
29.5 ms Preferences
0.3 ms JLLWrappers
453.6 ms OpenSpecFun_jll 113.71% compilation time (87% recompilation)
31.0 ms SpecialFunctions
0.3 ms Requires
0.3 ms Adapt
54.2 ms NNlib 66.10% compilation time (12% recompilation)
13.3 ms ShowCases
1.8 ms ConstructionBase
31.1 ms InitialValues
0.1 ms DataValueInterfaces
1.3 ms DataAPI
0.1 ms IteratorInterfaceExtensions
0.1 ms TableTraits
14.2 ms OrderedCollections
25.8 ms Tables
0.2 ms ZygoteRules
4.9 ms StaticArraysCore
20.1 ms Setfield
89.9 ms BangBang 57.72% compilation time
1.6 ms ContextVariablesX
0.1 ms FLoopsBase
1.3 ms PrettyPrint
0.4 ms NameResolution
33.0 ms MLStyle
2.5 ms JuliaVariables
0.5 ms ArgCheck
23.0 ms Baselet
0.1 ms CompositionsBase
0.1 ms DefineSingletons
17.6 ms MicroCollections
17.3 ms SplittablesBase
117.3 ms Transducers 36.25% compilation time
7.8 ms FLoops
42.9 ms Accessors 40.04% compilation time
22.9 ms FunctionWrappers
1079.6 ms FoldsThreads 178.99% compilation time
158.7 ms DataStructures
0.5 ms SortingAlgorithms
17.1 ms Missings
0.4 ms StatsAPI
53.0 ms StatsBase
5.7 ms SimpleTraits
1.1 ms DelimitedFiles
11.5 ms MLUtils
4.9 ms Functors
16.3 ms Optimisers
2.4 ms GPUArraysCore
0.1 ms RealDot
48.5 ms StructArrays
44.9 ms ChainRules
29.6 ms IRTools
0.9 ms DiffRules
0.3 ms NaNMath
375.0 ms FillArrays
19.0 ms AbstractFFTs
4.5 ms DiffResults
1331.8 ms StaticArrays
0.6 ms CommonSubexpressions
217.7 ms ForwardDiff
222.6 ms Zygote 42.27% compilation time
78.4 ms OneHotArrays
58.7 ms Flux |
In the scenario where we're unable to resolve the import time issues in FluxML/NNlib.jl#445, I feel the incremental solution would be to keep NNlibCUDA as a normal dep and not have it be a weak dep. Precompilation of CUDA.jl will still be unavoidable, but at least SBC users can still benefit from the reduced import times (alongside the much larger group of non-SBC users on non-GPU machines). |
I was just wondering if the plan here has settled? It would be great to not have to install the GPU stack on SBCs that I use! |
Then as now, I don't foresee any path for us to save users from installing the CUDA.jl stack before dropping support for Julia versions <1.9. That said, NNlibCUDA is likely to become an extension in the near future. The current hold-up there is figuring out what the fallback path would look like on the Flux side. I've seen some people using trigger packages, but word from above suggests those are verboten... For this particular PR, my recommendation would be to remove anything related to NNlibCUDA and convert the Flux-specific CUDA bits to an extension like we have for AMDGPU. It won't stop the GPU stack from being installed, but it'll get us one step closer. |
I just tried to rebase this and it has become pretty messy unfortunately. Is CUDA lazy loaded now? Also, my main motivation was to avoid CUDA install and |
Thanks for pushing this, @IanButterworth! I think now that FluxML/NNlib.jl#492 is merged and #2265 has some active effort behind it, we might be able to release something sooner than later. |
Uses new package extensions (due to be in julia 1.9), to make CUDA functionality optional.
On julia versions that don't have package extensions Flux will work as before.
I've only done the obvious things here to get loading working and haven't run tests locally.
I took the approach of maintaining file names of where functions came from to make it easier to cross-reference. It could all be consolidated if preferred.
updated:
For reference, Flux master on julia master
Helps with #1961
cc. @mcabbott
[Edit: closes #2155]
PR Checklist