Skip to content

Commit

Permalink
more fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 8, 2024
1 parent 6981b8c commit 965466d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 31 deletions.
65 changes: 34 additions & 31 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1335,15 +1335,15 @@ end
function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S}
ntuple(Val(length(T.parameters))) do i
Base.@_inline_meta
make_zero_immutable(prev[i], seen)
make_zero_immutable!(prev[i], seen)
end
end

function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S}
NamedTuple{a, b}(
ntuple(Val(length(T.parameters))) do i
Base.@_inline_meta
make_zero_immutable(prev[a[i]], seen)
make_zero_immutable!(prev[a[i]], seen)
end
)
end
Expand All @@ -1353,7 +1353,7 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S}
if guaranteed_const_nongen(T, nothing)
return prev
end
@assert !mutable_register(T)
@assert !ismutable(T)

@assert !Base.isabstracttype(RT)
@assert Base.isconcretetype(RT)
Expand All @@ -1364,11 +1364,11 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S}
if isdefined(prev, i)
xi = getfield(prev, i)
ST = Core.Typeof(xi)
flds[i] = if mutable_register(ST)
flds[i] = if active_reg_inner(ST, (), nothing, #=justActive=#Val(true)) == ActiveState
make_zero_immutable!(xi, seen)
else
EnzymeCore.make_zero!(xi, seen)
xi
else
make_zero_immutable!(xi, seen)
end
else
nf = i - 1 # rest of tail must be undefined values
Expand Down Expand Up @@ -1422,20 +1422,20 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
if in(seen, prev)
return
end
insert!(seen, prev)
push!(seen, prev)

for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
SBT = Core.Typeof(pv)
if mutable_register(SBT)
EnzymeCore.make_zero!(pv, seen)
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
@inbounds prev[I] = make_zero_immutable!(pv, seen)
nothing
else
@inbounds prev[I] = EnzymeCore.make_zero_immutable!(pv, seen)
EnzymeCore.make_zero!(pv, seen)
nothing
end
end
Expand All @@ -1447,18 +1447,18 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
if in(seen, prev)
return
end
insert!(seen, prev)
push!(seen, prev)

pv = prev[]
SBT = Core.Typeof(pv)
if mutable_register(SBT)
EnzymeCore.make_zero!(pv, seen)
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
prev[] = make_zero_immutable!(pv, seen)
nothing
else
prev[] = EnzymeCore.make_zero_immutable!(pv, seen)
EnzymeCore.make_zero!(pv, seen)
nothing
end
nothing
Expand All @@ -1470,48 +1470,51 @@ end
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
if in(seen, prev)
return
end
insert!(seen, prev)
push!(seen, prev)
SBT = Core.Typeof(pv)
if mutable_register(SBT)
EnzymeCore.make_zero!(pv, seen)
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
prev.contents = EnzymeCore.make_zero_immutable!(pv, seen)
nothing
else
prev.contents = EnzymeCore.make_zero_immutable!(pv, seen)
EnzymeCore.make_zero!(pv, seen)
nothing
end
nothing
end

@inline function EnzymeCore.make_zero!(prev::T, seen::S=IdSet{Any}())::Nothing where {T, S}
@inline function EnzymeCore.make_zero!(prev::T, seen::S=Base.IdSet{Any}())::Nothing where {T, S}
if guaranteed_const_nongen(T, nothing)
return
end
if haskey(seen, prev)
if in(seen, prev)
return
end
@assert !Base.isabstracttype(RT)
@assert Base.isconcretetype(RT)
nf = fieldcount(RT)
@assert !Base.isabstracttype(T)
@assert Base.isconcretetype(T)
nf = fieldcount(T)


if nf == 0
return
end

insert!(seen, prev)
push!(seen, prev)

for i in 1:nf
if isdefined(prev, i)
xi = getfield(prev, i)
SBT = Core.Typeof(pv)
if mutable_register(SBT)
EnzymeCore.make_zero!(xi, seen)
SBT = Core.Typeof(xi)
if guaranteed_const_nongen(SBT, nothing)
continue
end
if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState
setfield!(prev, i, make_zero_immutable!(xi, seen))
nothing
else
setfield!(prev, i, make_zero_immutable!(xi, seen))
EnzymeCore.make_zero!(xi, seen)
nothing
end
end
Expand Down
22 changes: 22 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,28 @@ end
# @test thunk_split.primal !== C_NULL
# @test thunk_split.primal !== thunk_split.adjoint
# @test thunk_a.adjoint !== thunk_split.adjoint
#
z = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9])
Enzyme.make_zero!(z)
@test z[1] [0.0, 0.0, 0.0]
@test z[2][1] == 0
@test z[2][2] == 1
@test z[3] [0.0, 0.0]

z2 = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9])
Enzyme.make_zero!(z2)
@test z2[1] [0.0, 0.0, 0.0]
@test z2[2][1] == 0
@test z2[2][2] == 1
@test z2[3] [0.0, 0.0]

z3 = [3.4, "foo"]
Enzyme.make_zero!(z3)
@test z3[1] 0.0
@test z3[2] == "foo"

z4 = sin
Enzyme.make_zero!(z4)
end

@testset "Reflection" begin
Expand Down

0 comments on commit 965466d

Please sign in to comment.