Skip to content

Commit

Permalink
Fix calling convention print and mismatch (#1125)
Browse files Browse the repository at this point in the history
* Fix calling convention print and mismatch

* Use make zero for reverse allocation
  • Loading branch information
jlk9 authored Nov 2, 2023
1 parent af9c3d7 commit d471a8d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
33 changes: 29 additions & 4 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes)
return ReturnType(($(nres...), tape))
elseif annotation <: Active
if $Width == 1
shadow_return = Ref(zero(resT))
shadow_return = Ref(make_zero(resT, IdDict(), origRet))
else
shadow_return = ($(nzeros...),)
end
Expand Down Expand Up @@ -4601,15 +4601,40 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
return tapeV
end


T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)

for i in 1:length(args)
party = value_type(parameters(llvmf)[i])
if value_type(args[i]) != party
if party == T_prjlvalue
while true
if isa(args[i], LLVM.BitCastInst)
args[i] = operands(args[i])[1]
continue
end
if isa(args[i], LLVM.AddrSpaceCastInst)
args[i] = operands(args[i])[1]
continue
end
break
end
end
end

if value_type(args[i]) == party
continue
end
# Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float}
args[i] = calling_conv_fixup(B, args[i], party)
# GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf, augprimal_TT, rev_TT, fn, args, sret, returnRoots
# return tapeV
function msg(io)
println(io, string(llvmf))
println(io, "args = ", args)
println(io, "i = ", i)
println(io, "args[i] = ", args[i])
println(io, "party = ", party)
end
args[i] = calling_conv_fixup(B, args[i], party, LLVM.UndefValue(party), Cuint[], Cuint[], msg)
end

res = LLVM.call!(B, LLVM.function_type(llvmf), llvmf, args)
Expand Down
31 changes: 23 additions & 8 deletions src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ end

# Fix calling convention within julia that Tuple{Float,Float} ->[2 x float] rather than {float, float}
# and that Bool -> i8, not i1
function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev::LLVM.Value=LLVM.UndefValue(tape), lidxs::Vector{Cuint}=Cuint[], ridxs::Vector{Cuint}=Cuint[])::LLVM.Value
function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev::LLVM.Value=LLVM.UndefValue(tape), lidxs::Vector{Cuint}=Cuint[], ridxs::Vector{Cuint}=Cuint[], emesg=nothing)::LLVM.Value
ctype = recursive_eltype(val, lidxs)
if ctype == tape
if length(lidxs) != 0
Expand All @@ -200,7 +200,7 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev:
push!(ln, i-1)
rn = copy(ridxs)
push!(rn, i-1)
prev = calling_conv_fixup(builder, val, ty, prev, ln, rn)
prev = calling_conv_fixup(builder, val, ty, prev, ln, rn, emesg)
end
return prev
end
Expand All @@ -211,7 +211,7 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev:
push!(ln, i-1)
rn = copy(ridxs)
push!(rn, i-1)
prev = calling_conv_fixup(builder, val, ty, prev, ln, rn)
prev = calling_conv_fixup(builder, val, ty, prev, ln, rn, emesg)
end
return prev
end
Expand All @@ -223,7 +223,7 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev:
push!(ln, i-1)
rn = copy(ridxs)
push!(rn, i-1)
prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn)
prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn, emesg)
end
return prev
end
Expand All @@ -234,7 +234,7 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev:
push!(ln, i-1)
rn = copy(ridxs)
push!(rn, i-1)
prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn)
prev = calling_conv_fixup(builder, val, eltype(tape), prev, ln, rn, emesg)
end
return prev
end
Expand Down Expand Up @@ -265,8 +265,23 @@ function calling_conv_fixup(builder, val::LLVM.Value, tape::LLVM.LLVMType, prev:
if isa(ctype, LLVM.ArrayType) && length(ctype) == 1 && eltype(ctype) == tape
lhs_n = copy(lidxs)
push!(lhs_n, 0)
return calling_conv_fixup(builder, val, tape, prev, lhs_n, ridxs)
return calling_conv_fixup(builder, val, tape, prev, lhs_n, ridxs, emesg)
end
@show ctype, tape, val, prev, lidxs, ridxs, tape_type(tape), convert(LLVM.LLVMType, tape_type(tape); allow_boxed=true)
@assert false


msg2 = sprint() do io
println(io, "Enzyme Internal Error: Illegal calling convention fixup")
if emesg !== nothing
emesg(io)
end
println(io, "ctype = ", ctype)
println(io, "tape = ", tape)
println(io, "val = ", val)
println(io, "prev = ", prev)
println(io, "lidxs = ", lidxs)
println(io, "ridxs = ", ridxs)
println(io, "tape_type(tape) = ", tape_type(tape))
println(io, "convert(LLVMType, tape_type(tape)) = ", convert(LLVM.LLVMType, tape_type(tape); allow_boxed=true))
end
throw(AssertionError(msg2))
end

0 comments on commit d471a8d

Please sign in to comment.