Skip to content

Commit

Permalink
Fix reverse mode closure issues (#1533)
Browse files Browse the repository at this point in the history
* Fix custom reverse on closure

* fix closure
  • Loading branch information
wsmoses authored Jun 11, 2024
1 parent 6c2b0d9 commit bd60907
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 9 deletions.
33 changes: 24 additions & 9 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
end
end
end
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))

# push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))

needsTape = !isghostty(TapeT) && !Core.Compiler.isconstType(TapeT)

Expand All @@ -711,22 +712,37 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,

swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(llvmf, i)))) for i in 1:length(collect(parameters(llvmf))))

_, sret, returnRoots = get_return_info(enzyme_custom_extract_mi(llvmf)[2])
miRT = enzyme_custom_extract_mi(llvmf)[2]
_, sret, returnRoots = get_return_info(miRT)

if !forward
funcTy = rev_TT.parameters[isKWCall ? 4 : 2]
if needsTape
@assert tape != C_NULL
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4]))
innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)])
tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) + !isghostty(funcTy)
trueidx = tape_idx+(sret !== nothing)+(returnRoots !== nothing)+swiftself+(RT <: Active)
innerTy = value_type(parameters(llvmf)[trueidx])
if innerTy != value_type(tape)
if isabstracttype(TapeT)
if isabstracttype(TapeT) || TapeT == Tuple || TapeT.layout == C_NULL
msg = sprint() do io
println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))")
println(io, "tape_idx=", tape_idx)
println(io, "true_idx=", trueidx)
println(io, "isKWCall=", isKWCall)
println(io, "kwtup=", kwtup)
println(io, "funcTy=", funcTy)
println(io, "isghostty(funcTy)=", isghostty(funcTy))
println(io, "miRT=", miRT)
println(io, "sret=", sret)
println(io, "returnRoots=", returnRoots)
println(io, "swiftself=", swiftself)
println(io, "RT=", RT)
println(io, "tape=", tape)
println(io, "llvmf=", string(llvmf))
println(io, "llvmf=", string(LLVM.function_type(llvmf)))
println(io, "TapeT=", TapeT)
println(io, "mi=", mi)
println(io, "ami=", ami)
println(io, "rev_TT =", rev_TT)
end
throw(AssertionError(msg))
end
Expand All @@ -749,7 +765,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B))
else
llety = convert(LLVMType, eltype(RT))
ptr_val = invert_pointer(gutils, operands(orig)[1], B)
ptr_val = invert_pointer(gutils, operands(orig)[1 + !isghostty(funcTy)], B)
val = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llety)))
for idx in 1:width
ev = (width == 1) ? ptr_val : extract_value!(B, ptr_val, idx-1)
Expand All @@ -769,8 +785,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
if any_jltypes(llty)
emit_writebarrier!(B, get_julia_inner_types(B, al0, val))
end

insert!(args, 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al)
insert!(args, 1+(!isghostty(funcTy))+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])), al)
end
end

Expand Down
40 changes: 40 additions & 0 deletions test/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,44 @@ end
@test dU[1] 7 * ( 3.0 + 4.0im )
end
end


struct Closure
v::Vector{Float64}
end

function (cl::Closure)(x)
val = cl.v[1] * x
cl.v[1] = 0.0
return val
end


function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{Closure},
::Type{<:Active}, args::Vararg{Active,N}) where {N}
vec = copy(func.val.v)
pval = func.val(args[1].val)
primal = if EnzymeRules.needs_primal(config)
pval
else
nothing
end
return AugmentedReturn(primal, nothing, vec)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{Closure},
dret::Active, tape, args::Vararg{Active,N}) where {N}
dargs = ntuple(Val(N)) do i
7 * args[1].val * dret.val + tape[1] * 1000
end
return dargs
end

@testset "Closure rule" begin
cl = Closure([3.14])
res = autodiff(Reverse, cl, Active, Active(2.7))[1][1]
@test res 7 * 2.7 + 3.14 * 1000
@test cl.v[1] 0.0
end

end # ReverseRules

0 comments on commit bd60907

Please sign in to comment.