Skip to content

Commit

Permalink
Handle mixed custom rule arg
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 12, 2024
1 parent 971194f commit fdb235f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 19 deletions.
10 changes: 5 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2450,7 +2450,7 @@ else
end
end

function store_nonjl_types!(B, p, startval)
function store_nonjl_types!(B, startval, p)
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
vals = LLVM.Value[]
Expand All @@ -2476,8 +2476,8 @@ function store_nonjl_types!(B, p, startval)
end
end
if isa(ty, LLVM.StructType)
for (i, t) in enumerate(LLVM.elements(ty))
if any_jltypes(t)
if any_jltypes(ty)
for (i, t) in enumerate(LLVM.elements(ty))
ev = extract_value!(B, cur, i-1)
push!(todo, ((path..., i-1), ev))
end
Expand All @@ -2488,8 +2488,8 @@ function store_nonjl_types!(B, p, startval)
for v in path
push!(parray, LLVM.ConstantInt(LLVM.IntType(32), v))
end
gptr = gep!(B, p, parray)
store!(B, cur, gptr)
gptr = gep!(B, value_type(startval), p, parray)
st = store!(B, cur, gptr)
end
return
end
Expand Down
2 changes: 0 additions & 2 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,6 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
idx+=1
end

# @show mixeds
for (ptr_val, argTyp, refal) in mixeds
RefTy = argTyp
if width != 1
Expand All @@ -1030,7 +1029,6 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
evcur = (width == 1) ? curs : extract_value!(B, curs, idx-1)
store_nonjl_types!(B, evcur, evp)
end
@show curs, ptr_val, argTyp, refal
end
end

Expand Down
63 changes: 51 additions & 12 deletions test/mixedrrule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using Enzyme
using Enzyme: EnzymeRules
using Test

Enzyme.API.printall!(true)

import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule_from_sig
using .EnzymeRules

Expand All @@ -20,8 +18,47 @@ function mixouter(x, y)
end

function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(mixfnc)},
::Type{<:Active}, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}})
pval = func.val(tup.val)
vec = copy(tup.val[2])
primal = if EnzymeRules.needs_primal(config)
pval
else
nothing
end
return AugmentedReturn(primal, nothing, vec)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)},
dret::Active, tape, tup::MixedDuplicated{Tuple{Float64, Vector{Float64}}})
prev = tup.dval[]
tup.dval[] = (7 * tape[1] * dret.val, prev[2])
prev[2][1] = 1000 * dret.val * tup.val[1]
return (nothing,)
end

@testset "Mixed activity rule" begin
x = [3.14]
dx = [0.0]
res = autodiff(Reverse, mixouter, Active, Active(2.7), Duplicated(x, dx))[1][1]
@test res 7 * 3.14
@test dx[1] 1000 * 2.7
@test x[1] 0.0
end


function recmixfnc(tup)
return sum(tup[1]) * tup[2][1]
end

function recmixouter(x, y, z)
res = recmixfnc(((x, z), y))
fill!(y, 0.0)
return res
end

function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)},
::Type{<:Active}, tup)
@show tup
pval = func.val(tup.val)
vec = copy(tup.val[2])
primal = if EnzymeRules.needs_primal(config)
Expand All @@ -32,37 +69,39 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof
return AugmentedReturn(primal, nothing, vec)
end

# check if a value is guaranteed to be not contain active[register] data
# (aka not either mixed or active)
@inline function guaranteed_nonactive(::Type{T}) where T
rt = Enzyme.Compiler.active_reg_inner(T, (), nothing)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(mixfnc)},
function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(recmixfnc)},
dret::Active, tape, tup)
prev = tup.dval[]
dRT = typeof(prev)
@show "rev", tup
@show dRT, fieldcount(dRT)

tup.dval[] = Enzyme.Compiler.splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i
Base.@_inline_meta
pv = getfield(prev, i)
if i == 1
next = 7 * tape[1] * dret.val
next = (7 * tape[1] * dret.val, 31 * tape[1] * dret.val)
Enzyme.Compiler.recursive_add(pv, next, identity, guaranteed_nonactive)
else
pv
end
end)
prev[2][1] = 1000 * dret.val * prev[1]
prev[2][1] = 1000 * dret.val * tup.val[1][1] + .0001 * dret.val * tup.val[1][2]
return (nothing,)
end

@testset "Mixed activity rule" begin
@testset "Recursive Mixed activity rule" begin
x = [3.14]
dx = [0.0]
res = autodiff(Reverse, mixouter, Active, Active(2.7), Duplicated(x, dx))[1][1]
@test res 7 * 3.14
@test dx[1] 1000 * 2.7
res = autodiff(Reverse, recmixouter, Active, Active(2.7), Duplicated(x, dx), Active(56.47))[1]
@test res[1] 7 * 3.14
@test res[3] 31 * 3.14
@test dx[1] 1000 * 2.7 + .0001 * 56.47
@test x[1] 0.0
end

Expand Down

0 comments on commit fdb235f

Please sign in to comment.