diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 509e3a48a..bfcd7d720 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -33,25 +33,18 @@ chainrules_blacklist(f, args...) = false # ChainRules does higher-order functions badly # see https://github.com/JuliaDiff/ChainRules.jl/issues/122 -chainrules_blacklist(::typeof(map), args...) = true -chainrules_blacklist(::typeof(broadcast), args...) = true -chainrules_blacklist(::typeof(mapreduce), args...) = true -chainrules_blacklist(::typeof(mapfoldl), args...) = true -chainrules_blacklist(::typeof(mapfoldr), args...) = true +for f in (map, broadcast, mapreduce, mapfoldl, mapfoldr) + @eval chainrules_blacklist(::typeof($f), args...) = true +end chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true # Except for sum(abs2, xs), that is fine chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false # ChainRules current Wirtinger deriviative is not compatible # reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29 -chainrules_blacklist(::typeof(abs), ::Complex) = true -chainrules_blacklist(::typeof(abs2), ::Complex) = true -chainrules_blacklist(::typeof(conj), ::Complex) = true -chainrules_blacklist(::typeof(adjoint), ::Complex) = true -chainrules_blacklist(::typeof(hypot), ::Complex) = true -chainrules_blacklist(::typeof(angle), ::Complex) = true -chainrules_blacklist(::typeof(imag), ::Complex) = true -chainrules_blacklist(::typeof(real), ::Complex) = true +for f in (abs, abs2, conj, adjoint, hypot, angle, imag, real) + @eval chainrules_blacklist(::typeof($f), ::Complex) = true +end # Sum of nonarrays doesn't really work # Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124 @@ -66,8 +59,8 @@ Converts a ChainRules pullback into a Zygote pullback. `pb` should be a ChainRules pullback, as returned from the second return value of `rrule` """==# function _pullback_via_chainrules(pb) - # The less optimized version of this code is - # cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...)) + # This is the optimized version of + # _pullback_via_chainrules(pb) = (Δs...) -> zextern.(pb(Δs...)) function zback(Δs...) ∂s = pb(Δs...) ntuple(length(∂s)) do ii