-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utility for walking a tree (e.g. gradients) w.r.t. a model #143
Comments
Ideally, I think the implementation would underly |
Related: #57. We have proposals for |
Agreed, with the ability to add more trees to call as described above. |
I think the most obvious julia> fmap(println, (a=1, b=2), (a="!", b="?"))
1!
2?
(a = nothing, b = nothing)
julia> sh = [1.0]; fmap(println, (a=sh, b=sh), (a="!", b="?"))
[1.0]!
(a = nothing, b = nothing) The tricky bit as usual will be shared parameters. Here The walk done by julia> Optimisers.setup(Momentum(), (a=sh, b=sh))
(a = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), b = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]))
julia> ans.a === ans.b
true This all always feels like we have slightly the wrong abstractions. |
For the norm use-case, another thing that would be handy is if I could Or maybe a more composable thing would be if I could walk the model & gradient simultaenously, and map non-trainable gradients to |
A simpler version of this came up in conversation over RL models on Slack today. The current incantation for updating one model's parameters based on the moving average of another model's is: for (t, p) in zip(Flux.params(target), Flux.params(policy))
t .= (1 - tau) .* t .+ tau .* p
end To which I proposed: Functors.fmap(m_target, m_policy; walk = Optimisers._Trainable_biwalk()) do t, p
t = (1 - tau) .* t .+ tau .* p
end It should take no time to package up the latter as a |
We now have for (kp, t) in trainables(target, path=true)
p = getkeypath(policy, kp)
t .= (1 - tau) .* t .+ tau .* p
end |
Motivation and description
Using
trainable
, we can walk amodel
and only apply a function to trainable parameters. But thegradient
from Zygote is a named tuple without this information.Normally, for optimizers this is fine, because our function is applied at every leaf, so we only need a single pass over the model. But it is fairly common to walk entire tree of gradients to compute something (e.g. like a global norm term) first. In this case, we need a pass over
gradient
outside of theupdate
context.Possible Implementation
We can include a
maptrainable(f, model, [gradient])
(or better name) function that maps a function w.r.t. the trainable parameters ofmodel
.gradient
is passed, thenf
is applied to the leaves ofgradient
(i.e. approximatelyfmap(TrainableWalk(f), gradient, model)
using the last argument to filter the walk).f
tomodel
(this is a simple walk but maybe it is good for consistency).The text was updated successfully, but these errors were encountered: