-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AD error due to interaction of missing gradient of besselk
and struct with vector
#1204
Comments
As mentioned in JuliaGaussianProcesses/KernelFunctions.jl#452, the problem seems to be that Line 30 in 95e1021
Minimal example: julia> using ChainRulesCore
julia> f(x) = x;
julia> @scalar_rule f(x) @not_implemented(":(")
julia> using Zygote
julia> Zygote.gradient(f, 0.1)
(NotImplemented(Main, #= REPL[5]:1 =#, :(),)
julia> Zygote.gradient(x -> f(only(x)), 0.1)
(NotImplemented(Main, #= REPL[5]:1 =#, :(),)
julia> Zygote.gradient(x -> f(only(x)), (0.1,))
((NotImplemented(Main, #= REPL[5]:1 =#, :(),),)
julia> Zygote.gradient(x -> f(only(x)), [0.1])
ERROR: MethodError: no method matching Zygote.OneElement(::ChainRulesCore.NotImplemented, ::Tuple{Int64}, ::Tuple{Base.OneTo{Int64}})
Closest candidates are:
Zygote.OneElement(::T, ::I, ::A) where {N, T<:Number, I<:Tuple{Vararg{Int64, N}}, A<:Tuple{Vararg{AbstractUnitRange, N}}} at ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:53
Stacktrace:
[1] (::Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}})(dy::ChainRulesCore.NotImplemented)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:32
[2] (::Zygote.var"#2300#back#429"{Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}}})(Δ::ChainRulesCore.NotImplemented)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[3] Pullback (repeats 2 times)
@ ./array.jl:835 [inlined]
[4] (::typeof(∂(iterate)))(Δ::Tuple{ChainRulesCore.NotImplemented, Nothing})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[5] Pullback
@ ./iterators.jl:1352 [inlined]
[6] Pullback
@ ./REPL[10]:1 [inlined]
[7] (::typeof(∂(#8)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[8] (::Zygote.var"#56#57"{typeof(∂(#8))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
[9] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
[10] top-level scope
@ REPL[10]:1 This is also highlighted by the fact that it's not julia> Zygote.gradient(x -> f(x[1]), [0.1])
ERROR: MethodError: no method matching Zygote.OneElement(::ChainRulesCore.NotImplemented, ::Tuple{Int64}, ::Tuple{Base.OneTo{Int64}})
Closest candidates are:
Zygote.OneElement(::T, ::I, ::A) where {N, T<:Number, I<:Tuple{Vararg{Int64, N}}, A<:Tuple{Vararg{AbstractUnitRange, N}}} at ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:53
Stacktrace:
[1] (::Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}})(dy::ChainRulesCore.NotImplemented)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:32
[2] (::Zygote.var"#2300#back#429"{Zygote.var"#433#435"{1, Float64, Vector{Float64}, Tuple{Int64}}})(Δ::ChainRulesCore.NotImplemented)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[3] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/tools/builtins.jl:15 [inlined]
[4] (::typeof(∂(literal_getindex)))(Δ::ChainRulesCore.NotImplemented)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[5] Pullback
@ ./REPL[11]:1 [inlined]
[6] (::typeof(∂(#10)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[7] (::Zygote.var"#56#57"{typeof(∂(#10))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
[8] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
[9] top-level scope
@ REPL[11]:1 |
This definition julia> @inline Zygote.wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing fixes the issue: julia> Zygote.gradient(f, 0.1)
(nothing,)
julia> Zygote.gradient(x -> f(only(x)), 0.1)
(nothing,)
julia> Zygote.gradient(x -> f(only(x)), (0.1,))
(nothing,)
julia> Zygote.gradient(x -> f(only(x)), [0.1])
(nothing,)
julia> Zygote.gradient(x -> f(x[1]), [0.1])
(nothing,) Clearly, the information about not-implemented derivatives is lost but I guess this is the Zygote way of dealing with special tangent types such as Zygote.jl/src/compiler/chainrules.jl Line 108 in 1eb80c5
nothing design is so fundamental for Zygote (at least currently), it seems this is also the easiest way to fix other similar errors without having to deal with NotImplemented in all pullbacks.
|
Treating |
This is far from the only example of the lossy Zygote -> ChainRules (-> Zygote) conversion possibly causing issues for subsequent operations. Unfortunately without doing at least part of #603, the status quo is probably going to stay the same for the foreseeable future. |
this one is not just lossy, its actively doing something which is normally incorrect. |
Minimal reproducible example - note that
Foo
errors, whereasBar
works.The text was updated successfully, but these errors were encountered: