Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 24, 2023
1 parent 08becad commit 60b5d33
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
29 changes: 13 additions & 16 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where T = guess_activity(T,

@inline function Enzyme.guess_activity(::Type{T}, Mode::API.CDerivativeMode) where {T}
ActReg = active_reg_nothrow(T)
if c == AnyState
if ActReg == AnyState
return Const{T}
end
if Mode == API.DEM_ForwardMode
Expand Down Expand Up @@ -306,22 +306,19 @@ end
return DupState
end

const TypeActivityCache = Dict{DataType, ActivityState}()
const ConstantTypes = Type[]

@inline function active_reg_inner(::Type{T}, seen) where T
if T keys(seen)
return seen[T]
end
if EnzymeRules.inactive_type(T)
return seen[T] = AnyState
end
if T isa UnionAll
return AnyState
end
if isghostty(T) || Core.Compiler.isconstType(T)
return AnyState
end
if T isa UnionAll
return DupState
end
# if abstract it must be by reference
if Base.isabstracttype(T)
return DupState
Expand Down Expand Up @@ -362,13 +359,19 @@ const ConstantTypes = Type[]
return ty
end

@inline @generated function active_reg(::Type{T}) where {T}
state = active_reg_inner(T, TypeActivityCache)
@inline @generated function active_reg_nothrow(::Type{T}) where {T}
seen = Dict{DataType, ActivityState}()
return active_reg_inner(T, seen)
end

@inline function active_reg(::Type{T}) where {T}
state = active_reg_nothrow(T)
str = string(T)*" has mixed internal activity types"
@assert state != MixedState str
return state == ActiveState
end


# User facing interface
abstract type AbstractThunk{FA, RT, TT, Width} end

Expand Down Expand Up @@ -9815,7 +9818,7 @@ end
end

if !(A <: Const) && (active_reg_nothrow(rrt) == AnyState)
error("Return type `$rrt` not marked Const, but is ghost or const type.")
error("Return type `$rrt` not marked Const, but type is guaranteed to be constant")
end

if A isa UnionAll
Expand All @@ -9827,12 +9830,6 @@ end
rt = A
end

if rrt == Nothing && !(A <: Const)
error("Return of nothing must be marked Const")
end

# @assert isa(rrt, DataType)

# We need to use primal as the key, to lookup the right method
# but need to mixin the hash of the adjoint to avoid cache collisions
# This is counter-intuitive since we would expect the cache to be split
Expand Down
11 changes: 6 additions & 5 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ function EnzymeRules.inactive_noinl(::typeof(Base.size), args...)
return nothing
end

EnzymeRules.inactive_type(::Type{Nothing}) = true
EnzymeRules.inactive_type(::Type{Union{}}) = true
EnzymeRules.inactive_type(::Type{T}, seen) where {T<:Integer} = true
EnzymeRules.inactive_type(::Type{T}, seen) where {T<:Function} = true
EnzymeRules.inactive_type(::Type{T}, seen) where {T<:DataType} = true
EnzymeRules.inactive_type(::Type{T}, seen) where {T<:Module} = true
EnzymeRules.inactive_type(::Type{T}, seen) where {T<:AbstractString} = true
EnzymeRules.inactive_type(::Type{T}) where {T<:Integer} = true
EnzymeRules.inactive_type(::Type{T}) where {T<:Function} = true
EnzymeRules.inactive_type(::Type{T}) where {T<:DataType} = true
EnzymeRules.inactive_type(::Type{T}) where {T<:Module} = true
EnzymeRules.inactive_type(::Type{T}) where {T<:AbstractString} = true

# Note all of these forward mode definitions do not support runtime activity as
# the do not keep the primal if shadow(x.y) == primal(x.y)
Expand Down

0 comments on commit 60b5d33

Please sign in to comment.