Skip to content

Commit

Permalink
use metaprogramming in blacklist
Browse files Browse the repository at this point in the history
Update src/compiler/interface2.jl

fix missing eval
  • Loading branch information
oxinabox committed Nov 21, 2019
1 parent 2dc2437 commit 70a8aff
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,18 @@ chainrules_blacklist(f, args...) = false

# ChainRules does higher-order functions badly
# see https://github.com/JuliaDiff/ChainRules.jl/issues/122
chainrules_blacklist(::typeof(map), args...) = true
chainrules_blacklist(::typeof(broadcast), args...) = true
chainrules_blacklist(::typeof(mapreduce), args...) = true
chainrules_blacklist(::typeof(mapfoldl), args...) = true
chainrules_blacklist(::typeof(mapfoldr), args...) = true
for f in (map, broadcast, mapreduce, mapfoldl, mapfoldr)
@eval chainrules_blacklist(::typeof($f), args...) = true
end
chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true
# Except for sum(abs2, xs), that is fine
chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false

# ChainRules current Wirtinger deriviative is not compatible
# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29
chainrules_blacklist(::typeof(abs), ::Complex) = true
chainrules_blacklist(::typeof(abs2), ::Complex) = true
chainrules_blacklist(::typeof(conj), ::Complex) = true
chainrules_blacklist(::typeof(adjoint), ::Complex) = true
chainrules_blacklist(::typeof(hypot), ::Complex) = true
chainrules_blacklist(::typeof(angle), ::Complex) = true
chainrules_blacklist(::typeof(imag), ::Complex) = true
chainrules_blacklist(::typeof(real), ::Complex) = true
for f in (abs, abs2, conj, adjoint, hypot, angle, imag, real)
@eval chainrules_blacklist(::typeof($f), ::Complex) = true
end

# Sum of nonarrays doesn't really work
# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124
Expand All @@ -66,8 +59,8 @@ Converts a ChainRules pullback into a Zygote pullback.
`pb` should be a ChainRules pullback, as returned from the second return value of `rrule`
"""==#
function _pullback_via_chainrules(pb)
# The less optimized version of this code is
# cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...))
# This is the optimized version of
# _pullback_via_chainrules(pb) = (Δs...) -> zextern.(pb(Δs...))
function zback(Δs...)
∂s = pb(Δs...)
ntuple(length(∂s)) do ii
Expand Down

0 comments on commit 70a8aff

Please sign in to comment.