From 4a0cb2b2a1bc71e7aaab641ccbb66126c7055fc1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 1 May 2023 01:47:50 -0400 Subject: [PATCH] Canonicalize TT (#792) * Canonicalize TT * fix * attempt fix * fix * fix * fix --- src/api.jl | 5 +++++ src/compiler.jl | 19 +++++++++++-------- src/typetree.jl | 10 +++++----- test/typetree.jl | 6 +----- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/api.jl b/src/api.jl index 301eddf545..49952b0b3e 100644 --- a/src/api.jl +++ b/src/api.jl @@ -70,6 +70,7 @@ function EnzymeCheckedMergeTypeTree(dst, src) end EnzymeTypeTreeOnlyEq(dst, x) = ccall((:EnzymeTypeTreeOnlyEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64), dst, x) EnzymeTypeTreeLookupEq(dst, x, dl) = ccall((:EnzymeTypeTreeLookupEq, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl) +EnzymeTypeTreeCanonicalizeInPlace(dst, x, dl) = ccall((:EnzymeTypeTreeCanonicalizeInPlace, libEnzyme), Cvoid, (CTypeTreeRef, Int64, Cstring), dst, x, dl) EnzymeTypeTreeData0Eq(dst) = ccall((:EnzymeTypeTreeData0Eq, libEnzyme), Cvoid, (CTypeTreeRef,), dst) EnzymeTypeTreeInner0(dst) = ccall((:EnzymeTypeTreeInner0, libEnzyme), CConcreteType, (CTypeTreeRef,), dst) EnzymeTypeTreeShiftIndiciesEq(dst, dl, offset, maxSize, addOffset) = @@ -456,6 +457,10 @@ function moveBefore(i1, i2, BR) ccall((:EnzymeMoveBefore, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef, LLVM.API.LLVMBuilderRef), i1, i2, BR) end +function EnzymeCloneFunctionDISubprogramInto(i1, i2) + ccall((:EnzymeCloneFunctionDISubprogramInto, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2) +end + function EnzymeCopyMetadata(i1, i2) ccall((:EnzymeCopyMetadata, libEnzyme),Cvoid,(LLVM.API.LLVMValueRef,LLVM.API.LLVMValueRef), i1, i2) end diff --git a/src/compiler.jl b/src/compiler.jl index c0764eb1ab..ba0222dedb 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -6680,6 +6680,12 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C op_idx = arg.codegen.i rest = typetree(arg.typ, ctx, dl) if arg.cc == GPUCompiler.BITS_REF + # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader + # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} + # aka the next field after this in the bigger object isn't guaranteed to also be the same. + if allocatedinline(arg.typ) + shift!(rest, dl, 0, sizeof(arg.typ), 0) + end merge!(rest, TypeTree(API.DT_Pointer, ctx)) only!(rest, -1) else @@ -7167,10 +7173,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, T_ret = returnRoots ? T_void : jltype FT = LLVM.FunctionType(T_ret, T_wrapperargs) llvm_f = LLVM.Function(mod, safe_name(LLVM.name(enzymefn)*"wrap"), FT) - sfn = LLVM.get_subprogram(enzymefn) - if sfn !== nothing - LLVM.set_subprogram!(llvm_f, sfn) - end + API.EnzymeCloneFunctionDISubprogramInto(llvm_f, enzymefn) dl = datalayout(mod) params = [parameters(llvm_f)...] @@ -7253,8 +7256,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) - if LLVM.get_subprogram(enzymefn) !== nothing - metadata(val)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(enzymefn) ) + if LLVM.get_subprogram(llvm_f) !== nothing + metadata(val)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(llvm_f) ) end if Mode == API.DEM_ReverseModePrimal @@ -7284,8 +7287,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(function_attributes(cf), EnumAttribute("alwaysinline", 0; ctx)) for shadowv in shadows c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) - if LLVM.get_subprogram(enzymefn) !== nothing - metadata(c)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(enzymefn) ) + if LLVM.get_subprogram(llvm_f) !== nothing + metadata(c)[LLVM.MD_dbg] = DILocation(ctx, 0, 0, LLVM.get_subprogram(llvm_f) ) end end end diff --git a/src/typetree.jl b/src/typetree.jl index d579850b21..e0a2e77ea7 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -44,6 +44,9 @@ function data0!(tt::TypeTree) API.EnzymeTypeTreeData0Eq(tt) end +function canonicalize!(tt::TypeTree, size, dl) + API.EnzymeTypeTreeCanonicalizeInPlace(tt, size, dl) +end function shift!(tt::TypeTree, dl, offset, maxSize, addOffset) API.EnzymeTypeTreeShiftIndiciesEq(tt, dl, offset, maxSize, addOffset) end @@ -55,11 +58,7 @@ function merge!(dst::TypeTree, src::TypeTree; consume=true) end function typetree(::Type{T}, ctx, dl, seen=nothing) where T <: Integer - tt = TypeTree() - for i in 1:sizeof(T) - merge!(tt, TypeTree(API.DT_Integer, i-1, ctx)) - end - return tt + return TypeTree(API.DT_Integer, -1, ctx) end function typetree(::Type{Float16}, ctx, dl, seen=nothing) @@ -203,6 +202,7 @@ function typetree(@nospecialize(T), ctx, dl, seen=nothing) merge!(tt, subtree) end + canonicalize!(tt, sizeof(T), dl) return tt end diff --git a/test/typetree.jl b/test/typetree.jl index 174ea3ff5b..5df21c10ed 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -28,11 +28,7 @@ end @test tt(Symbol) == "{}" @test tt(String) == "{}" @test tt(AbstractChannel) == "{}" - if sizeof(Int) == sizeof(Int64) - @test tt(Base.ImmutableDict{Symbol, Any}) == "{[0]:Pointer, [8]:Pointer, [16]:Pointer}" - else - @test tt(Base.ImmutableDict{Symbol, Any}) == "{[0]:Pointer, [4]:Pointer, [8]:Pointer}" - end + @test tt(Base.ImmutableDict{Symbol, Any}) == "{[-1]:Pointer}" @test tt(Atom) == "{[0]:Float@float, [4]:Float@float, [8]:Float@float, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer}" @test tt(Composite) == "{[0]:Float@float, [4]:Float@float, [8]:Float@float, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@float, [20]:Float@float, [24]:Float@float, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer}" @test tt(Tuple{Any,Any}) == "{[-1]:Pointer}"