diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index c658850c2e..de24d01053 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -687,7 +687,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, end end end - push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) + + # push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) needsTape = !isghostty(TapeT) && !Core.Compiler.isconstType(TapeT) @@ -711,22 +712,37 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf)))) - _, sret, returnRoots = get_return_info(enzyme_custom_extract_mi(llvmf)[2]) + miRT = enzyme_custom_extract_mi(llvmf)[2] + _, sret, returnRoots = get_return_info(miRT) if !forward + funcTy = rev_TT.parameters[isKWCall ? 4 : 2] if needsTape @assert tape != C_NULL - tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) - innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)]) + tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) + !isghostty(funcTy) + trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself+(RT <: Active) + innerTy = value_type(parameters(llvmf)[trueidx]) if innerTy != value_type(tape) - if isabstracttype(TapeT) + if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL msg = sprint() do io println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") println(io, "tape_idx=", tape_idx) + println(io, "true_idx=", trueidx) + println(io, "isKWCall=", isKWCall) + println(io, "kwtup=", kwtup) + println(io, "funcTy=", funcTy) + println(io, "isghostty(funcTy)=", isghostty(funcTy)) + println(io, "miRT=", miRT) println(io, "sret=", sret) + println(io, "returnRoots=", returnRoots) + println(io, "swiftself=", swiftself) println(io, "RT=", RT) println(io, "tape=", tape) - println(io, "llvmf=", string(llvmf)) + println(io, "llvmf=", string(LLVM.function_type(llvmf))) + println(io, "TapeT=", TapeT) + println(io, "mi=", mi) + println(io, "ami=", ami) + println(io, "rev_TT =", rev_TT) end throw(AssertionError(msg)) end @@ -749,7 +765,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B)) else llety = convert(LLVMType, eltype(RT)) - ptr_val = invert_pointer(gutils, operands(orig)[1], B) + ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B) val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety))) for idx in 1:width ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1) @@ -769,8 +785,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if any_jltypes(llty) emit_writebarrier!(B, get_julia_inner_types(B, al0, val)) end - - insert!(args, 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al) + insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al) end end diff --git a/test/rrules.jl b/test/rrules.jl index 1322895924..171c160b0f 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -305,4 +305,44 @@ end @test dU[1] ≈ 7 * ( 3.0 + 4.0im ) end end + + +struct Closure + v::Vector{Float64} +end + +function (cl::Closure)(x) + val = cl.v[1] * x + cl.v[1] = 0.0 + return val +end + + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure}, + ::Type{<:Active}, args::Vararg{Active,N}) where {N} + vec = copy(func.val.v) + pval = func.val(args[1].val) + primal = if EnzymeRules.needs_primal(config) + pval + else + nothing + end + return AugmentedReturn(primal, nothing, vec) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure}, + dret::Active, tape, args::Vararg{Active,N}) where {N} + dargs = ntuple(Val(N)) do i + 7 * args[1].val * dret.val + tape[1] * 1000 + end + return dargs +end + +@testset "Closure rule" begin + cl = Closure([3.14]) + res = autodiff(Reverse, cl, Active, Active(2.7))[1][1] + @test res ≈ 7 * 2.7 + 3.14 * 1000 + @test cl.v[1] ≈ 0.0 +end + end # ReverseRules