From f5e771eb1acf566251b9588f04baecb65fd6d90a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 23 Sep 2023 19:20:42 -0500 Subject: [PATCH] fixup --- src/compiler.jl | 27 ++++++++++++--------------- src/internal_rules.jl | 1 + 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0ec85223f1..d6dbe5a23c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -306,9 +306,6 @@ end return DupState end -const TypeActivityCache = Dict{DataType, ActivityState}() -const ConstantTypes = Type[] - @inline function active_reg_inner(::Type{T}, seen) where T if T ∈ keys(seen) return seen[T] @@ -316,12 +313,12 @@ const ConstantTypes = Type[] if EnzymeRules.inactive_type(T) return seen[T] = AnyState end - if T isa UnionAll - return AnyState - end if isghostty(T) || Core.Compiler.isconstType(T) return AnyState end + if T isa UnionAll + return DupState + end # if abstract it must be by reference if Base.isabstracttype(T) return DupState @@ -362,13 +359,19 @@ const ConstantTypes = Type[] return ty end -@inline @generated function active_reg(::Type{T}) where {T} - state = active_reg_inner(T, TypeActivityCache) +@inline @generated function active_reg_nothrow(::Type{T}) where {T} + seen = Dict{DataType, ActivityState}() + return active_reg_inner(T, seen) +end + +@inline function active_reg(::Type{T}) where {T} + state = active_reg_nothrow(T) str = string(T)*" has mixed internal activity types" @assert state != MixedState str return state == ActiveState end + # User facing interface abstract type AbstractThunk{FA, RT, TT, Width} end @@ -9815,7 +9818,7 @@ end end if !(A <: Const) && (active_reg_nothrow(rrt) == AnyState) - error("Return type `$rrt` not marked Const, but is ghost or const type.") + error("Return type `$rrt` not marked Const, but type is guaranteed to be constant") end if A isa UnionAll @@ -9827,12 +9830,6 @@ end rt = A end - if rrt == Nothing && !(A <: Const) - error("Return of nothing must be marked Const") - end - - # @assert isa(rrt, DataType) - # We need to use primal as the key, to lookup the right method # but need to mixin the hash of the adjoint to avoid cache collisions # This is counter-intuitive since we would expect the cache to be split diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 1cb1a451bf..dbc1b5a9da 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -100,6 +100,7 @@ function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end +EnzymeRules.inactive_type(::Type{Nothing}) = true EnzymeRules.inactive_type(::Type{Union{}}) = true EnzymeRules.inactive_type(::Type{T}, seen) where {T<:Integer} = true EnzymeRules.inactive_type(::Type{T}, seen) where {T<:Function} = true