diff --git a/src/sugar.jl b/src/sugar.jl index 3e68830100..b93b7fb0eb 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -254,19 +254,15 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) ``` """ +# TODO eventually add an invalidation edge here from inactive_type @generated function gradient( rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::ty_0, args::Vararg{Any,N}, ) where {F,ty_0,ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten,N} - toemit = Expr[quote - act_0 = - !(x isa Enzyme.Const) && - Compiler.active_reg_inner(Core.Typeof(x), (), nothing, Val(true)) == - Compiler.ActiveState #=justActive=# - end] rargs = Union{Symbol,Expr}[:x] + gentys = Type[x] acts = Symbol[Symbol("act_0")] for i = 1:N @@ -276,55 +272,69 @@ grad = gradient(ReverseWithPrimal, mul, [2.0], Const([3.0])) push!(rargs, argidx) sym = Symbol("act_$i") push!(acts, sym) - push!( - toemit, - quote - $sym = - !($argidx isa Enzyme.Const) && - Compiler.active_reg_inner( - Core.Typeof($argidx), - (), - nothing, - Val(true), - ) == Compiler.ActiveState #=justActive=# - end, - ) + push!(gentys, args[i]) + end + + toemit = Expr[] + states = Compiler.ActivityState[] + + for (argidx, act, genty) in zip(rargs, acts, gentys) + if genty <: Enzyme.Const + push!( + toemit, + quote + $act = false + end + ) + push!(states, Compiler.AnyState) + else + state = Compiler.active_reg_inner(genty, (), nothing) + push!(states, state) + end end idx = 0 - shadows = Symbol[] - enz_args = Expr[] - resargs = Expr[] - for (arg, act) in zip(rargs, acts) - shad = Symbol("shad_$idx") - push!(shadows, shad) - push!(toemit, quote - $shad = if $arg isa Enzyme.Const - nothing - elseif $act - Ref(make_zero($arg)) - else - make_zero($arg) - end - end) - push!(enz_args, quote - if $arg isa Enzyme.Const - $arg - elseif $act + enz_args = Union{Expr,Symbol}[] + resargs = Union{Expr,Symbol}[] + for (i, (arg, act, state, genty)) in enumerate(zip(rargs, acts, states, gentys)) + shad = Symbol("shad_$i") + if genty <: Enzyme.Const + push!(enz_args, arg) + push!(resargs, :nothing) + elseif state == Compiler.MixedState + push!(toemit, quote + $shad = Ref(make_zero($arg)) + end) + push!(enz_args, quote MixedDuplicated($arg, $shad) - else - Duplicated($arg, $shad) - end - end) - push!(resargs, quote - if $arg isa Enzyme.Const - nothing - elseif $act + end) + push!(resargs, quote $shad[] - else + end) + elseif state == Compiler.DupState + push!(toemit, quote + $shad = make_zero($arg) + end) + push!(enz_args, quote + Duplicated($arg, $shad) + end) + push!(resargs, quote $shad - end - end) + end) + elseif state == Compiler.ActiveState + push!(enz_args, quote + Active($arg) + end) + push!(resargs, quote + res[1][$i] + end) + else + @assert state == Compiler.AnyState + push!(enz_args, quote + Const($arg) + end) + push!(resargs, :nothing) + end idx += 1 end push!(toemit, quote