-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
∇getindex mutates, causing issues with higher order AD over getindex. #820
Comments
Your example can be reduced even more: using Flux
using Zygote
net = Chain(
Dense(2, 100, relu),
Dense(100, 100, relu),
Dense(100, 100, relu),
Dense(100, 4),
x -> x[1, :]
)
θ, builder = Flux.destructure(net)
x = randn(Float32, 2, 16)
function predict(θ, x, builder)
net = builder(θ)
A, pullback = Zygote.pullback(net, x)
∇A = pullback(ones(eltype(A), size(A)))[1]
a = sum(∇A)
end
Zygote.gradient(θ -> predict(θ, x, builder), θ) The issue is just that you can't nest Zygote. If you make the outer differentiation use ReverseDiff.jl and the inner one use Zygote.jl you're fine though. |
@ChrisRackauckas Thanks for the reply. I can see now that many of the pullbacks defined for common operations, like Do you know if there is any plan to address this? It seems to me that loss functions that use the gradients of a scalar field could be quite common when modelling physical systems with conserved quantities, as with Hamiltonian NNs. I have tried using combining various ADs as you suggest, but keep running into problems. I give some examples of things I've tried below. My apologies for their length, but I feel it might help to show what I'm actually trying to accomplish. For context, I am trying to model the power dissipated by a non-equilibrium thermodynamic system undergoing some control protocol. This is modelled with two components. The first is a conservative term, that depends on the directional derivative of a scalar field (the free energy). The second is a dissipative term that depends on a positive definite friction tensor. In the spirit of SciML, I'm trying to model both the free energy and the friction tensor using NNs. |
Here is one approach I tried using If I don't include However, this does not give me the correct gradient. Instead, all of the gradients for the ξnet parameters are zero. If I change using Flux
using Zygote
using DiffEqFlux
using ForwardDiff
using NNlib
using LinearAlgebra
using Statistics
export predictpower, create_Fnetwork, create_ξnetwork, combine_networks
struct Builder{R1,R2}
re1::R1
re2::R2
n::Int64
end
(builder::Builder)(p) = begin
p1 = p[1:builder.n]
p2 = p[(builder.n + 1):end]
return (builder.re1(p1), builder.re2(p2))
end
function combine_models(m1, m2)
p1, re1 = Flux.destructure(m1)
p2, re2 = Flux.destructure(m2)
n = size(p1)[1]
p = [p1; p2]
builder = Builder(re1, re2, n)
return (p, builder)
end
struct DirectionalDerivative{F, V}
f::F
direction::V
end
const DD = DirectionalDerivative
function (dd::DD)(pt)
let dd=dd
ForwardDiff.derivative(0) do h
dd.f(pt + h * dd.direction)
end
end
end
function predictpower(x, θ, builder)
n, nbatch = size(x)
@assert n % 2 == 0
ncontrol = n ÷ 2
# Unpack the inputs
Fnet, ξnet = builder(θ)
λ = x[1:ncontrol, :]
dλ = x[ncontrol + 1:end, :]
# Compute the directional derivative dλ⋅∇F
Pcons = DD(Fnet, dλ)(λ)
# Reshape to column / row vectors
dλ = reshape(dλ, ncontrol, 1, nbatch)
dλT = permutedims(dλ, [2, 1, 3])
# Calculate the dissaptive part of the power
# Pdiss = dλ^T ⋅ ξ ⋅ dλ
ξ = ξnet(λ)
Pdiss = batched_mul(batched_mul(dλT, ξ), dλ)
Pcons = reshape(Pcons, :)
Pdiss = reshape(Pdiss, :)
return Pcons + Pdiss
end
function create_Fnetwork(controldim, hiddendim, hiddenlayers)
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
final = Dense(hiddendim, 1)
return Chain(initial, layers..., final)
end
function create_ξnetwork(controldim, hiddendim, hiddenlayers)
componentdim = controldim * (controldim - 1) ÷ 2 + controldim
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
final = Dense(hiddendim, componentdim)
posdef = VecToPosDef(componentdim, controldim)
return Chain(initial, layers..., final, posdef)
end
function create_network(controldim, hiddendim, hiddenlayers)
componentdim = controldim * (controldim - 1) ÷ 2 + controldim
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
penultimate = Dense(hiddendim, componentdim + 1)
posdef = VecToPosDef(componentdim, controldim)
function output(x)
F = reshape(x[1, :], 1, :)
ξ = posdef(x[2:end, :])
return (F, ξ)
end
return Chain(initial, layers..., penultimate, output)
end
"""
VecToPosDef(indim, n)
Convert a vector to a positive definite matrix.
Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
indim::Int64
n::Int64
function VecToPosDef(indim, n)
@assert indim == n * (n - 1) ÷ 2 + n
return new(indim, n)
end
end
function (lpd::VecToPosDef)(x::AbstractArray)
indim, n_batch = size(x)
@assert indim == lpd.indim
# Zygote does not support mutation of arrays,
# so we need to use a Buffer object, which does.
out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)
# Set the upper triangle to zero.
for i = 1:lpd.n
for j = i + 1:lpd.n
for k = 1:n_batch
out[i, j, k] = 0.0
end
end
end
i = 1
# Compute the diagonal.
# Exponentiate to ensure > 0.
for j = 1:lpd.n
out[j, j, :] = exp.(x[i, :])
i += 1
end
# Compute the lower triangle.
for j = 1:lpd.n
for k = 1:(j - 1)
out[j, k, :] = x[i, :]
i += 1
end
end
# Turn the buffer back into an array
out = copy(out)
return batched_mul(out, permutedims(out, [2, 1, 3]))
end
# Test code
Fnet = create_Fnetwork(2, 128, 2)
ξnet = create_ξnetwork(2, 128, 2)
θ, builder = combine_models(Fnet, ξnet)
x = randn(Float32, 4, 128)
function loss(x, θ, builder)
power = predictpower(x, θ, builder)
return mean(power.^2)
end
grad = Zygote.gradient(p -> loss(x, p, builder), θ)[1]
grad = getindex.(ForwardDiff.partials.(grad),1) |
This version tries to use a single network with Zygote for the inner gradient and It fails with (full traceback below): using Flux
using Zygote
using ReverseDiff
using NNlib
using LinearAlgebra
export predictpower, create_Fnetwork, create_ξnetwork, combine_networks
function predictpower(x, θ, builder)
n, nbatch = size(x)
@assert n % 2 == 0
ncontrol = n ÷ 2
# Unpack the inputs
net = builder(θ)
λ = x[1:ncontrol, :]
dλ = x[ncontrol + 1:end, :]
# Forward pass
results, pullback = Zygote.pullback(net, λ)
F = results.F
ξ = results.ξ
∇F = pullback((F = ones(eltype(F), size(F)), ξ = nothing))
∇F = reshape(∇F, 1, :)
# Reshape to column / row vectors
dλ = reshape(dλ, ncontrol, 1, nbatch)
dλT = permutedims(dλ, [2, 1, 3])
# Compute the conservative part of the power
Pcons = batched_mul(dλ, ∇F)
# Calculate the dissaptive part of the power
# Pdiss = dλ^T ⋅ ξ ⋅ dλ
ξ = ξnet(λ)
Pdiss = batched_mul(batched_mul(dλT, ξ), dλ)
Pcons = reshape(Pcons, :)
Pdiss = reshape(Pdiss, :)
return Pcons + Pdiss
end
function create_network(controldim, hiddendim, hiddenlayers)
componentdim = controldim * (controldim - 1) ÷ 2 + controldim
initial = Dense(controldim, hiddendim, relu)
layers = [Dense(hiddendim, hiddendim, relu) for i = 1:hiddenlayers]
penultimate = Dense(hiddendim, componentdim + 1)
posdef = VecToPosDef(componentdim, controldim)
function output(x)
F = reshape(x[1, :], 1, :)
ξ = posdef(x[2:end, :])
return (F, ξ)
end
return Chain(initial, layers..., penultimate, output)
end
"""
VecToPosDef(indim, n)
Convert a vector to a positive definite matrix.
Take `indims` dimensional batch of vectors and convert to
a batch of `(n, n)`` positive definite matrices. The dimensions
must much sch that `indim == n*(n-1)/2 + n`. The entries
of the input are treated as elements of lower triangular
matrix. The diagonal elements are exponentated to unsure
positivity.
"""
struct VecToPosDef
indim::Int64
n::Int64
function VecToPosDef(indim, n)
@assert indim == n * (n - 1) ÷ 2 + n
return new(indim, n)
end
end
function (lpd::VecToPosDef)(x::AbstractArray)
indim, n_batch = size(x)
@assert indim == lpd.indim
# Zygote does not support mutation of arrays,
# so we need to use a Buffer object, which does.
out = Zygote.Buffer(x, lpd.n, lpd.n, n_batch)
# Set the upper triangle to zero.
for i = 1:lpd.n
for j = i + 1:lpd.n
for k = 1:n_batch
out[i, j, k] = 0.0
end
end
end
i = 1
# Compute the diagonal.
# Exponentiate to ensure > 0.
for j = 1:lpd.n
out[j, j, :] = exp.(x[i, :])
i += 1
end
# Compute the lower triangle.
for j = 1:lpd.n
for k = 1:(j - 1)
out[j, k, :] = x[i, :]
i += 1
end
end
# Turn the buffer back into an array
out = copy(out)
return batched_mul(out, permutedims(out, [2, 1, 3]))
end
# Test it
net = create_network(2, 128, 2)
θ, builder = Flux.destructure(net)
x = randn(Float32, 4, 128)
function loss(x, θ, builder)
power = predictpower(x, θ, builder)
return mean(power.^2)
end
grad = ReverseDiff.gradient(θ -> loss(x, θ, builder), θ) Here is the traceback:
|
@ChrisRackauckas the MWE can actually be reduced to only include the julia> gradient(rand(3,3)) do p
gradient(p) do p
sum(p[1, :])
end[1] |> sum
end |
I'm stuck with the same problem, how did you solve yours? https://discourse.julialang.org/t/flux-higher-order-derivatives-and-forward-mode/38805/3 I tried integrating that approach to write an |
So what is the actual problem preventing Zygote from computing higher order derivatives? |
So the "issue" is that zygote uses mutation on the adjoint of the getindex. Hmm, let me think about if we can handle it better |
I actually thought the problem lied elsewhere, but using x = rand(2)
m = Chain(Dense(2,1))
Flux.gradient(params(m)) do
gradient(m,x) |> sum |> sum
end Edit: |
Yeah, I think that should be fine but it is less generally correct to do, I think |
A cleaner way is to extract the gradient by tuple destructuring (is it called that?) |
The issue is that: ∇getindex(x::AbstractArray, inds) = dy -> begin
if inds isa NTuple{<:Any, Integer}
allinds = eachindex(x)
ininds(i) = i ∈ inds
dx = ifelse.(_zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (dx, map(_->nothing, inds)...)
end These mutate. I would suggest splitting that into two separate dispatches and trying to come up with schemes that are just broadcasts or filters. If that's not easy to do, then I think a dispatch on just arrays (to avoid CuArrays) that just loops would be nice and fix the problem for most non-GPU users. |
I believe I'm facing a similar issue, where I need to use the jacobian of my prediction function with respect to an array of random variables in the loss function. Here my code: # produces θ for pred function
ann = Flux.Chain(
Flux.Dense(input, 32, Flux.tanh),
Flux.Dense(32, 32),
Flux.Dense(32, 3),
);
function dAdt(dA, A, p, t)
a, b, c = p
dA[1] = -a * b * A[1]
dA[2] = c * b * A[1] - a * A[2]
end
function pred(θ, η, t, callback)
p = θ .* exp.(η)
prob = diffeq.ODEProblem(dAdt, [0., 0.], (-.1, maximum(t)), p)
sol = diffeq.solve(prob, diffeq.Tsit5(), saveat=t, tstops=[0.], callback=callback, sensealg=des.ForwardDiffSensitivity())
return sol[2 , :] # A[2] corresponds to y measurements
end
∂pred_∂η(θ, η, time, callback) = Zygote.jacobian(eta -> pred(θ, eta, time, callback), η)
# p == 3x3 correlation matrix
function Obj(x, y, p, times, callbacks)
if !isposdef(p)
return Inf
end
N = length(times) # equal to the number of observations in dataset
θ = ann(x')
η = zeros(size(p, 1)) # test
loss = 0.
for i in 1:N
ŷ = pred(θ[:, i], η, times[i], callbacks[i])
residuals = y[i] - ŷ
jac_eta = ∂pred_∂η(θ[:, i], η, times[i], callbacks[i]) # line 1
loss = mean(residuals) + mean(jac_eta * p * jac_eta') # line 2
end
return loss
end
grad = Zygote.gradient(() -> Obj(x, y, p, times, callbacks), Flux.params(ann)) # error mutating arrays removing line 1 and changing line 2 to |
#77 is a solution that could be used. |
So out of #77 we would just need the |
Yes |
This is my take at Keno's approach ∇getindex(x::AbstractArray, inds) = dy -> (_zerosetindex(x, inds, dy), map(_->nothing, inds)...)
function _zerosetindex(x, inds::NTuple{<:Any, Integer}, dy)
dx = _zero(x, typeof(dy))
dx[inds...] = dy
dx
end
function _zerosetindex(x, inds, dy)
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
dx
end
@adjoint function _zerosetindex(x, inds, dy)
_zerosetindex(x, inds, dy), ddx -> (nothing, nothing, ddx[inds...])
end Keno's tests seem to run through as well. Should I put up a PR? |
I think that would be great! |
We'll want to test this with GPUs, and check for performance |
I'm new to flux/zygote/julia. I'm trying to develop a model that looks something like below.
I have a network that produces two outputs,
A
andB
. The gradient ofA
with respect to the inputs is part of my loss function, along with other terms that depend onB
. I've just summed things here for simplicity, but my actual model produces the same error.The error is
ERROR: LoadError: Mutating arrays is not supported
, which comes from thex -> (A=x[1, :], B=x[2:end, :])
line in the network, but I don't understand where the mutation is coming from.I gather from a number of issues here and threads on Discourse that higher-order gradients are not well supported, but there isn't much documentation around this. As a new user, it would be extremely helpful if there was some kind of documentation / guidance about how to work around this.
On a related Discourse thread @ChrisRackauckas suggested using another AD, like ReverseDiff, but I'm can't figure out how to get the gradient that I want. Any guidance would be appreciated.
The text was updated successfully, but these errors were encountered: