diff --git a/src/compiler.jl b/src/compiler.jl index ec04a35e36..ddeb36b750 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -262,25 +262,6 @@ const activefns = Set{String}(( MixedState = 3 end -Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where T = guess_activity(T, convert(API.CDerivativeMode, mode)) - -@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} - ActReg = active_reg_nothrow(T) - if ActReg == AnyState - return Const{T} - end - if Mode == API.DEM_ForwardMode - return DuplicatedNoNeed{T} - else - if ActReg == ActiveState - return Active{T} - else - return Duplicated{T} - end - end -end - - @inline function Base.:|(a1::ActivityState, a2::ActivityState) ActivityState(Int(a1) | Int(a2)) end @@ -360,8 +341,7 @@ end end @inline @generated function active_reg_nothrow(::Type{T}) where {T} - seen = Dict{DataType, ActivityState}() - return active_reg_inner(T, seen) + return active_reg_inner(T, IdDict()) end @inline function active_reg(::Type{T}) where {T} @@ -371,6 +351,25 @@ end return state == ActiveState end +@inline guaranteed_const(::Type{T}) where T = active_reg_nothrow(T) == AnyState + +Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where T = guess_activity(T, convert(API.CDerivativeMode, mode)) + +@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T} + ActReg = active_reg_nothrow(T) + if ActReg == AnyState + return Const{T} + end + if Mode == API.DEM_ForwardMode + return DuplicatedNoNeed{T} + else + if ActReg == ActiveState + return Active{T} + else + return Duplicated{T} + end + end +end # User facing interface abstract type AbstractThunk{FA, RT, TT, Width} end @@ -1102,8 +1101,7 @@ end function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width} FT = Core.Typeof(fn) - ghos = active_reg_nothrow(FT) == AnyState - forward = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + forward = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) ft = ghos ? Const(fn) : Duplicated(fn, dfn) function fclosure() res = forward(ft) @@ -1116,8 +1114,7 @@ end function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween} # TODO make this AD subcall type stable FT = Core.Typeof(fn) - ghos = active_reg_nothrow(FT) == AnyState - forward, adjoint = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) + forward, adjoint = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI) ft = ghos ? Const(fn) : Duplicated(fn, dfn) taperef = Ref{Any}() @@ -1218,7 +1215,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) wrapped = Expr[] for i in 1:N expr = :( - if ActivityTup[$i+1] && active_reg_nothrow($(primtypes[i])) != AnyState + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) @assert $(primtypes[i]) !== DataType if !$forwardMode && active_reg($(primtypes[i])) Active($(primargs[i])) @@ -1262,7 +1259,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) dupClosure = ActivityTup[1] FT = Core.Typeof(f) - if dupClosure && active_reg_nothrow(FT) == AnyState + if dupClosure && guaranteed_const(FT) dupClosure = false end @@ -1323,7 +1320,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) dupClosure = ActivityTup[1] FT = Core.Typeof(f) - if dupClosure && active_reg_nothrow(rrt) == AnyState + if dupClosure && guaranteed_const(FT) dupClosure = false end @@ -1430,7 +1427,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes) dupClosure = ActivityTup[1] FT = Core.Typeof(f) - if dupClosure && active_reg_nothrow(FT) == AnyState + if dupClosure && guaranteed_const(FT) dupClosure = false end world = codegen_world_age(FT, tt) @@ -2072,7 +2069,7 @@ end end @inline function make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT} - if active_reg_nothrow(RT) == AnyState + if guaranteed_const(RT) return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev end if haskey(seen, prev) @@ -3258,7 +3255,7 @@ end width = get_width(gutils) ops = collect(operands(orig))[1:end-1] - dupClosure = !(active_reg_nothrow(funcT) == AnyState) && !is_constant_value(gutils, ops[1]) + dupClosure = !guaranteed_const(funcT) && !is_constant_value(gutils, ops[1]) pdupClosure = dupClosure subfunc = nothing @@ -3280,28 +3277,9 @@ end subfunc = functions(mod)[fwdmodenm] elseif mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient - - # TODO can optimize to only do if could contain a float if dupClosure - has_active = false - todo = Type[funcT] - while length(todo) != 0 - T = pop!(todo) - if !allocatedinline(T) - continue - end - if fieldcount(T) == 0 - if T <: Integer - continue - end - has_active = true - break - end - for f in 1:fieldcount(T) - push!(todo, fieldtype(T, f)) - end - end - + ty = active_reg_nothrow(funcTy) + has_active = ty == MixedState || ty == ActiveState if has_active refed = true e_tt = Tuple{Duplicated{Base.RefValue{funcT}}, e_tt.parameters...} @@ -3897,7 +3875,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall) for arg in jlargs @assert arg.cc != RemovedParam if arg.cc == GPUCompiler.GHOST - @assert active_reg_nothrow(arg.typ) == AnyState + @assert guaranteed_const(arg.typ) if isKWCall && arg.arg_i == 2 Ty = arg.typ kwtup = Ty @@ -5883,7 +5861,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce)) typ = Base.unsafe_pointer_to_objref(ptr) TT = Core.Typeof(typ) - if active_reg_nothrow(RT) == AnyState + if guaranteed_const(TT) continue end badval = string(typ)*" of type"*" "*string(TT) @@ -5905,7 +5883,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) typ = Base.unsafe_pointer_to_objref(ptr) TT = Core.Typeof(typ) - if active_reg_nothrow(RT) == AnyStates + if guaranteed_const(TT) continue end badval = string(typ)*" of type"*" "*string(TT) @@ -7444,7 +7422,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr for (i, T) in enumerate(TT.parameters) source_typ = eltype(T) - if active_reg_nothrow(source_typ) == AnyState + if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) if !(T <: Const) error("Type of ghost or constant type "*string(T)*" is marked as differentiable.") end @@ -7696,7 +7674,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, ActiveRetTypes = Type[] for (i, T) in enumerate(TT.parameters) source_typ = eltype(T) - if active_reg_nothrow(source_typ) == AnyState + if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) @assert T <: Const if is_adjoint && i != 1 push!(ActiveRetTypes, Nothing) @@ -9817,7 +9795,7 @@ end error("Function to differentiate `$mi` is guaranteed to return an error and doesn't make sense to autodiff. Giving up") end - if !(A <: Const) && (active_reg_nothrow(rrt) == AnyState) + if !(A <: Const) && guaranteed_const(rrt) error("Return type `$rrt` not marked Const, but type is guaranteed to be constant") end diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 22af1c06ed..eb2f720d70 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -184,7 +184,7 @@ end @inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} - if Enzyme.Compiler.active_reg_nothrow(RT) == AnyState + if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end if !haskey(seen, into) @@ -206,7 +206,7 @@ end end @inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT} - if Enzyme.Compiler.active_reg_nothrow(RT) == AnyState + if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end if !haskey(seen, into)