Skip to content

Commit

Permalink
Further simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent 60b5d33 commit 25b52d2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 60 deletions.
94 changes: 36 additions & 58 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,25 +262,6 @@ const activefns = Set{String}((
MixedState = 3
end

Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where T = guess_activity(T, convert(API.CDerivativeMode, mode))

@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
ActReg = active_reg_nothrow(T)
if ActReg == AnyState
return Const{T}
end
if Mode == API.DEM_ForwardMode
return DuplicatedNoNeed{T}
else
if ActReg == ActiveState
return Active{T}
else
return Duplicated{T}
end
end
end


@inline function Base.:|(a1::ActivityState, a2::ActivityState)
ActivityState(Int(a1) | Int(a2))
end
Expand Down Expand Up @@ -360,8 +341,7 @@ end
end

@inline @generated function active_reg_nothrow(::Type{T}) where {T}
seen = Dict{DataType, ActivityState}()
return active_reg_inner(T, seen)
return active_reg_inner(T, IdDict())
end

@inline function active_reg(::Type{T}) where {T}
Expand All @@ -371,6 +351,25 @@ end
return state == ActiveState
end

@inline guaranteed_const(::Type{T}) where T = active_reg_nothrow(T) == AnyState

Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where T = guess_activity(T, convert(API.CDerivativeMode, mode))

@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
ActReg = active_reg_nothrow(T)
if ActReg == AnyState
return Const{T}
end
if Mode == API.DEM_ForwardMode
return DuplicatedNoNeed{T}
else
if ActReg == ActiveState
return Active{T}
else
return Duplicated{T}
end
end
end

# User facing interface
abstract type AbstractThunk{FA, RT, TT, Width} end
Expand Down Expand Up @@ -1102,8 +1101,7 @@ end

function runtime_newtask_fwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}) where {FT1, FT2, World, width}
FT = Core.Typeof(fn)
ghos = active_reg_nothrow(FT) == AnyState
forward = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
forward = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ForwardMode), Val(width), Val((false,)), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
ft = ghos ? Const(fn) : Duplicated(fn, dfn)
function fclosure()
res = forward(ft)
Expand All @@ -1116,8 +1114,7 @@ end
function runtime_newtask_augfwd(world::Val{World}, fn::FT1, dfn::FT2, post::Any, ssize::Int, ::Val{width}, ::Val{ModifiedBetween}) where {FT1, FT2, World, width, ModifiedBetween}
# TODO make this AD subcall type stable
FT = Core.Typeof(fn)
ghos = active_reg_nothrow(FT) == AnyState
forward, adjoint = thunk(world, (ghos ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
forward, adjoint = thunk(world, (guaranteed_const(FT) ? Const : Duplicated){FT}, Const, Tuple{}, Val(API.DEM_ReverseModePrimal), Val(width), Val(ModifiedBetween), #=returnPrimal=#Val(true), #=shadowinit=#Val(false), FFIABI)
ft = ghos ? Const(fn) : Duplicated(fn, dfn)
taperef = Ref{Any}()

Expand Down Expand Up @@ -1218,7 +1215,7 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing)
wrapped = Expr[]
for i in 1:N
expr = :(
if ActivityTup[$i+1] && active_reg_nothrow($(primtypes[i])) != AnyState
if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i]))
@assert $(primtypes[i]) !== DataType
if !$forwardMode && active_reg($(primtypes[i]))
Active($(primargs[i]))
Expand Down Expand Up @@ -1262,7 +1259,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes)

dupClosure = ActivityTup[1]
FT = Core.Typeof(f)
if dupClosure && active_reg_nothrow(FT) == AnyState
if dupClosure && guaranteed_const(FT)
dupClosure = false
end

Expand Down Expand Up @@ -1323,7 +1320,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes)

dupClosure = ActivityTup[1]
FT = Core.Typeof(f)
if dupClosure && active_reg_nothrow(rrt) == AnyState
if dupClosure && guaranteed_const(FT)
dupClosure = false
end

Expand Down Expand Up @@ -1430,7 +1427,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes)

