-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rrule for fill! #521
base: main
Are you sure you want to change the base?
rrule for fill! #521
Conversation
I think the difficulty with allowing this is that it will cause any other rule which has captured
Is there some clever way we might avoid this, or at least, make this example where |
Can't we trust users to use this in a safe way? We do it already for |
It's harder for me to picture Seems to be from #252, without discussion. |
I don't see a general solution for mutating functions violating implicit assumptions made by other rrules. PS |
We do have x = similar(...)
y = fill!(x, a) would be we call And in the
case |
Oh right, I guess However, my original example is trickier, since the pullback never gets called.
|
Huh, is the difference about if it was assigned to |
I don't know the internals, maybe this could be changed? It calls the |
Here's an update.
julia> using Zygote, ChainRulesCore, ForwardDiff, Diffractor
# New rule, with back(::Zero) method
julia> function ChainRulesCore.rrule(::typeof(fill!), A::Vector, x::Number)
function back(dB)
println("pullback for fill! got $dB")
(NoTangent(), @not_implemented("arg is mutated"), sum(dB))
end
function back(dB::AbstractZero)
println("pullback for fill! got $dB")
(NoTangent(), @not_implemented("mutated"), @not_implemented("no input"))
end
fill!(A,x), back
end
julia> Zygote.gradient([1,2], 3) do x, s # easy case, works as desired
y = fill!(similar(x), s)
sum(abs2, y)
end
pullback for fill! got [6, 6]
(nothing, 12.0)
# Example from above
julia> Zygote.gradient([1,2,3], 0) do x, s # silently wrong, both args!
y = log.(x) # this needs x's value
fill!(x, s) # pullback is not called
sum(y .+ x)
end
([2.0, 1.5, 1.3333333333333333], nothing)
julia> ForwardDiff.gradient([1,2,3]) do x
y = log.(x)
fill!(x, 0)
sum(y .+ x)
end
3-element Vector{Float64}:
1.0
0.5
0.3333333333333333
julia> Diffractor.gradient([1,2,3], 0) do x, s
y = log.(x) # this needs x's value
fill!(x, s) # poisons x, and s
sum(y .+ x)
end
pullback for fill! got ZeroTangent()
(NotImplemented(Main, #= REPL[22]:6 =#, mutated), NotImplemented(Main, #= REPL[22]:6 =#, no input))
# New example
julia> Zygote.gradient([1 2; 3 4], [5,6], 7) do x, y, z
xy = x * y
y2 = fill!(y, z)
sum(xy .+ y2)
end
pullback for fill! got [1.0, 1.0]
([7.0 7.0; 7.0 7.0], [4.0, 6.0], 2.0)
julia> Diffractor.gradient([1 2; 3 4], [5,6], 7) do x, y, z # silently wrong about x
xy = x * y # x's gradient needs y's value, etc.
y2 = fill!(y, z) # poisons y, but not x
sum(xy .+ y2)
end
pullback for fill! got [1.0, 1.0]
([7.0 7.0; 7.0 7.0], NotImplemented(Main, #= REPL[5]:4 =#, nope), 2.0)
julia> ForwardDiff.gradient([1 2; 3 4]) do x
y, z = [5,6], 7
xy = x * y
y2 = fill!(y, z)
sum(xy .+ y2)
end
2×2 Matrix{Int64}:
5 6
5 6 |
If you overload function Zygote._pullback(__context__::Zygote.AContext, ::typeof(fill!), x::Array, v)
old = copy(x) # could instead just have fill!(x, NaN) on the reverse?
y = fill!(x, v)
back(::Nothing) = begin
copyto!(x, old) # restore
(nothing, Zygote.Fill(NaN, size(x)), NaN) # since we didn't see the return, poison it
end
back(dy) = begin
copyto!(x, old)
(nothing, Zygote.Fill(NaN, size(x)), sum(dy)) # here we know dv
end
return (y, back)
end Similar for function Zygote._pullback(__context__::Zygote.AContext, ::typeof(Base.setindex!), x::Array, v, ind::Integer...)
old = x[ind...]
y = setindex!(x, v, ind...)
nots = map(_ -> nothing, ind)
back(::Nothing) = begin
x[ind...] = old
(nothing, Zygote.Fill(NaN, size(x)), NaN, nots...)
end
back(dy) = begin
x[ind...] = old
(nothing, Zygote.Fill(NaN, size(x)), dy, nots...) # setindex! returns the value
end
return (y, back)
end
Zygote.gradient([1,2,3.0], 4) do x, y
x[1] = y^2
sum(x .* y)
end # should be ([0,4,4], 53), in fact all NaN |
Damn it Zygote. |
We had FluxML/Zygote.jl#1227 but it was closed, I've just re-opened it. The problem was that FluxML/Zygote.jl#1204 happened, which lead to FluxML/Zygote.jl#1205. As I mentioned in the issue, there doesn't seem to be a more incremental fix here than doing all of FluxML/Zygote.jl#603. Am I missing a better solution? |
Thinking about it a bit more, could we get away with just switching over Zygote's zero types (i.e. |
@mzgubic (with my help) tried to switch over Zygote's types a few years ago. |
In the spirit of baby steps, I've filed FluxML/Zygote.jl#1385 to provide a better base for future attempts at this. |
related to #515