From fd57be1379b8ee48b07a0327d323a6fb858c638f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 12 Jun 2024 21:21:44 -0700 Subject: [PATCH] simplify mixed activity use --- src/rules/jitrules.jl | 204 +++++++----------------------------------- 1 file changed, 33 insertions(+), 171 deletions(-) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index f2f9d27407..af8f83b80e 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,86 +1,3 @@ -function func_mixed_call(N) - allargs = Expr[] - typeargs = Union{Symbol,Expr}[] - exprs2 = Union{Symbol,Expr}[] - for i in 1:N - arg = Symbol("arg_$i") - targ = Symbol("T$i") - e = :($arg::$targ) - push!(allargs, e) - push!(typeargs, targ) - - inarg = quote - if RefTypes[1+$i] - $arg[] - else - $arg - end - end - push!(exprs2, inarg) - end - - quote - @generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)} - fexpr = :f - if RefTypes[1] - fexpr = :(($fexpr)[]) - end - exprs2 = Union{Symbol,Expr}[] - for i in 1:$N - arg = Symbol("arg_$i") - inarg = if RefTypes[1+i] - :($arg[]) - else - :($arg) - end - push!(exprs2, inarg) - end - @static if VERSION ≥ v"1.8-" - return quote - Base.@_inline_meta - @inline $fexpr($(exprs2...)) - end - else - return quote - Base.@_inline_meta - $fexpr($(exprs2...)) - end - end - end - end -end - -@generated function runtime_mixed_call(::Val{RefTypes}, f::F, allargs::Vararg{Any, N}) where {RefTypes, F, N} - fexpr = :f - if RefTypes[1] - fexpr = :(($fexpr)[]) - end - exprs2 = Union{Symbol,Expr}[] - for i in 1:N - inarg = if RefTypes[1+i] - :(allargs[$i][]) - else - :(allargs[$i]) - end - push!(exprs2, inarg) - end - @static if VERSION ≥ v"1.8-" - return quote - Base.@_inline_meta - @inline $fexpr($(exprs2...)) - end - else - return quote - Base.@_inline_meta - $fexpr($(exprs2...)) - end - end -end - -for N in 0:10 - eval(func_mixed_call(N)) -end - function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false; func=true, mixed_or_active = false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] @@ -192,7 +109,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, if $aref == ActiveState Active($(primargs[i])) elseif $aref == MixedState - $((Width == 1) ? :Duplicated : :BatchDuplicated)(Ref($(primargs[i])), $(shadowargs[i])) + $((Width == 1) ? :MixedDuplicated : :BatchMixedDuplicated)($(primargs[i]), $(shadowargs[i])) else $((Width == 1) ? :Duplicated : :BatchDuplicated)($(primargs[i]), $(shadowargs[i])) end @@ -361,45 +278,23 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes, active_refs) false end + tt = Tuple{$(ElTypes...)} + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - internal_tape, origRet, initShadow, annotation = if any_mixed - ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} - rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) - annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - - annotationM = if $Width != 1 && annotation0M <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0M - end - worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) - ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) - - forward, adjoint = thunk(Val(worldM), - Const{typeof(runtime_mixed_call)}, - annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - - forward(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationM - + annotationA = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} else - tt = Tuple{$(ElTypes...)} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - - annotationA = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0 - end - world = codegen_world_age(FT, tt) + annotation0 + end + world = codegen_world_age(FT, tt) - forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, - annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotationA, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...)..., annotationA - end + internal_tape, origRet, initShadow = forward(dupClosure0 ? Duplicated(f, df) : Const(f), args...) + annotation = annotationA resT = typeof(origRet) if annotation <: Const @@ -523,64 +418,31 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs, act rt = Core.Compiler.return_type(f, tt) annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - if any_mixed - ttM = Tuple{Val{active_refs}, FT, $(ElTypes...)} - rtM = Core.Compiler.return_type(runtime_mixed_call, ttM) - annotation0M = guess_activity(rtM, API.DEM_ReverseModePrimal) - - annotationM = if $Width != 1 && annotation0M <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0M - end - worldM = codegen_world_age(typeof(runtime_mixed_call), ttM) - ModifiedBetweenM = Val((false, false, element(ModifiedBetween)...)) - - _, adjoint = thunk(Val(worldM), - Const{typeof(runtime_mixed_call)}, - annotationM, Tuple{Const{Val{active_refs}}, dupClosure0 ? Duplicated{FT} : Const{FT}, $(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetweenM, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - - if tape.shadow_return !== nothing - if !(annotation0M <: Active) && nonzero_active_data(($shadowret,)) - ET = ($(ElTypes...),) - throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) - end - end - if annotation0M <: Active - adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape) - else - adjoint(Const(runtime_mixed_call), Const(Val(active_refs)), dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape) - end - nothing + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} else + annotation0 + end - annotation = if $Width != 1 && annotation0 <: Duplicated - BatchDuplicated{rt, $Width} - else - annotation0 - end - - world = codegen_world_age(FT, tt) + world = codegen_world_age(FT, tt) - _, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, - annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + _, adjoint = thunk(Val(world), dupClosure0 ? Duplicated{FT} : Const{FT}, + annotation, Tuple{$(Types...)}, Val(API.DEM_ReverseModePrimal), width, + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - if tape.shadow_return !== nothing - if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) - ET = ($(ElTypes...),) - throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) - end - end - tup = if annotation0 <: Active - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] - else - adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + if tape.shadow_return !== nothing + if !(annotation0 <: Active) && nonzero_active_data(($shadowret,)) + ET = ($(ElTypes...),) + throw(AssertionError("Shadow value "*string(($shadowret,))*" returned from type unstable call to $f($(ET...)) has mixed internal activity types. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information")) end - - $(outs...) end + tup = if annotation0 <: Active + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., $shadowret, tape.internal_tape)[1] + else + adjoint(dupClosure0 ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] + end + + $(outs...) return nothing end end