Skip to content

Commit

Permalink
Canonicalize TT (#792)
Browse files Browse the repository at this point in the history
* Canonicalize TT

* fix

* attempt fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored May 1, 2023
1 parent d1498e3 commit 4a0cb2b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
5 changes: 5 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ function EnzymeCheckedMergeTypeTree(dst, src)
end
EnzymeTypeTreeOnlyEq(dst, x) = ccall((:EnzymeTypeTreeOnlyEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64), dst, x)
EnzymeTypeTreeLookupEq(dst, x, dl) = ccall((:EnzymeTypeTreeLookupEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl)
EnzymeTypeTreeCanonicalizeInPlace(dst, x, dl) = ccall((:EnzymeTypeTreeCanonicalizeInPlace, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl)
EnzymeTypeTreeData0Eq(dst) = ccall((:EnzymeTypeTreeData0Eq, libEnzyme), Cvoid, (CTypeTreeRef,), dst)
EnzymeTypeTreeInner0(dst) = ccall((:EnzymeTypeTreeInner0, libEnzyme), CConcreteType, (CTypeTreeRef,), dst)
EnzymeTypeTreeShiftIndiciesEq(dst, dl, offset, maxSize, addOffset) =
Expand Down Expand Up @@ -456,6 +457,10 @@ function moveBefore(i1, i2, BR)
ccall((:EnzymeMoveBefore, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef), i1, i2, BR)
end

function EnzymeCloneFunctionDISubprogramInto(i1, i2)
ccall((:EnzymeCloneFunctionDISubprogramInto, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2)
end

function EnzymeCopyMetadata(i1, i2)
ccall((:EnzymeCopyMetadata, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2)
end
Expand Down
19 changes: 11 additions & 8 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6680,6 +6680,12 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C
op_idx = arg.codegen.i
rest = typetree(arg.typ, ctx, dl)
if arg.cc == GPUCompiler.BITS_REF
# adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader
# object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int}
# aka the next field after this in the bigger object isn't guaranteed to also be the same.
if allocatedinline(arg.typ)
shift!(rest, dl, 0, sizeof(arg.typ), 0)
end
merge!(rest, TypeTree(API.DT_Pointer, ctx))
only!(rest, -1)
else
Expand Down Expand Up @@ -7167,10 +7173,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
T_ret = returnRoots ? T_void : jltype
FT = LLVM.FunctionType(T_ret, T_wrapperargs)
llvm_f = LLVM.Function(mod, safe_name(LLVM.name(enzymefn)*"wrap"), FT)
sfn = LLVM.get_subprogram(enzymefn)
if sfn !== nothing
LLVM.set_subprogram!(llvm_f, sfn)
end
API.EnzymeCloneFunctionDISubprogramInto(llvm_f, enzymefn)
dl = datalayout(mod)

params = [parameters(llvm_f)...]
Expand Down Expand Up @@ -7253,8 +7256,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end

val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms)
if LLVM.get_subprogram(enzymefn) !== nothing
metadata(val)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(enzymefn) )
if LLVM.get_subprogram(llvm_f) !== nothing
metadata(val)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(llvm_f) )
end

if Mode == API.DEM_ReverseModePrimal
Expand Down Expand Up @@ -7284,8 +7287,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
push!(function_attributes(cf), EnumAttribute("alwaysinline", 0; ctx))
for shadowv in shadows
c = call!(builder, LLVM.function_type(cf), cf, [shadowv])
if LLVM.get_subprogram(enzymefn) !== nothing
metadata(c)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(enzymefn) )
if LLVM.get_subprogram(llvm_f) !== nothing
metadata(c)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(llvm_f) )
end
end
end
Expand Down
10 changes: 5 additions & 5 deletions src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ function data0!(tt::TypeTree)
API.EnzymeTypeTreeData0Eq(tt)
end

function canonicalize!(tt::TypeTree, size, dl)
API.EnzymeTypeTreeCanonicalizeInPlace(tt, size, dl)
end
function shift!(tt::TypeTree, dl, offset, maxSize, addOffset)
API.EnzymeTypeTreeShiftIndiciesEq(tt, dl, offset, maxSize, addOffset)
end
Expand All @@ -55,11 +58,7 @@ function merge!(dst::TypeTree, src::TypeTree; consume=true)
end

function typetree(::Type{T}, ctx, dl, seen=nothing) where T <: Integer
tt = TypeTree()
for i in 1:sizeof(T)
merge!(tt, TypeTree(API.DT_Integer, i-1, ctx))
end
return tt
return TypeTree(API.DT_Integer, -1, ctx)
end

function typetree(::Type{Float16}, ctx, dl, seen=nothing)
Expand Down Expand Up @@ -203,6 +202,7 @@ function typetree(@nospecialize(T), ctx, dl, seen=nothing)

merge!(tt, subtree)
end
canonicalize!(tt, sizeof(T), dl)
return tt
end

Expand Down
6 changes: 1 addition & 5 deletions test/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ end
@test tt(Symbol) == "{}"
@test tt(String) == "{}"
@test tt(AbstractChannel) == "{}"
if sizeof(Int) == sizeof(Int64)
@test tt(Base.ImmutableDict{Symbol, Any}) == "{[0]:Pointer, [8]:Pointer, [16]:Pointer}"
else
@test tt(Base.ImmutableDict{Symbol, Any}) == "{[0]:Pointer, [4]:Pointer, [8]:Pointer}"
end
@test tt(Base.ImmutableDict{Symbol, Any}) == "{[-1]:Pointer}"
@test tt(Atom) == "{[0]:Float@float, [4]:Float@float, [8]:Float@float, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer}"
@test tt(Composite) == "{[0]:Float@float, [4]:Float@float, [8]:Float@float, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@float, [20]:Float@float, [24]:Float@float, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer}"
@test tt(Tuple{Any,Any}) == "{[-1]:Pointer}"
Expand Down

0 comments on commit 4a0cb2b

Please sign in to comment.