dupClosure = ActivityTup[1]
FT = Core.Typeof(f)
if dupClosure && active_reg_nothrow(FT) == AnyState
if dupClosure && guaranteed_const(FT)
dupClosure = false
end
world = codegen_world_age(FT, tt)
Expand Down Expand Up @@ -2072,7 +2069,7 @@ end
end

@inline function make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT}
if active_reg_nothrow(RT) == AnyState
if guaranteed_const(RT)
return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev
end
if haskey(seen, prev)
Expand Down Expand Up @@ -3258,7 +3255,7 @@ end
width = get_width(gutils)

ops = collect(operands(orig))[1:end-1]
dupClosure = !(active_reg_nothrow(funcT) == AnyState) && !is_constant_value(gutils, ops[1])
dupClosure = !guaranteed_const(funcT) && !is_constant_value(gutils, ops[1])
pdupClosure = dupClosure

subfunc = nothing
Expand All @@ -3280,28 +3277,9 @@ end
subfunc = functions(mod)[fwdmodenm]

elseif mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient

# TODO can optimize to only do if could contain a float
if dupClosure
has_active = false
todo = Type[funcT]
while length(todo) != 0
T = pop!(todo)
if !allocatedinline(T)
continue
end
if fieldcount(T) == 0
if T <: Integer
continue
end
has_active = true
break
end
for f in 1:fieldcount(T)
push!(todo, fieldtype(T, f))
end
end

ty = active_reg_nothrow(funcTy)
has_active = ty == MixedState || ty == ActiveState
if has_active
refed = true
e_tt = Tuple{Duplicated{Base.RefValue{funcT}}, e_tt.parameters...}
Expand Down Expand Up @@ -3897,7 +3875,7 @@ function enzyme_custom_setup_args(B, orig, gutils, mi, RT, reverse, isKWCall)
for arg in jlargs
@assert arg.cc != RemovedParam
if arg.cc == GPUCompiler.GHOST
@assert active_reg_nothrow(arg.typ) == AnyState
@assert guaranteed_const(arg.typ)
if isKWCall && arg.arg_i == 2
Ty = arg.typ
kwtup = Ty
Expand Down Expand Up @@ -5883,7 +5861,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err
ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce))
typ = Base.unsafe_pointer_to_objref(ptr)
TT = Core.Typeof(typ)
if active_reg_nothrow(RT) == AnyState
if guaranteed_const(TT)
continue
end
badval = string(typ)*" of type"*" "*string(TT)
Expand All @@ -5905,7 +5883,7 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err
ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce)))
typ = Base.unsafe_pointer_to_objref(ptr)
TT = Core.Typeof(typ)
if active_reg_nothrow(RT) == AnyStates
if guaranteed_const(TT)
continue
end
badval = string(typ)*" of type"*" "*string(TT)
Expand Down Expand Up @@ -7444,7 +7422,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr

for (i, T) in enumerate(TT.parameters)
source_typ = eltype(T)
if active_reg_nothrow(source_typ) == AnyState
if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
if !(T <: Const)
error("Type of ghost or constant type "*string(T)*" is marked as differentiable.")
end
Expand Down Expand Up @@ -7696,7 +7674,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
ActiveRetTypes = Type[]
for (i, T) in enumerate(TT.parameters)
source_typ = eltype(T)
if active_reg_nothrow(source_typ) == AnyState
if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
@assert T <: Const
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
Expand Down Expand Up @@ -9817,7 +9795,7 @@ end
error("Function to differentiate `$mi` is guaranteed to return an error and doesn't make sense to autodiff. Giving up")
end

if !(A <: Const) && (active_reg_nothrow(rrt) == AnyState)
if !(A <: Const) && guaranteed_const(rrt)
error("Return type `$rrt` not marked Const, but type is guaranteed to be constant")
end

Expand Down
4 changes: 2 additions & 2 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ end


@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array}
if Enzyme.Compiler.active_reg_nothrow(RT) == AnyState
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
Expand All @@ -206,7 +206,7 @@ end
end

@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT}
if Enzyme.Compiler.active_reg_nothrow(RT) == AnyState
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
Expand Down

0 comments on commit 25b52d2

Please sign in to comment.