Skip to content

Commit

Permalink
simplify mixed activity use
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 13, 2024
1 parent 8f43fa3 commit dc7492f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 173 deletions.
4 changes: 2 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4479,7 +4479,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
elseif arg.cc != GPUCompiler.BITS_REF
if TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
push!(boxedArgs, arg.arg_i)
push!(raisedArgs, arg.arg_i)
push!(wrapper_types, LLVM.PointerType(typ, Derived))
Expand All @@ -4490,7 +4490,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
end
else
# bits ref, and not boxed
if TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated
if TT != nothing && (TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
Expand Down
204 changes: 33 additions & 171 deletions src/rules/jitrules.jl
Original file line number Diff line number Diff line change
@@ -1,86 +1,3 @@
function func_mixed_call(N)
allargs = Expr[]
typeargs = Union{Symbol,Expr}[]
exprs2 = Union{Symbol,Expr}[]
for i in 1:N
arg = Symbol("arg_$i")
targ = Symbol("T$i")
e = :($arg::$targ)
push!(allargs, e)
push!(typeargs, targ)

inarg = quote
if RefTypes[1+$i]
$arg[]
else
$arg
end
end
push!(exprs2, inarg)
end

quote
@generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)}
fexpr = :f
if RefTypes[1]
fexpr = :(($fexpr)[])
end
exprs2 = Union{Symbol,Expr}[]
for i in 1:$N
arg = Symbol("arg_$i")
inarg = if RefTypes[1+i]
:($arg[])
else
:($arg)
end
push!(exprs2, inarg)
end
@static if VERSION v"1.8-"
return quote
Base.@_inline_meta
@inline $fexpr($(exprs2...))
end
else
return quote
Base.@_inline_meta
$fexpr($(exprs2...))
end
end
end
end
end

@generated function runtime_mixed_call(::Val{RefTypes}, f::F, allargs::Vararg{Any, N}) where {RefTypes, F, N}
fexpr = :f
if RefTypes[1]
fexpr = :(($fexpr)[])
end
exprs2 = Union{Symbol,Expr}[]
for i in 1:N
inarg = if RefTypes[1+i]
:(allargs[$i][])
else
:(allargs[$i])
end
push!(exprs2, inarg)
end
@static if VERSION v"1.8-"
return quote
Base.@_inline_meta
@inline $fexpr($(exprs2...))
end
else
return quote
Base.@_inline_meta
$fexpr($(exprs2...))
end
end
end

for N in 0:10
eval(func_mixed_call(N))
end

function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false)
primargs = Union{Symbol,Expr}[]
shadowargs = Union{Symbol,Expr}[]
Expand Down Expand Up @@ -192,7 +109,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing,
if $aref == ActiveState
Active($(primargs[i]))
elseif $aref == MixedState
$((Width == 1) ? :Duplicated : :BatchDuplicated)(Ref($(primargs[i])), $(shadowargs[i]))
$((Width == 1) ? :MixedDuplicated : :BatchMixedDuplicated)($(primargs[i]), $(shadowargs[i]))
else
$((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i]))
end
Expand Down Expand Up @@ -361,45 +278,23 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs)
false
end

tt = Tuple{$(ElTypes...)}
rt = Core.Compiler.return_type(f, tt)
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

internal_tape, origRet, initShadow, annotation = if any_mixed
ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)}
rtM = Core.Compiler.return_type(runtime_mixed_call, ttM)
annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal)

annotationM = if $Width != 1 && annotation0M <: Duplicated
BatchDuplicated{rt, $Width}
else
annotation0M
end
worldM = codegen_world_age(typeof(runtime_mixed_call), ttM)
ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...))

forward, adjoint = thunk(Val(worldM),
Const{typeof(runtime_mixed_call)},
annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)

forward(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationM

annotationA = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt, $Width}
else
tt = Tuple{$(ElTypes...)}
rt = Core.Compiler.return_type(f, tt)
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

annotationA = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt, $Width}
else
annotation0
end
world = codegen_world_age(FT, tt)
annotation0
end
world = codegen_world_age(FT, tt)

forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT},
annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)
forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT},
annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)

forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationA
end
internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...)
annotation = annotationA

resT = typeof(origRet)
if annotation <: Const
Expand Down Expand Up @@ -523,64 +418,31 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act
rt = Core.Compiler.return_type(f, tt)
annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal)

if any_mixed
ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)}
rtM = Core.Compiler.return_type(runtime_mixed_call, ttM)
annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal)

annotationM = if $Width != 1 && annotation0M <: Duplicated
BatchDuplicated{rt, $Width}
else
annotation0M
end
worldM = codegen_world_age(typeof(runtime_mixed_call), ttM)
ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...))

_, adjoint = thunk(Val(worldM),
Const{typeof(runtime_mixed_call)},
annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)

if tape.shadow_return !== nothing
if !(annotation0M <: Active) && nonzero_active_data(($shadowret,))
ET = ($(ElTypes...),)
throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information"))
end
end
if annotation0M <: Active
adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)
else
adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)
end
nothing
annotation = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt, $Width}
else
annotation0
end

annotation = if $Width != 1 && annotation0 <: Duplicated
BatchDuplicated{rt, $Width}
else
annotation0
end

world = codegen_world_age(FT, tt)
world = codegen_world_age(FT, tt)

_, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT},
annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)
_, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT},
annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width,
ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI)

if tape.shadow_return !== nothing
if !(annotation0 <: Active) && nonzero_active_data(($shadowret,))
ET = ($(ElTypes...),)
throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information"))
end
end
tup = if annotation0 <: Active
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1]
else
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1]
if tape.shadow_return !== nothing
if !(annotation0 <: Active) && nonzero_active_data(($shadowret,))
ET = ($(ElTypes...),)
throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information"))
end

$(outs...)
end
tup = if annotation0 <: Active
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1]
else
adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1]
end

$(outs...)
return nothing
end
end
Expand Down

0 comments on commit dc7492f

Please sign in to comment.