From 878064724fe8934dfd13d755c613a41043ea8ec4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 12 Jun 2024 08:38:00 -0700 Subject: [PATCH] Handle mixed custom rule arg --- src/compiler.jl | 10 +++---- src/rules/customrules.jl | 2 -- test/mixedrrule.jl | 63 ++++++++++++++++++++++++++++++++-------- 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 10ff8a976a..daef61e583 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2450,7 +2450,7 @@ else end end -function store_nonjl_types!(B, p, startval) +function store_nonjl_types!(B, startval, p) T_jlvalue = LLVM.StructType(LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) vals = LLVM.Value[] @@ -2476,8 +2476,8 @@ function store_nonjl_types!(B, p, startval) end end if isa(ty, LLVM.StructType) - for (i, t) in enumerate(LLVM.elements(ty)) - if any_jltypes(t) + if any_jltypes(ty) + for (i, t) in enumerate(LLVM.elements(ty)) ev = extract_value!(B, cur, i-1) push!(todo, ((path..., i-1), ev)) end @@ -2488,8 +2488,8 @@ function store_nonjl_types!(B, p, startval) for v in path push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v)) end - gptr = gep!(B, p, parray) - store!(B, cur, gptr) + gptr = gep!(B, value_type(startval), p, parray) + st = store!(B, cur, gptr) end return end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 0103e70121..1eab07c3f2 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -1017,7 +1017,6 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx+=1 end - # @show mixeds for (ptr_val, argTyp, refal) in mixeds RefTy = argTyp if width != 1 @@ -1030,7 +1029,6 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, evcur = (width == 1) ? curs : extract_value!(B, curs, idx-1) store_nonjl_types!(B, evcur, evp) end - @show curs, ptr_val, argTyp, refal end end diff --git a/test/mixedrrule.jl b/test/mixedrrule.jl index 4e85ac4f02..5697641c64 100644 --- a/test/mixedrrule.jl +++ b/test/mixedrrule.jl @@ -4,8 +4,6 @@ using Enzyme using Enzyme: EnzymeRules using Test -Enzyme.API.printall!(true) - import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig using .EnzymeRules @@ -20,8 +18,47 @@ function mixouter(x, y) end function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, + ::Type{<:Active}, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) + pval = func.val(tup.val) + vec = copy(tup.val[2]) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, + dret::Active, tape, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}}) + prev = tup.dval[] + tup.dval[] = (7 * tape[1] * dret.val, prev[2]) + prev[2][1] = 1000 * dret.val * tup.val[1] + return (nothing,) +end + +@testset "Mixed activity rule" begin + x = [3.14] + dx = [0.0] + res = autodiff(Reverse, mixouter, Active, Active(2.7), Duplicated(x, dx))[1][1] + @test res ≈ 7 * 3.14 + @test dx[1] ≈ 1000 * 2.7 + @test x[1] ≈ 0.0 +end + + +function recmixfnc(tup) + return sum(tup[1]) * tup[2][1] +end + +function recmixouter(x, y, z) + res = mixfnc(((x, z), y)) + fill!(y, 0.0) + return res +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, ::Type{<:Active}, tup) - @show tup pval = func.val(tup.val) vec = copy(tup.val[2]) primal = if EnzymeRules.needs_primal(config) @@ -32,37 +69,39 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof return AugmentedReturn(primal, nothing, vec) end +# check if a value is guaranteed to be not contain active[register] data +# (aka not either mixed or active) @inline function guaranteed_nonactive(::Type{T}) where T rt = Enzyme.Compiler.active_reg_inner(T, (), nothing) return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)}, +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)}, dret::Active, tape, tup) prev = tup.dval[] dRT = typeof(prev) - @show "rev", tup - @show dRT, fieldcount(dRT) + tup.dval[] = Enzyme.Compiler.splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i Base.@_inline_meta pv = getfield(prev, i) if i == 1 - next = 7 * tape[1] * dret.val + next = (7 * tape[1] * dret.val, 31 * tape[1] * dret.val) Enzyme.Compiler.recursive_add(pv, next, identity, guaranteed_nonactive) else pv end end) - prev[2][1] = 1000 * dret.val * prev[1] + prev[2][1] = 1000 * dret.val * tup.val[1] + .0001 * dret.val * tup.val[3] return (nothing,) end -@testset "Mixed activity rule" begin +@testset "Recursive Mixed activity rule" begin x = [3.14] dx = [0.0] - res = autodiff(Reverse, mixouter, Active, Active(2.7), Duplicated(x, dx))[1][1] - @test res ≈ 7 * 3.14 - @test dx[1] ≈ 1000 * 2.7 + res = autodiff(Reverse, recmixouter, Active, Active(2.7), Duplicated(x, dx), Active(56.47))[1] + @test res[1] ≈ 7 * 3.14 + @test res[3] ≈ 31 * 3.14 + @test dx[1] ≈ 1000 * 2.7 + .0001 * 56.47 @test x[1] ≈ 0.0 end