-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type inference fails for custom type broadcasting #1359
Comments
Cthulhu reports a julia> pb(d_result)
ERROR: MethodError: no method matching conj(::MyFloat64)
Closest candidates are:
conj(::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S}) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/symmetric.jl:368
conj(::SparseArrays.SparseVector{<:Complex}) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/SparseArrays/src/sparsevector.jl:1215
conj(::ChainRulesCore.Tangent) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_types/tangent.jl:167
...
Stacktrace:
[1] exp_pullback
@ ~/.julia/packages/ChainRules/RZYEu/src/rulesets/Base/fastmath_able.jl:56 [inlined]
[2] ZBack
@ ~/.julia/packages/Zygote/AS0Go/src/compiler/chainrules.jl:206 [inlined]
[3] (::Zygote.var"#938#943")(::Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1313"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}, ȳ₁::MyFloat64)
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/lib/broadcast.jl:205
[4] (::Base.var"#4#5"{Zygote.var"#938#943"})(a::Tuple{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1313"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}, MyFloat64})
@ Base ./generator.jl:36
[5] iterate
@ ./generator.jl:47 [inlined]
[6] collect
@ ./array.jl:787 [inlined]
[7] map
@ ./abstractarray.jl:3055 [inlined]
[8] (::Zygote.var"#∇broadcasted#942"{Tuple{Vector{MyFloat64}}, Vector{Tuple{MyFloat64, Zygote.ZBack{ChainRules.var"#exp_pullback#1313"{MyFloat64, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}}}}}, Val{2}})(ȳ::Vector{MyFloat64})
@ Zygote ~/.julia/packages/Zygote/AS0Go/src/lib/broadcast.jl:205
[9] #3885#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[10] #208
@ ~/.julia/packages/Zygote/AS0Go/src/lib/lib.jl:206 [inlined]
[11] #2066#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[12] Pullback
@ ./broadcast.jl:1298 [inlined]
[13] (::Zygote.var"#3909#back#952"{typeof(∂(broadcasted))})(Δ::Vector{MyFloat64})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[14] top-level scope
@ REPL[11]:1 |
I was over-simplifying the MWE a little bit 😅 . Thanks for introducing Cthulhu - great tool! Wasn't aware of that. The actual MWE that makes using Zygote, InteractiveUtils
import Base: exp, conj, +, -, *, /
struct MyFloat64 <: Number
n::Float64
end
*(f1::MyFloat64, f2::MyFloat64) = MyFloat64(f1.n * f2.n)
exp(f::MyFloat64) = MyFloat64(exp(f.n))
conj(f::MyFloat64) = f
my_vector = MyFloat64[1.0, 2.0, 3.0]
result, pb = Zygote._pullback(broadcast, exp, my_vector)
d_result = MyFloat64[1.0, 1.0, 1.0]
pb(d_result) And now, throwing Cthulhu on it, I get:
The I mean, in theory I believe |
So, let's do this instead: struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing
function test_type_inference()
dxs_zip = Tuple{Nothing, Float64}[(nothing, 1.)]
dxs = ntuple(Val(2)) do i
map(StaticGetter{i}(), dxs_zip)
end
dxs
end type is corrupted:
|
I also found the correct methods to implement eventually. The issue appears to be some compiler heuristic being violated, which leads to no const prop from ntuple and |
Thank you for the quick fix, I verified your corrections locally by In case you are interested in why I'm always throwing questions onto the corner cases of Zygote, I'm building a fast higher-order forward-mode AD (essentially a rewrite of TaylorSeries.jl with statically inferred polynomial types) and I want it to be downstream-compatible with Zygote so that it can be used in cases like NeuralPDE.jl. |
Closed by #1360 |
Let's say I have a broadcast over an array of custom type. The following MWE warns that the return type of
pb(d_result)
cannot be determined:output is
To my best knowledge of Zygote, the magic happens here, but I couldn't tell what happens in the pullback function so that it's type unstable.
Real
andComplex
are rescued by ForwardDiff, but not for custom types.p.s. Found a similar issue #885 which is two years ago, but it didn't really get solved.
The text was updated successfully, but these errors were encountered: