Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility for walking a tree (e.g. gradients) w.r.t. a model #143

Closed
darsnack opened this issue Apr 17, 2023 · 7 comments
Closed

Utility for walking a tree (e.g. gradients) w.r.t. a model #143

darsnack opened this issue Apr 17, 2023 · 7 comments
Labels
enhancement New feature or request

Comments

@darsnack
Copy link
Member

darsnack commented Apr 17, 2023

Motivation and description

Using trainable, we can walk a model and only apply a function to trainable parameters. But the gradient from Zygote is a named tuple without this information.

Normally, for optimizers this is fine, because our function is applied at every leaf, so we only need a single pass over the model. But it is fairly common to walk entire tree of gradients to compute something (e.g. like a global norm term) first. In this case, we need a pass over gradient outside of the update context.

Possible Implementation

We can include a maptrainable(f, model, [gradient]) (or better name) function that maps a function w.r.t. the trainable parameters of model.

  • If another tree like gradient is passed, then f is applied to the leaves of gradient (i.e. approximately fmap(TrainableWalk(f), gradient, model) using the last argument to filter the walk).
  • If no other tree is passed, we just apply f to model (this is a simple walk but maybe it is good for consistency).
@darsnack darsnack added the enhancement New feature or request label Apr 17, 2023
@darsnack
Copy link
Member Author

Ideally, I think the implementation would underly update (i.e. update is maptrainable with f specialized to call apply).

@ToucheSir
Copy link
Member

Related: #57. We have proposals for map and reduce, but does it make sense to try for a mapreduce?

@darsnack
Copy link
Member Author

Agreed, with the ability to add more trees to call as described above.

@mcabbott
Copy link
Member

If another tree like gradient is passed, then f is applied to the leaves of gradient (i.e. approximately fmap(TrainableWalk(f), gradient, model) using the last argument to filter the walk).

I think the most obvious t_mapreduce(f, r, model, grads) would always call f(x, dx), but take trainability from the model. The present fmap(f, xs, ys) always calls f(x,y):

julia> fmap(println, (a=1, b=2), (a="!", b="?"))
1!
2?
(a = nothing, b = nothing)

julia> sh = [1.0]; fmap(println, (a=sh, b=sh), (a="!", b="?"))
[1.0]!
(a = nothing, b = nothing)

The tricky bit as usual will be shared parameters. Here fmap simply ignores y belonging to a shared x. This fmap(f, xs, ys) is a half-baked feature, I think update! was the original target but it's not actually right for that.

The walk done by Optimisers.update! instead adds distinct dx belonging to shared x before calling apply!. I wonder how often that would be correct, e.g. for the gradient norm example it probably would be. To write update! (ignoring its return) you would need t_mapreduce(f, Returns(nothing), model, grads, state_tree) where we add dx but not state?

julia> Optimisers.setup(Momentum(), (a=sh, b=sh))
(a = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]), b = Leaf(Momentum{Float32}(0.01, 0.9), [0.0]))

julia> ans.a === ans.b
true

This all always feels like we have slightly the wrong abstractions.

@ericphanson
Copy link

For the norm use-case, another thing that would be handy is if I could destructure the gradient to flatten it, but only keep the trainable params as governed by the model. Then I can just take a norm directly on the flat vector.

Or maybe a more composable thing would be if I could walk the model & gradient simultaenously, and map non-trainable gradients to nothing, returning an updated gradient that only has non-nothing entries for trainable params. Then I could do whatever I wanted with that (walk it again with fmap, flatten it with destructure, etc).

@ToucheSir
Copy link
Member

A simpler version of this came up in conversation over RL models on Slack today. The current incantation for updating one model's parameters based on the moving average of another model's is:

for (t, p) in zip(Flux.params(target), Flux.params(policy))
        t .= (1 - tau) .* t .+ tau .* p
end

To which I proposed:

Functors.fmap(m_target, m_policy; walk = Optimisers._Trainable_biwalk()) do t, p
  t = (1 - tau) .* t .+ tau .* p
end

It should take no time to package up the latter as a mapparams function on our side. The questions are, where should it live (Flux or Optimisers) and what should be it called (e.g. maptrainable instead)?

@CarloLucibello
Copy link
Member

CarloLucibello commented Nov 6, 2024

We now have trainables(model, path=true) that can be used for all use cases mentioned here. The last example becomes:

for (kp, t) in trainables(target, path=true)
    p = getkeypath(policy, kp)  
    t .= (1 - tau) .* t .+ tau .* p
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants