Skip to content

Commit

Permalink
1.11: the adventure continues, destroy (#1986)
Browse files Browse the repository at this point in the history
* 1.11: the adventure continues, destroy

* fix

* fixup

* fix

* cleanup

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 20, 2024
1 parent 9e945a5 commit 72763e9
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 207 deletions.
50 changes: 35 additions & 15 deletions src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ function absint(arg::LLVM.Value, partial::Bool = false)
if nm == "julia.pointer_from_objref"
return absint(operands(arg)[1], partial)
end
if nm == "julia.gc_loaded"
return absint(operands(arg)[2], partial)
end
if nm == "jl_typeof" || nm == "ijl_typeof"
vals = abs_typeof(operands(arg)[1], partial)
return (vals[1], vals[2])
Expand Down Expand Up @@ -158,7 +161,13 @@ function absint(arg::LLVM.Value, partial::Bool = false)
end

function actual_size(@nospecialize(typ2))
if typ2 <: Array || typ2 <: AbstractString || typ2 <: Symbol
@static if VERSION < v"1.11-"
if typ2 <: Array
return sizeof(Int)
end
else
end
if typ2 <: AbstractString || typ2 <: Symbol
return sizeof(Int)
elseif Base.isconcretetype(typ2)
return sizeof(typ2)
Expand Down Expand Up @@ -256,6 +265,11 @@ function abs_typeof(
return abs_typeof(operands(arg)[1], partial)
end

if nm == "julia.gc_loaded"
legal, res, byref = abs_typeof(operands(arg)[2], partial)
return legal, res, byref
end

for (fname, ty) in (
("jl_box_int64", Int64),
("ijl_box_int64", Int64),
Expand Down Expand Up @@ -453,7 +467,7 @@ function abs_typeof(
fo = fieldoffset(typ, i)
if fieldoffset(typ, i) == offset
offset = 0
typ = fieldtype(typ, i)
typ = typed_fieldtype(typ, i)
if !Base.allocatedinline(typ)
if byref != GPUCompiler.BITS_VALUE
legal = false
Expand All @@ -464,7 +478,7 @@ function abs_typeof(
break
elseif fieldoffset(typ, i) > offset
offset = offset - fieldoffset(typ, lasti)
typ = fieldtype(typ, lasti)
typ = typed_fieldtype(typ, lasti)
@assert Base.isconcretetype(typ)
if !Base.allocatedinline(typ)
legal = false
Expand All @@ -477,15 +491,15 @@ function abs_typeof(
lasti = i
end
end
if !seen && fieldcount(typ) > 0
offset = offset - fieldoffset(typ, lasti)
typ = fieldtype(typ, lasti)
@assert Base.isconcretetype(typ)
if !Base.allocatedinline(typ)
legal = false
end
seen = true
end
if !seen && fieldcount(typ) > 0
offset = offset - fieldoffset(typ, lasti)
typ = typed_fieldtype(typ, lasti)
@assert Base.isconcretetype(typ)
if !Base.allocatedinline(typ)
legal = false
end
seen = true
end
if !seen
legal = false
end
Expand All @@ -495,8 +509,14 @@ function abs_typeof(
while legal && should_recurse(typ2, value_type(arg), byref, dl)
idx, _ = first_non_ghost(typ2)
if idx != -1
typ2 = fieldtype(typ2, idx)
if !Base.allocatedinline(typ2)
typ2 = typed_fieldtype(typ2, idx)
if Base.allocatedinline(typ2)
if byref == GPUCompiler.BITS_VALUE
continue
end
legal = false
break
else
if byref != GPUCompiler.BITS_VALUE
legal = false
break
Expand Down Expand Up @@ -532,7 +552,7 @@ function abs_typeof(
@assert Base.isconcretetype(typ)
cnt = 0
for i = 1:fieldcount(typ)
styp = fieldtype(typ, i)
styp = typed_fieldtype(typ, i)
if isghostty(styp)
continue
end
Expand Down
16 changes: 14 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ end
return Val(AnyState)
end

subT = fieldtype(T, f)
subT = typed_fieldtype(T, f)

if justActive && !allocatedinline(subT)
return Val(AnyState)
Expand Down Expand Up @@ -2441,7 +2441,7 @@ function zero_single_allocation(builder, jlType, LLVMType, nobj, zeroAll, idx)
if isa(ty, LLVM.StructType)
i = 1
for ii = 1:fieldcount(jlty)
jlet = fieldtype(jlty, ii)
jlet = typed_fieldtype(jlty, ii)
if isghostty(jlet) || Core.Compiler.isconstType(jlet)
continue
end
Expand Down Expand Up @@ -3816,6 +3816,18 @@ function enzyme!(
LLVM.API.LLVMValueRef,
)
),
"julia.gc_loaded" => @cfunction(
inoutgcloaded_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"julia.pointer_from_objref" => @cfunction(
inout_rule,
UInt8,
Expand Down
73 changes: 9 additions & 64 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -823,60 +823,18 @@ function nodecayed_phis!(mod::LLVM.Module)
if isa(ld, LLVM.LoadInst)
v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true)
rhs = LLVM.ConstantInt(offty, sizeof(Int))
if o2 != rhs
msg = sprint() do io::IO
println(
io,
"Enzyme internal error addr13 load doesn't keep offset 0",
)
println(io, "mod=", string(LLVM.parent(f)))
println(io, "f=", string(f))
println(io, "v=", string(v))
println(io, "opv[1]=", string(operands(v)[1]))
println(io, "opv[2]=", string(operands(v)[2]))
println(io, "ld=", string(ld))
println(io, "ld_op[1]=", string(operands(ld)[1]))

println(io, "v2=", string(v2))
println(io, "o2=", string(o2))
println(io, "hl2=", string(hl2))

println(io, "offty=", string(offty))
println(io, "rhs=", string(rhs))
end
throw(AssertionError(msg))
end

# We currently only support gc_loaded(mem, ptr) where ptr = (({size_t, {}*}*)mem)->second
# [aka a load of the second element of mem]
base_2, off_2, _ = get_base_and_offset(v2)
base_1, off_1, _ = get_base_and_offset(operands(v)[1])
if base_1 != base_2 || off_1 != off_2
msg = sprint() do io::IO
println(
io,
"Enzyme internal error addr13 load data isn't offset of mem",
)
println(io, "f=", string(f))
println(io, "v=", string(v))
println(io, "opv[1]=", string(operands(v)[1]))
println(io, "opv[2]=", string(operands(v)[2]))
println(io, "ld=", string(ld))
println(io, "ld_op[1]=", string(operands(ld)[1]))

println(io, "v2=", string(v2))
println(io, "o2=", string(o2))
println(io, "hl2=", string(hl2))

println(io, "base_1=", string(base_1))
println(io, "base_2=", string(base_2))
println(io, "off_1=", string(off_1))
println(io, "off_2=", string(off_2))
end
throw(AssertionError(msg))

if o2 == rhs && base_1 == base_2 && off_1 == off_2
return v2, offset, true
end

return v2, offset, true
rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty)
lhs = ptrtoint!(b, operands(v)[2], offty)
off2 = nuwsub!(b, rhs, lhs)
return v2, nuwadd!(b, offset, off2), true
end
end
end
Expand Down Expand Up @@ -1127,24 +1085,11 @@ function nodecayed_phis!(mod::LLVM.Module)
else
base_obj = nphi

# %value_phi11 = phi {} addrspace(10)* [ %55, %L78 ], [ %54, %L76 ]

# %.phi.trans.insert77 = bitcast {} addrspace(10)* %value_phi11 to { i64, {} addrspace(10)** } addrspace(10)*
# %.phi.trans.insert78 = addrspacecast { i64, {} addrspace(10)** } addrspace(10)* %.phi.trans.insert77 to { i64, {} addrspace(10)** } addrspace(11)*
# %.phi.trans.insert79 = getelementptr inbounds { i64, {} addrspace(10)** }, { i64, {} addrspace(10)** } addrspace(11)* %.phi.trans.insert78, i64 0, i32 1
# %.pre80 = load {} addrspace(10)**, {} addrspace(10)** addrspace(11)* %.phi.trans.insert79, align 8, !dbg !532, !tbaa !19, !alias.scope !26, !noalias !29

# %154 = call {} addrspace(10)* addrspace(13)* @julia.gc_loaded({} addrspace(10)* %value_phi11, {} addrspace(10)** %.pre80), !dbg !532

jlt = LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), 10)
pjlt = LLVM.PointerType(jlt)
gent = LLVM.StructType([convert(LLVMType, Int), pjlt])
pgent = LLVM.PointerType(LLVM.StructType([convert(LLVMType, Int), pjlt]), 10)

nphi = bitcast!(nb, nphi, pgent)
nphi = addrspacecast!(nb, nphi, LLVM.PointerType(gent, 11))
nphi = inbounds_gep!(nb, gent, nphi, [LLVM.ConstantInt(Int64(0)), LLVM.ConstantInt(Int32(1))])
nphi = load!(nb, pjlt, nphi)
nphi = get_memory_data(nb, nphi)
nphi = bitcast!(nb, nphi, pjlt)

GTy = LLVM.FunctionType(LLVM.PointerType(jlt, 13), LLVM.LLVMType[jlt, pjlt])
gcloaded, _ = get_function!(
Expand Down
Loading

0 comments on commit 72763e9

Please sign in to comment.