Skip to content

Commit

Permalink
Fix tape by reference (#743)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Apr 19, 2023
1 parent 8cdbe18 commit 4c30ed8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do ctx
LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue])
end
end
function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround)
function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool)
curent_bb = position(B)
fn = LLVM.parent(curent_bb)
mod = LLVM.parent(fn)
Expand Down Expand Up @@ -3955,7 +3955,19 @@ function enzyme_custom_common_rev(forward::Bool, B::LLVM.API.LLVMBuilderRef, Ori
if !forward
if needsTape
@assert tape != C_NULL
insert!(args, 1+(kwtup!==nothing), LLVM.Value(tape))
tape = LLVM.Value(tape)
innerTy = value_type(parameters(llvmf)[1+(kwtup!==nothing)])
if innerTy != value_type(tape)
llty = convert(LLVMType, TapeT; ctx)
al0 = al = emit_allocobj!(B, TapeT)
al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(B, tape, al)
if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, tape))
end
tape = addrspacecast!(B, al, LLVM.PointerType(llty, 11))
end
insert!(args, 1+(kwtup!==nothing), tape)
end
if RT <: Active

Expand Down
24 changes: 24 additions & 0 deletions test/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,28 @@ end
@test Enzyme.autodiff(Reverse, h2, Active(3.0)) == ((1080.0,),)
end

q(x) = x^2
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(q)}, ::Type{<:Active}, x::Active)
tape = (Ref(2.0), Ref(3.4))
if needs_primal(config)
return AugmentedReturn(func.val(x.val), nothing, tape)
else
return AugmentedReturn(nothing, nothing, tape)
end
end

function reverse(config::ConfigWidth{1}, ::Const{typeof(q)}, dret::Active, tape, x::Active)
@test tape[1][] == 2.0
@test tape[2][] == 3.4
if needs_primal(config)
return (10+2*x.val*dret.val,)
else
return (100+2*x.val*dret.val,)
end
end

@testset "Byref Tape" begin
@test Enzyme.autodiff(Enzyme.Reverse, q, Active(2.0))[1][1] 104.0
end

end # ReverseRules

0 comments on commit 4c30ed8

Please sign in to comment.