-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problems with variable indirect use #946
Comments
For Zygote, it would be better to use |
Thanks for the great hint. Can you point me to an example or a documentation on |
You can have a look at the ChainRules documentation, and for examples see the ChainRules package. |
We've discussed this over email but I thought the short version might as well be recorded here. The reason for the error is that a matrix, |
Another problem here seems to be that you define the |
Thanks for picking this up! Yet the error message remains the same... |
This is my attempt trying to use the suggestion of @MikeInnes : function ChainRulesCore.rrule(::Type{Example{Float64,2,F}},
sz::NTuple{N, Int},
gen::F,
) where {F,N}
val_grad(x) = Zygote._pullback(gen, x)[2](1.0)
gradgen(x) = val_grad(x)[1][:a]
function IFA_pullback(ΔΩ)
inner = Example{Float64,2,typeof(gradgen)}(sz, gradgen)
∂gen = ΔΩ .* inner
@show ∂gen
return (NO_FIELDS,NO_FIELDS,∂gen)
end
Ω = Example{Float64,2,typeof(gen)}(sz, gen)
return (Ω, IFA_pullback)
end As you see by the julia> gradient(c, 2.0)
∂gen = [4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0]
ERROR: Need an adjoint for constructor var"#g#14"{Float64}. Gradient is of type Matrix{Float64}
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] (::Zygote.Jnew{var"#g#14"{Float64}, Nothing, false})(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:314
[3] (::Zygote.var"#1723#back#196"{Zygote.Jnew{var"#g#14"{Float64}, Nothing, false}})(Δ::Matrix{Float64})
@ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
[4] Pullback
@ ~\Documents\Programming\Julia\Development\TestingZygote.jl:63 [inlined]
[5] (::typeof(∂(c)))(Δ::Float64)
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
[6] (::Zygote.var"#41#42"{typeof(∂(c))})(Δ::Float64)
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:41
[7] gradient(f::Function, args::Float64)
@ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:59
[8] top-level scope |
Again though, you're giving a matrix I suspect the right gradient here would be Instead, you want to do something like broadcast the pullback of |
Thanks @MikeInnes, for this hint. It took me ages to understand that not the returned using ChainRulesCore
using Zygote
struct Example{T,N,F} <: AbstractArray{T,N} where F
sz::NTuple{N, Int}
f::F
end
function Base.getindex(a::Example{T,N,F}, idx::Vararg{B,N}) where {T,N,F,B}
a.f(idx)
end
Base.size(e::Example) = e.sz
function ChainRulesCore.rrule(::Type{Example{Float64,2,F}},
sz::NTuple{N, Int},
gen::F,
) where {F,N}
val_grad(idx) = Zygote._pullback(gen, idx)[2](1.0) # 1.0 is only the seed
mySymbols = keys(val_grad(sz)[1])
gradgen(idx) = val_grad(idx)[1]
function IFA_pullback(ΔΩ)
Fcts = ((idx)-> val_grad(idx)[1][aSymbol] for aSymbol in mySymbols)
TupleVals = (ΔΩ .* Example{Float64,2,typeof(Fun)}(sz, Fun) for Fun in Fcts)
∂gen = NamedTuple{mySymbols}(TupleVals)
return (NO_FIELDS, NO_FIELDS, ∂gen)
end
Ω = Example{Float64,2,typeof(gen)}(sz, gen)
return (Ω, IFA_pullback)
end
c(a) = begin
g(idx)= idx[1] + idx[2] *a*a
myarr = Example{Float64,2,typeof(g)}((3,3),g) # 3,3 refers to size
sum(myarr)
end The output looks like this: julia> gradient(c, 2.0)
([4.0 8.0 12.0; 4.0 8.0 12.0; 4.0 8.0 12.0],) Pheew. That took longer than planned ;-) |
Does anyone know if |
Only if you know the differential explicitly or if there exists an |
The plan is to allow this by JuliaCon, so watch the issue @devmotion posted in case you will find this useful |
We are trying to write an
rrule
for a custom array class, such thatZygote
can differentiate through it, but are stuck due to an error about a missing adjoint for a constructor. This may well be a user error, but it could also be a problem ofZygote
. Any help is appreciated!This code seems to generally work fine for using the error, but the point is the needed ability to differentiate wrt a variable used in the innermost function. The code using this definitions, which then causes the error:
The error looks like this:
Something similar happens, if you place the variable
a
right behindsum(
in functionc
.The text was updated successfully, but these errors were encountered: