diff --git a/src/compiler.jl b/src/compiler.jl index 3c83de8848..c98f50de68 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -334,7 +334,7 @@ declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do ctx LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) end end -function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround) +function emit_allocobj!(B, tag::LLVM.Value, Size::LLVM.Value, needs_workaround::Bool) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) @@ -3955,7 +3955,19 @@ function enzyme_custom_common_rev(forward::Bool, B::LLVM.API.LLVMBuilderRef, Ori if !forward if needsTape @assert tape != C_NULL - insert!(args, 1+(kwtup!==nothing), LLVM.Value(tape)) + tape = LLVM.Value(tape) + innerTy = value_type(parameters(llvmf)[1+(kwtup!==nothing)]) + if innerTy != value_type(tape) + llty = convert(LLVMType, TapeT; ctx) + al0 = al = emit_allocobj!(B, TapeT) + al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + store!(B, tape, al) + if any_jltypes(llty) + emit_writebarrier!(B, get_julia_inner_types(B, al0, tape)) + end + tape = addrspacecast!(B, al, LLVM.PointerType(llty, 11)) + end + insert!(args, 1+(kwtup!==nothing), tape) end if RT <: Active diff --git a/test/rrules.jl b/test/rrules.jl index 16cffac0d9..f37c8420e7 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -105,4 +105,28 @@ end @test Enzyme.autodiff(Reverse, h2, Active(3.0)) == ((1080.0,),) end +q(x) = x^2 +function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(q)}, ::Type{<:Active}, x::Active) + tape = (Ref(2.0), Ref(3.4)) + if needs_primal(config) + return AugmentedReturn(func.val(x.val), nothing, tape) + else + return AugmentedReturn(nothing, nothing, tape) + end +end + +function reverse(config::ConfigWidth{1}, ::Const{typeof(q)}, dret::Active, tape, x::Active) + @test tape[1][] == 2.0 + @test tape[2][] == 3.4 + if needs_primal(config) + return (10+2*x.val*dret.val,) + else + return (100+2*x.val*dret.val,) + end +end + +@testset "Byref Tape" begin + @test Enzyme.autodiff(Enzyme.Reverse, q, Active(2.0))[1][1] ≈ 104.0 +end + end # ReverseRules