Skip to content

Commit

Permalink
compact choosing pullback mechanism code
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Nov 21, 2019
1 parent 70a8aff commit 951dac6
Showing 1 changed file with 35 additions and 40 deletions.
75 changes: 35 additions & 40 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@ using IRTools.Inner: argnames!, update!

ignore(T) = all(T -> T <: Type, T.parameters)


function _pullback(ctx::AContext, f, args...)
if chainrules_blacklist(f, args...)
# then don't even consider using ChainRules
return _pullback_via_source2source(ctx, f, args...)
end

res = ChainRules.rrule(f, args...)
if res === nothing
# No ChainRule defined, time to do the source tranform
if chainrules_blacklist(f, args...) || (res = ChainRules.rrule(f, args...)) === nothing
# Blacklisted or no ChainRule defined, time to do the source tranform
return _pullback_via_source2source(ctx, f, args...)
else
# Can just use ChainRule answer
Expand All @@ -21,6 +14,39 @@ function _pullback(ctx::AContext, f, args...)
end
end

@generated function _pullback_via_source2source(ctx::AContext, f, args...)
T = Tuple{f,args...}
ignore(T) && return :(f(args...), Pullback{$T}(()))
g = try _lookup_grad(T) catch e e end
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
# IRTools.verify(forw)
forw = slots!(pis!(inlineable!(forw)))
return update!(meta.code, forw)
end

@generated function (j::Pullback{T})(Δ) where T
ignore(T) && return :nothing
g = try _lookup_grad(T)
catch e
rethrow(CompileError(T,e))
end
if g == nothing
Δ == Nothing && return :nothing
return :(error("Non-differentiable function $(repr(j.t[1]))"))
end
meta, _, back = g
argnames!(meta, Symbol("#self#"), )
# IRTools.verify(back)
back = slots!(inlineable!(back))
return update!(meta.code, back)
end




#=="""
chainrules_blacklist(f, args...,)
Expand Down Expand Up @@ -73,34 +99,3 @@ end
zextern(x) = ChainRules.extern(x)
zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing
zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing


@generated function _pullback_via_source2source(ctx::AContext, f, args...)
T = Tuple{f,args...}
ignore(T) && return :(f(args...), Pullback{$T}(()))
g = try _lookup_grad(T) catch e e end
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
# IRTools.verify(forw)
forw = slots!(pis!(inlineable!(forw)))
return update!(meta.code, forw)
end

@generated function (j::Pullback{T})(Δ) where T
ignore(T) && return :nothing
g = try _lookup_grad(T)
catch e
rethrow(CompileError(T,e))
end
if g == nothing
Δ == Nothing && return :nothing
return :(error("Non-differentiable function $(repr(j.t[1]))"))
end
meta, _, back = g
argnames!(meta, Symbol("#self#"), )
# IRTools.verify(back)
back = slots!(inlineable!(back))
return update!(meta.code, back)
end

0 comments on commit 951dac6

Please sign in to comment.