-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rules for mutating functions, @adjoint! and its documentation #1228
Comments
To elaborate on my point in the topic, this is the totality of what function ChainRules.rrule(::typeof(addone!), a)
y = addone!(a)
# @adjoint! adds:
addone!_pullback(::NoTangent) = (NoTangent(), NoTangent())
# note that in ZygoteRules parlance this would be:
# addone!_pullback(::Nothing) = nothing
function addone!_pullback(ȳ)
return NoTangent(), ones(length(a))
end
return y, addone!_pullback
end This is because the return values of mutating functions are often unused, and thus AD may pass in a null gradient to the pullback: x = [...]
__unused__ = push!(x)
return sum(x) So you can see that
The requirement is that the AD support mutation sufficiently well. How that looks in practice is a little fuzzier, since AFAIK we don't have a ChainRules-compatible source-to-source AD which does support it. Note that Zygote does support a limited amount of mutation already, but not of array types.
In general, it is not safe and Zygote does try to catch it. This is why you get those array mutation errors when using it. However, Zygote's compiler (i.e. the AD transform) has very little information to work with when generating pullback code and thus has to punt the responsibility for checking for mutation to the runtime. I do agree that there should be better docs for this somewhere outside of the mention in the |
Sorry, I may have to phrase the issue/question differently: My primary interest it not exactly using ChainRulesCore, Zygote
function addone!(array)
array .+= 1
return sum(array)
end
function ChainRulesCore.rrule(::typeof(addone!), a)
y = addone!(a)
function addone!_pullback(ȳ)
return NoTangent(), ones(length(a))
end
return y, addone!_pullback
end
a = [3.3,2.1,2.3]
gradient(addone!, a) So, what kind of functions like this do work / what doesn't work? It just seems very unclear to me. |
As written this one gives wrong answers in most cases: julia> Zygote.gradient([1,2,3]) do x
addone!(x)^2
end
ȳ = 18
([1.0, 1.0, 1.0],)
julia> ForwardDiff.gradient([1,2,3]) do x
addone!(x)^2
end
3-element Vector{Int64}:
18
18
18 That's easy to correct, Edit: here's an example where mutation causes problems: function ChainRulesCore.rrule(::typeof(addone!), a)
y = addone!(a)
addone!_pullback(ȳ) = NoTangent(), fill(@show(ȳ), size(a))
y, addone!_pullback
end
gradient([1,2,3]) do x
y = dot(x, x) # pullback for dot closes over x
z = addone!(x) # pullback for addone! is not run, unless you uncomment +z:
x[1] + y # + z
end |
I got intent of the question from the start :). I would recommend re-reading the second half of my comment above. |
So, do I get this right: in general it is not recommended to do these rules for mutating functions. It is possible in some cases, but it's not possible to say in which cases it is and in which it isn't a priori? |
Yes, it's not safe in general to define such rules. I think any function which mutates an input array can give you wrong answers.
It's certainly repeatable, thus is known from the code. But correctness isn't a property of the rule alone. The dangers are that (1) some other rule may depend on If you know that |
Thanks a lot for the answers.
A typical case of preallocating memory for a function that saves its results in this array. Lets say e.g. a Fourier transform, so to define a rule for inplace plans. But I also do worry about correctness already enough, that I am not sure I want to go down this route. Maybe I do some tests though. |
Note that this must be done transitively, which means that if you have a nested chain of functions: f(x, y) = g(x, y) + sum(x)
function g(x)
push!(x, 3)
return y
end Both function f(x, y)
newx, res = g(x, y)
return sum(newx) + res
end
function g(x, y)
newx = push!(x, 3)
return newx, y
end Note that because Zygote does not have enough information at "compile" time to know which functions transitively call a mutating function (which itself returns the mutated value) and don't thread the return value through to their return value, every function would have to be augmented this way. That could cause performance and possibly even correctness issues. Another problem is that not all mutating functions return the mutated value! Now the question becomes: given all these caveats, how do Why not enable the mutable value gradient cache for all arrays and not just
But does not elaborate. #75 was an experimental effort to make array mutation work, but I'm not sure if it ran into any fundamental issues which would've prevented further progress. Perhaps @MikeInnes would be able to provide a historical perspective on this? |
This can be done, but the preallocation part needs to be hidden in a rule and there are some caveats around usage. See PumasAI/SimpleChains.jl#59 for a bit more discussion. |
Thanks for all the comments. I'll close this issue now, it is clearer to me what can work and what kind of tests I could do. |
The simple answer to the original question is: you're free to use The more technical answer is that mutation works for anything that isn't captured by value in a pullback, which in practice means Ok, here's more detail on why this stuff is hard to fix. Buckle up. Supporting array mutation is indeed hard, but not for the obvious reasons. Arrays introduce two kinks that are actually pretty easy to deal with: Firstly, the above issue with values being captured. But the structure of AD gives a surprisingly easy solution: just undo all the mutations in the backwards pass. For example you'll notice that Secondly, mutation introduces non-local data flow. If thread You'll notice that Zygote claims the gradient of a xs = Zygote.Buffer([1])
xs[1] = 2
Zygote.gradient(xs -> xs[1]^2, xs) # => (nothing,) The gradient isn't really ( This is all perfectly workable, but you'll notice that the adjoints for In #75, Zygote does a bunch of magic to make value-like adjoints do the right thing for references, eg adjoints that produce a mutable need to retrieve (and clear) the accumulated gradient, which then is passed to the user's pullback. But adjoints might receive or produce arrays wrapped in things like I can imagine ways to clean this situation up, eg by formalising this fixup operation, but not without ugliness. For my part, I'm going to go watch some Rich Hickey talks instead :) (The other downside to having reference-like gradients is that it's more type restrictive: we have to fix the type of the gradient ahead of time for performance, as we do with |
(This is related to https://discourse.julialang.org/t/zygote-jl-adjoint-mutating-inplace-adjoints/78241)
Inspecting the Zygote code, I can see that aside from
@adjoint
there is also@adjoint!
that is used to declare the adjoints of some mutating functions (likepush!
etc). I can’t find any doc strings or documentation when and how this can be used. I suspect, there are some limitations as Zygote generally forbids mutating. Additionally,ChainRules
in its documentation says (https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/which_functions_need_rules.html#Functions-which-mutate-arrays):And then goes on to demonstrate how to write these rules for Zygote nonetheless, with the example for a function that adds inplace to the input array. This seems to work (and is probably translated to an
@adjoint!
rule?).(just copied from the ChainRules doc)
So what are the requirements that these rules defined by ChainRules or by
@adjoint!
work? In the linked issue atChainRulesCore
, they can also not really exactly name these requirements. At the very least there should be some documentation on that and even better Zygote should return some kind of warning if they are prone to fail.The text was updated successfully, but these errors were encountered: