From c0a83f2d9b181893a27155373fb68f2880694062 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 6 Jun 2024 19:42:07 -0400 Subject: [PATCH 1/3] Make zero in place --- lib/EnzymeCore/src/EnzymeCore.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 30577a38e8..482561607b 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -226,7 +226,14 @@ function autodiff_deferred_thunk end Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. """ -function make_zero end +function make_zero + +""" + make_zero!(prev::T)::T + + Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +""" +function make_zero! end """ make_zero(prev::T) From 6981b8c3af6da56e23545eef4adb4b91231eef33 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 7 Jun 2024 17:31:28 +0200 Subject: [PATCH 2/3] add make_zero! --- Project.toml | 2 +- examples/custom_rule.jl | 8 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 4 +- src/Enzyme.jl | 6 +- src/api.jl | 2 +- src/compiler.jl | 199 ++++++++++++++++++++++++++++++- 7 files changed, 209 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 8e31225cd9..87c6d55dcc 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.7.3" +EnzymeCore = "0.7.4" Enzyme_jll = "0.0.119" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" diff --git a/examples/custom_rule.jl b/examples/custom_rule.jl index 836d299c1e..c2098006c2 100644 --- a/examples/custom_rule.jl +++ b/examples/custom_rule.jl @@ -134,7 +134,7 @@ function forward(func::Const{typeof(f)}, RT::Type{<:Union{Const, DuplicatedNoNee if !(x isa Const) && !(y isa Const) y.dval .= 2 .* x.val .* x.dval elseif !(y isa Const) - y.dval .= 0 + make_zero!(y.dval) end dret = !(y isa Const) ? sum(y.dval) : zero(eltype(y.val)) if RT <: Const @@ -211,7 +211,7 @@ function reverse(config::ConfigWidth{1}, func::Const{typeof(f)}, dret::Active, t x.dval .+= 2 .* xval .* dret.val ## also accumulate any derivative in y's shadow into x's shadow. x.dval .+= 2 .* xval .* y.dval - y.dval .= 0 + make_zero!(y.dval) return (nothing, nothing) end @@ -251,8 +251,8 @@ end x = [3.0, 1.0] y = [0.0, 0.0] -dx .= 0 -dy .= 0 +make_zero!(dx) +make_zero!(dy) autodiff(Reverse, h, Duplicated(y, dy), Duplicated(x, dx)) @show dx # derivative of h w.r.t. x diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 670e1f3014..20a89b9a05 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.7.3" +version = "0.7.4" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 482561607b..fb788fd5a6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -226,10 +226,10 @@ function autodiff_deferred_thunk end Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. """ -function make_zero +function make_zero end """ - make_zero!(prev::T)::T + make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. """ diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 911d1801ad..7626304944 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -14,8 +14,8 @@ export BatchDuplicatedFunc import EnzymeCore: batch_size, get_func export batch_size, get_func -import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero -export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero +import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! +export autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero, make_zero! export jacobian, gradient, gradient! export markType, batch_size, onehot, chunkedonehot @@ -1007,7 +1007,7 @@ gradient!(Reverse, dx, f, [2.0, 3.0]) ``` """ @inline function gradient!(::ReverseMode, dx::X, f::F, x::X) where {X<:Array, F} - dx .= 0 + make_zero!(dx) autodiff(Reverse, f, Active, Duplicated(x, dx)) dx end diff --git a/src/api.jl b/src/api.jl index 3c626635b0..d68d904d5a 100644 --- a/src/api.jl +++ b/src/api.jl @@ -104,7 +104,7 @@ struct CFnTypeInfo end -@static if isdefined(LLVM, :InstructionMetadataDict) +@static if !isdefined(LLVM, :ValueMetadataDict) Base.haskey(md::LLVM.InstructionMetadataDict, kind::String) = ccall((:EnzymeGetStringMD, libEnzyme), Cvoid, (LLVM.API.LLVMValueRef, Cstring), md.inst, kind) != C_NULL diff --git a/src/compiler.jl b/src/compiler.jl index 30bf6f0d9c..83ff363dbd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1298,7 +1298,7 @@ end xi = getfield(prev, i) T = Core.Typeof(xi) xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) + setfield!(y, i, xi) end end return y @@ -1324,6 +1324,201 @@ end return y end +function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S} + zero(T) +end + +function make_zero_immutable!(prev::Complex{T}, seen::S)::Complex{T} where {T <: AbstractFloat, S} + zero(T) +end + +function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S} + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable(prev[i], seen) + end +end + +function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} where {a,b, S} + NamedTuple{a, b}( + ntuple(Val(length(T.parameters))) do i + Base.@_inline_meta + make_zero_immutable(prev[a[i]], seen) + end + ) +end + + +function make_zero_immutable!(prev::T, seen::S)::T where {T, S} + if guaranteed_const_nongen(T, nothing) + return prev + end + @assert !mutable_register(T) + + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + flds = Vector{Any}(undef, nf) + for i in 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + ST = Core.Typeof(xi) + flds[i] = if mutable_register(ST) + EnzymeCore.make_zero!(xi, seen) + xi + else + make_zero_immutable!(xi, seen) + end + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T <: AbstractFloat, ST} + T[] = zero(T) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}}, seen::ST)::Nothing where {T <: AbstractFloat, ST} + T[] = zero(Complex{T}) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} + fill!(prev, zero(T)) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N}, seen::ST)::Nothing where {T <: AbstractFloat, N, ST} + fill!(prev, zero(Complex{T})) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T})::Nothing where {T <: AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{Complex{T}})::Nothing where {T <: AbstractFloat} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N})::Nothing where {T <: AbstractFloat, N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{Complex{T}, N})::Nothing where {T <: AbstractFloat, N} + EnzymeCore.make_zero!(prev, nothing) + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Array{T, N}, seen::ST)::Nothing where {T, N, ST} + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + insert!(seen, prev) + + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(pv, seen) + nothing + else + @inbounds prev[I] = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + end + end + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Base.RefValue{T}, seen::ST)::Nothing where {T, ST} + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + insert!(seen, prev) + + pv = prev[] + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(pv, seen) + nothing + else + prev[] = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} + pv = prev.contents + T = Core.Typeof(pv) + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + insert!(seen, prev) + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(pv, seen) + nothing + else + prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + nothing + end + nothing +end + +@inline function EnzymeCore.make_zero!(prev::T, seen::S=IdSet{Any}())::Nothing where {T, S} + if guaranteed_const_nongen(T, nothing) + return + end + if haskey(seen, prev) + return + end + @assert !Base.isabstracttype(RT) + @assert Base.isconcretetype(RT) + nf = fieldcount(RT) + + + if nf == 0 + return + end + + insert!(seen, prev) + + for i in 1:nf + if isdefined(prev, i) + xi = getfield(prev, i) + SBT = Core.Typeof(pv) + if mutable_register(SBT) + EnzymeCore.make_zero!(xi, seen) + nothing + else + setfield!(prev, i, make_zero_immutable!(xi, seen)) + nothing + end + end + end + return +end + struct EnzymeRuntimeException <: Base.Exception msg::Cstring end @@ -5536,7 +5731,7 @@ end @assert ismutable(x) yi = getfield(y, i) nexti = recursive_add(xi, yi, f, mutable_register) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, nexti) + setfield!(x, i, nexti) end end end From 965466d8618103273ca09a1a646cf769f2228155 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 8 Jun 2024 11:12:07 -0400 Subject: [PATCH 3/3] more fixes and tests --- src/compiler.jl | 65 +++++++++++++++++++++++++----------------------- test/runtests.jl | 22 ++++++++++++++++ 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 83ff363dbd..fac6907b59 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1335,7 +1335,7 @@ end function make_zero_immutable!(prev::T, seen::S)::T where {T <: Tuple, S} ntuple(Val(length(T.parameters))) do i Base.@_inline_meta - make_zero_immutable(prev[i], seen) + make_zero_immutable!(prev[i], seen) end end @@ -1343,7 +1343,7 @@ function make_zero_immutable!(prev::NamedTuple{a, b}, seen::S)::NamedTuple{a, b} NamedTuple{a, b}( ntuple(Val(length(T.parameters))) do i Base.@_inline_meta - make_zero_immutable(prev[a[i]], seen) + make_zero_immutable!(prev[a[i]], seen) end ) end @@ -1353,7 +1353,7 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} if guaranteed_const_nongen(T, nothing) return prev end - @assert !mutable_register(T) + @assert !ismutable(T) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) @@ -1364,11 +1364,11 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T, S} if isdefined(prev, i) xi = getfield(prev, i) ST = Core.Typeof(xi) - flds[i] = if mutable_register(ST) + flds[i] = if active_reg_inner(ST, (), nothing, #=justActive=#Val(true)) == ActiveState + make_zero_immutable!(xi, seen) + else EnzymeCore.make_zero!(xi, seen) xi - else - make_zero_immutable!(xi, seen) end else nf = i - 1 # rest of tail must be undefined values @@ -1422,20 +1422,20 @@ end if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - insert!(seen, prev) + push!(seen, prev) for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(pv, seen) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + @inbounds prev[I] = make_zero_immutable!(pv, seen) nothing else - @inbounds prev[I] = EnzymeCore.make_zero_immutable!(pv, seen) + EnzymeCore.make_zero!(pv, seen) nothing end end @@ -1447,18 +1447,18 @@ end if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - insert!(seen, prev) + push!(seen, prev) pv = prev[] SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(pv, seen) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + prev[] = make_zero_immutable!(pv, seen) nothing else - prev[] = EnzymeCore.make_zero_immutable!(pv, seen) + EnzymeCore.make_zero!(pv, seen) nothing end nothing @@ -1470,48 +1470,51 @@ end if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - insert!(seen, prev) + push!(seen, prev) SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(pv, seen) + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) nothing else - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) + EnzymeCore.make_zero!(pv, seen) nothing end nothing end -@inline function EnzymeCore.make_zero!(prev::T, seen::S=IdSet{Any}())::Nothing where {T, S} +@inline function EnzymeCore.make_zero!(prev::T, seen::S=Base.IdSet{Any}())::Nothing where {T, S} if guaranteed_const_nongen(T, nothing) return end - if haskey(seen, prev) + if in(seen, prev) return end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) if nf == 0 return end - insert!(seen, prev) + push!(seen, prev) for i in 1:nf if isdefined(prev, i) xi = getfield(prev, i) - SBT = Core.Typeof(pv) - if mutable_register(SBT) - EnzymeCore.make_zero!(xi, seen) + SBT = Core.Typeof(xi) + if guaranteed_const_nongen(SBT, nothing) + continue + end + if active_reg_inner(SBT, (), nothing, #=justActive=#Val(true)) == ActiveState + setfield!(prev, i, make_zero_immutable!(xi, seen)) nothing else - setfield!(prev, i, make_zero_immutable!(xi, seen)) + EnzymeCore.make_zero!(xi, seen) nothing end end diff --git a/test/runtests.jl b/test/runtests.jl index 225ddf435f..0212ec0d83 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -181,6 +181,28 @@ end # @test thunk_split.primal !== C_NULL # @test thunk_split.primal !== thunk_split.adjoint # @test thunk_a.adjoint !== thunk_split.adjoint + # + z = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9]) + Enzyme.make_zero!(z) + @test z[1] ≈ [0.0, 0.0, 0.0] + @test z[2][1] == 0 + @test z[2][2] == 1 + @test z[3] ≈ [0.0, 0.0] + + z2 = ([3.14, 21.5, 16.7], [0,1], [5.6, 8.9]) + Enzyme.make_zero!(z2) + @test z2[1] ≈ [0.0, 0.0, 0.0] + @test z2[2][1] == 0 + @test z2[2][2] == 1 + @test z2[3] ≈ [0.0, 0.0] + + z3 = [3.4, "foo"] + Enzyme.make_zero!(z3) + @test z3[1] ≈ 0.0 + @test z3[2] == "foo" + + z4 = sin + Enzyme.make_zero!(z4) end @testset "Reflection" begin