-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Adapt.adapt_structure
method for Optimisers.Leaf
#180
Conversation
Adapt.adapt_storage
method for Optimisers.Leaf
Adapt.adapt_structure
method for Optimisers.Leaf
I think this won't preserve identity? That's the big difference between Functors and Adapt |
@mcabbott can you explain? I don't know what identity is. @CarloLucibello addressing your comment from #179
That is true, the MWE works with MLDataDevices. However, we still need Adapt functionality. Consider the case when Leaf is stored as part of a struct. Then using using Optimisers, CUDA, LuxCUDA, MLDataDevices, Adapt
struct TrainState{Tp, To}
p::Tp
opt_st::To
end
Adapt.@adapt_structure TrainState
p = rand(2)
opt_st = Optimisers.setup(Optimisers.Adam(), p)
ts = TrainState(p, opt_st)
device = gpu_device()
device(ts).opt_st.state[1]
2-element Vector{Float64}:
0.0
0.0 So there is a need to define Adapt.adapt_structure for Leaf. |
Functors keeps an IdDict so that if the same array appears twice, this property is preserved by fmap. Optimisers.jl follows that too, and will use (and expect, IIRC) the same Leaf in such cases. So I don't see an easy way to male |
Would something like this solve the problem? function Adapt.adapt_storage(to, leaf::Leaf)
return fmap(x -> Adapt.adapt_storage(to, x), leaf)
end |
It would not, because the IdDict needs to be shared between |
Maybe it's possible to grab the IdDict and bind it to the new Leaf object? Where is it defined? BTW the fix in this PR |
That's because your model doesn't have any shared/"tied" parameters. e.g.
It's created at the top level in |
…handled by functors. So we add a warning referring the user to Flux.gpu or MLDataDevices.gpu_device()
@ToucheSir, thanks for explaining. I added a warning to the The behavior is as follows: julia> using Optimisers, CUDA, LuxCUDA
julia> opt_st = Optimisers.setup(Optimisers.Adam(), zeros(2))
Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0, 0.0], [0.0, 0.0], (0.9, 0.999)))
julia> cu(opt_st)
┌ Warning: `Optimisers.Leaf` object does not support device transfer via
│ `Adapt.jl`. Avoid this by calling `Flux.gpu/cpu` or
│ `MLDataDevices.cpu_device()/gpu_device()` on the optimiser state object.
│ See below GitHub issue for more details.
│ https://github.com/FluxML/Optimisers.jl/issues/179
└ @ OptimisersAdaptExt ~/.julia/dev/Optimisers.jl/ext/OptimisersAdaptExt.jl:7
Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))
julia> cu(opt_st).state[1] |> typeof
CuArray{Float32, 1, CUDA.DeviceMemory}
julia> using MLDataDevices
julia> gpu_device()(opt_st)
Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))
julia> gpu_device()(opt_st).state[1]
2-element CuArray{Float32, 1, CUDA.DeviceMemory}:
0.0
0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, I had another look through this after some internal changes around device handling on the Flux side. Made some minor tweaks (removed PR link since we'd want a docs one, and used adapt
to be more general), but it should be ready to go.
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
@warn """`Optimisers.Leaf` object does not support device transfer via | ||
`Adapt.jl`. Avoid this by calling `Flux.gpu/cpu` or | ||
`MLDataDevices.cpu_device()/gpu_device()` on the optimiser state object. | ||
""" maxlog=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want this only once? It's a potential correctness bug, not a performance issue.
""" maxlog=1 | |
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is your worry that the message will get lost in the shuffle somehow? My thought was that people may have a valid use for this, and as long as they know what they're getting into the library doesn't have to remind them on every call.
Another practical concern would be what happens when someone tries to call cu(large state tree)
. Not setting maxlog
would mean other logging is drowned out because this warning would trigger for every Leaf
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that printing 100 times on a big model is too much. Ideal would be once, on every invocation, IMO... but that's hard to make happen.
It's not the world's biggest correctness bug to ignore shared parameters, so maybe we should live with it. Maybe the message should say that's what the problem is?
Thanks @ToucheSir , I resolved the merge conflicts. |
`Adapt.jl`. This could lead to incorrect gradient updates. Avoid this by | ||
calling `Flux.gpu/cpu` or `MLDataDevices.cpu_device()/gpu_device()` on the | ||
optimiser state object. | ||
""" maxlog=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcabbott i edited the warning to say that this is a correctness issue.
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
* Adapt.adapt_structure method for Optimisers.Leaf * import Adapt.jl * add Adapt.jl to Project.toml * adapt compat * based on discussion: adapt_structure method does not maintain IdDict handled by functors. So we add a warning referring the user to Flux.gpu or MLDataDevices.gpu_device() * Update ext/OptimisersAdaptExt.jl Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> * edit warning to indicate that this is a correctness issue * Update ext/OptimisersAdaptExt.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --------- Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Fix #179
PR Checklist