Skip to content

Commit

Permalink
Abstract is mixed
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 14, 2024
1 parent a889bb6 commit 99906ff
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 21 deletions.
65 changes: 48 additions & 17 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,13 @@ end
ActivityState(Int(a1) | Int(a2))
end

struct Merger{seen,worldT,justActive,UnionSret}
struct Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed}
world::worldT
end

@inline element(::Val{T}) where T = T

@inline function (c::Merger{seen,worldT,justActive,UnionSret})(f::Int) where {seen,worldT,justActive,UnionSret}
@inline function (c::Merger{seen,worldT,justActive,UnionSret,AbstractIsMixed})(f::Int) where {seen,worldT,justActive,UnionSret,AbstractIsMixed}
T = element(first(seen))

reftype = ismutabletype(T) || T isa UnionAll
Expand All @@ -273,7 +273,7 @@ end
return Val(AnyState)
end

sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret))
sub = active_reg_inner(subT, seen, c.world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))

if sub == AnyState
Val(AnyState)
Expand Down Expand Up @@ -372,24 +372,31 @@ end
end)
end

@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}) where {ST, Seen, justActive, UnionSret}
@inline function active_reg_recur(::Type{ST}, seen::Seen, world, ::Val{justActive}, ::Val{UnionSret}, ::Val{AbstractIsMixed}) where {ST, Seen, justActive, UnionSret, AbstractIsMixed}
if ST isa Union
return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret))))
return forcefold(Val(active_reg_recur(ST.a, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))), Val(active_reg_recur(ST.b, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))))
end
return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret))
return active_reg_inner(ST, seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
end

@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret}
@inline is_vararg_tup(x) = false
@inline is_vararg_tup(::Type{Tuple{Vararg{T2}}}) where T2 = true

@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false), ::Val{AbstractIsMixed}=Val(false))::ActivityState where {ST,T, justActive, UnionSret, AbstractIsMixed}
if T === Any
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end

if T === Union{}
return AnyState
end

if T <: Complex && !(T isa UnionAll)
return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret))
return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed))
end

if T <: AbstractFloat
Expand All @@ -401,10 +408,14 @@ end
return AnyState
end

if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret)) == AnyState
if is_arrayorvararg_ty(T) && active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret), Val(AbstractIsMixed)) == AnyState
return AnyState
else
return DupState
if AbstractIsMixed && is_vararg_tup(T)
return MixedState
else
return DupState
end
end
end

Expand Down Expand Up @@ -434,10 +445,18 @@ end
if T isa UnionAll
aT = Base.argument_datatype(T)
if aT === nothing
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
if datatype_fieldcount(aT) === nothing
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
end

Expand All @@ -451,18 +470,30 @@ end
return AnyState
end
if active_reg_inner(T.a, seen, world, Val(justActive), Val(UnionSret)) != AnyState
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
if active_reg_inner(T.b, seen, world, Val(justActive), Val(UnionSret)) != AnyState
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end
end
return AnyState
end

# if abstract it must be by reference
if Base.isabstracttype(T)
return DupState
if AbstractIsMixed
return MixedState
else
return DupState
end
end

if ismutabletype(T)
Expand Down Expand Up @@ -504,7 +535,7 @@ end

seen2 = (Val(nT), seen...)

fty = Merger{seen2,typeof(world),justActive, UnionSret}(world)
fty = Merger{seen2,typeof(world),justActive, UnionSret, AbstractIsMixed}(world)

ty = forcefold(Val(AnyState), ntuple(fty, Val(fieldcount(nT)))...)

Expand Down
6 changes: 3 additions & 3 deletions src/rules/typeunstablerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function body_construct_augfwd(N, Width, primtypes, active_refs, primargs, batch
shadow_rets_i = Expr[]
aref = Symbol("active_ref_$i")
for w in 1:Width
sref = Symbol("shadow_"*string(i)*"_"*string(w))
sref = Symbol("sub_shadow_"*string(i)*"_"*string(w))
push!(shadow_rets_i, quote
$sref = if $aref == AnyState
$(primargs[i]);
Expand Down Expand Up @@ -248,10 +248,10 @@ function newstruct_common(fwd, run, offset, B, orig, gutils, normalR, shadowR)
# if any active [e.g. ActiveState / MixedState] data could exist
# err
if !fwd
if !found
if !found_partial
return false
end
act = active_reg_inner(typ, (), world)
act = active_reg_inner(typ_partial, (), world, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true))
if act == MixedState || act == ActiveState
return false
end
Expand Down
25 changes: 24 additions & 1 deletion test/mixedapplyiter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,27 @@ end
@test out[] 5562.9996
@test tupapprox(dx, [[(4.0, [5.4]), (6.0, [6.28])], [(15.8, [94.0]), (22.4, [112.0])]])
@test tupapprox(dx2, [[(3*4.0, [3*5.4]), (3*6.0, [3*6.28])], [(3*15.8, [3*94.0]), (3*22.4, [3*112.0])]])
end
end

struct MyRectilinearGrid5{FT,FZ}
x :: FT
z :: FZ
end


@inline flatten_tuple(a::Tuple) = @inbounds a[2:end]
@inline flatten_tuple(a::Tuple{<:Any}) = tuple() #inner_flatten_tuple(a[1])...)

function myupdate_state!(model)
tupled = Base.inferencebarrier((model,model))
flatten_tuple(tupled)
return nothing
end

@testset "Abstract type allocation" begin
model = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0])
dmodel = MyRectilinearGrid5{Float64, Vector{Float64}}(0.0, [0.0])
autodiff(Enzyme.Reverse,
myupdate_state!,
MixedDuplicated(model, Ref(dmodel)))
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ end
@assert Enzyme.Compiler.active_reg_inner(Tuple{S,Int64} where S, (), Base.get_world_counter()) == Enzyme.Compiler.DupState
@assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing) == Enzyme.Compiler.DupState
@assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing, #=justActive=#Val(false), #=unionSret=#Val(true)) == Enzyme.Compiler.ActiveState
@test active_reg_inner(Tuple, (), nothing) == Enzyme.Compiler.DupState
@test active_reg_inner(Tuple, (), nothing, #=justactive=#Val(false), #=unionsret=#Val(false), #=abstractismixed=#Val(true)) == Enzyme.Compiler.MixedState
world = codegen_world_age(typeof(f0), Tuple{Float64})
thunk_a = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI)
thunk_b = Enzyme.Compiler.thunk(Val(world), Const{typeof(f0)}, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1), Val((false, false)), Val(false), Val(false), DefaultABI)
Expand Down

0 comments on commit 99906ff

Please sign in to comment.