Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix methodinstance usage and backedges #2199

Merged
merged 17 commits into from
Dec 15, 2024
24 changes: 12 additions & 12 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
FTy = Core.Typeof(f.val)

rt = if A isa UnionAll
Compiler.primal_return_type(mode, FTy, tt)
Compiler.primal_return_type(Reverse, FTy, tt)
else
eltype(A)
end
Expand Down Expand Up @@ -410,7 +410,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
end

opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Reverse, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -536,7 +536,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value.
) where {FA<:Annotation,CMode<:Mode,Nargs}
tt = vaEltypeof(args...)
rt = Compiler.primal_return_type(
mode,
mode isa ForwardMode ? Forward : Reverse,
eltype(FA),
tt,
)
Expand Down Expand Up @@ -632,7 +632,7 @@ f(x) = x*x
tt = vaEltypeof(args...)

opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Forward, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -687,7 +687,7 @@ code, as well as high-order differentiation.
A2 = A

if A isa UnionAll
rt = Compiler.primal_return_type(mode, FTy, tt)
rt = Compiler.primal_return_type(Reverse, FTy, tt)
A2 = A{rt}
if rt == Union{}
rt = Nothing
Expand Down Expand Up @@ -840,7 +840,7 @@ code, as well as high-order differentiation.
FT = Core.Typeof(f.val)

if RT isa UnionAll
rt = Compiler.primal_return_type(mode, FT, tt)
rt = Compiler.primal_return_type(Forward, FT, tt)
if rt == Union{}
rt = Nothing
end
Expand Down Expand Up @@ -968,7 +968,7 @@ result, ∂v, ∂A

tt′ = Tuple{args...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Reverse, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -1098,7 +1098,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float

tt′ = Tuple{args...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Forward, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -1166,7 +1166,7 @@ end

primal_tt = Tuple{map(eltype, args)...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), TT)
my_methodinstance(Forward, eltype(FA), primal_tt)
else
Val(0)
end
Expand Down Expand Up @@ -1196,7 +1196,7 @@ const tape_cache = Dict{UInt,Type}()

const tape_cache_lock = ReentrantLock()

import .Compiler: fspec, remove_innerty, UnknownTapeType
import .Compiler: remove_innerty, UnknownTapeType

@inline function tape_type(
parent_job::Union{GPUCompiler.CompilerJob,Nothing},
Expand Down Expand Up @@ -1246,7 +1246,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType

primal_tt = Tuple{map(eltype, args)...}

mi = Compiler.fspec(eltype(FA), TT)
mi = my_methodinstance(parent_job === nothing ? Reverse : GPUCompiler.get_interpreter(parent_job), eltype(FA), primal_tt)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(
Expand Down Expand Up @@ -1382,7 +1382,7 @@ result, ∂v, ∂A
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}
rt0 = Compiler.primal_return_type(mode, eltype(FA), primal_tt)
rt0 = Compiler.primal_return_type(Reverse, eltype(FA), primal_tt)

rt = Compiler.remove_innerty(A2){rt0}

Expand Down
1 change: 1 addition & 0 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ end
EnzymeCore.EnzymeRules.inactive_type(T)
else
inmi = my_methodinstance(
nothing,
typeof(EnzymeCore.EnzymeRules.inactive_type),
Tuple{Type{T}},
world,
Expand Down
Loading
Loading