diff --git a/src/compiler.jl b/src/compiler.jl index 83ff363dbd..57086a2ec6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1335,7 +1335,7 @@ end function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S} ntuple(Val(length(T.parameters))) do i Base.@_inline_meta - make_zero_immutable(prev[i], seen) + make_zero_immutable!(prev[i], seen) end end @@ -1343,7 +1343,7 @@ function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} NamedTuple{a, b}( ntuple(Val(length(T.parameters))) do i Base.@_inline_meta - make_zero_immutable(prev[a[i]], seen) + make_zero_immutable!(prev[a[i]], seen) end ) end @@ -1353,7 +1353,7 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} if guaranteed_const_nongen(T, nothing) return prev end - @assert !mutable_register(T) + @assert !ismutable(T) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) @@ -1364,11 +1364,11 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if mutable_register(ST) + flds[i] = if active_reg_inner(ST, (), nothing, #=justActive=#Val(true)) == ActiveState + make_zero_immutable!(xi, seen) + else EnzymeCore.make_zero!(xi, seen) xi - else - make_zero_immutable!(xi, seen) end else nf = i - 1 # rest of tail must be undefined values @@ -1422,20 +1422,20 @@ end if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - insert!(seen, prev) + push!(seen, prev) for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(pv, seen) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + @inbounds prev[I] = EnzymeCore.make_zero_immutable!(pv, seen) nothing else - @inbounds prev[I] = EnzymeCore.make_zero_immutable!(pv, seen) + EnzymeCore.make_zero!(pv, seen) nothing end end @@ -1447,18 +1447,18 @@ end if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - insert!(seen, prev) + push!(seen, prev) pv = prev[] SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(pv, seen) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + prev[] = EnzymeCore.make_zero_immutable!(pv, seen) nothing else - prev[] = EnzymeCore.make_zero_immutable!(pv, seen) + EnzymeCore.make_zero!(pv, seen) nothing end nothing @@ -1470,48 +1470,51 @@ end if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - insert!(seen, prev) + push!(seen, prev) SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(pv, seen) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) nothing else - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + EnzymeCore.make_zero!(pv, seen) nothing end nothing end -@inline function EnzymeCore.make_zero!(prev::T, seen::S=IdSet{Any}())::Nothing where {T, S} +@inline function EnzymeCore.make_zero!(prev::T, seen::S=Base.IdSet{Any}())::Nothing where {T, S} if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) if nf == 0 return end - insert!(seen, prev) + push!(seen, prev) for i in 1:nf if isdefined(prev, i) xi = getfield(prev, i) - SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(xi, seen) + SBT = Core.Typeof(xi) + if guaranteed_const_nongen(SBT, nothing) + continue + end + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + setfield!(prev, i, make_zero_immutable!(xi, seen)) nothing else - setfield!(prev, i, make_zero_immutable!(xi, seen)) + EnzymeCore.make_zero!(xi, seen) nothing end end diff --git a/test/runtests.jl b/test/runtests.jl index 225ddf435f..0212ec0d83 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -181,6 +181,28 @@ end # @test thunk_split.primal !== C_NULL # @test thunk_split.primal !== thunk_split.adjoint # @test thunk_a.adjoint !== thunk_split.adjoint + # + z = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9]) + Enzyme.make_zero!(z) + @test z[1] ≈ [0.0, 0.0, 0.0] + @test z[2][1] == 0 + @test z[2][2] == 1 + @test z[3] ≈ [0.0, 0.0] + + z2 = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9]) + Enzyme.make_zero!(z2) + @test z2[1] ≈ [0.0, 0.0, 0.0] + @test z2[2][1] == 0 + @test z2[2][2] == 1 + @test z2[3] ≈ [0.0, 0.0] + + z3 = [3.4, "foo"] + Enzyme.make_zero!(z3) + @test z3[1] ≈ 0.0 + @test z3[2] == "foo" + + z4 = sin + Enzyme.make_zero!(z4) end @testset "Reflection" begin