Skip to content

Commit

Permalink
Merge branch 'main' into mhauru/distributions-integration-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru authored Dec 6, 2024
2 parents baf01e5 + 8178345 commit c14e27f
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3338,7 +3338,7 @@ function GPUCompiler.codegen(
if !has_fn_attr(f, EnumAttribute("alwaysinline"))
continue
end
if !has_fn_attr(f, EnumAttribute("returnstwice"))
if !has_fn_attr(f, EnumAttribute("returns_twice"))
push!(function_attributes(f), EnumAttribute("returns_twice"))
push!(toremove, name(f))
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
(Ptr{Cvoid}, Cstring, Ptr{Cvoid}),
arg1,
fname,
reinterpret(Ptr{Cvoid}, JIT.lookup(hnd).ptr),
pointer(JIT.lookup(hnd)),
)
else
res = ccall(
Expand All @@ -762,7 +762,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
(Ptr{Cvoid}, Cstring, Ptr{Cvoid}),
arg1,
fname,
reinterpret(Ptr{Cvoid}, JIT.lookup(hnd).ptr),
pointer(JIT.lookup(hnd)),
)
end
replaceWith = LLVM.ConstantInt(
Expand Down
82 changes: 58 additions & 24 deletions src/errors.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
const VERBOSE_ERRORS = Ref(false)

abstract type CompilationException <: Base.Exception end

struct EnzymeRuntimeException <: Base.Exception
msg::Cstring
end

function Base.showerror(io::IO, ece::EnzymeRuntimeException)
print(io, "Enzyme execution failed.\n")
msg = Base.unsafe_string(ece.msg)
print(io, msg, '\n')
end

struct NoDerivativeException <: CompilationException
msg::String
ir::Union{Nothing,String}
Expand All @@ -8,10 +21,22 @@ end
function Base.showerror(io::IO, ece::NoDerivativeException)
print(io, "Enzyme compilation failed.\n")
if ece.ir !== nothing
print(io, "Current scope: \n")
print(io, ece.ir)
if VERBOSE_ERRORS[]
print(io, "Current scope: \n")
print(io, ece.ir)
else
print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
end
end
if occursin("cannot handle unknown binary operator", ece.msg)
for msg in ece.msg.split('\n')
if occursin("cannot handle unknown binary operator", msg)
print('\n', msg, '\n')
end
end
else
print(io, '\n', ece.msg, '\n')
end
print(io, '\n', ece.msg, '\n')
if ece.bt !== nothing
Base.show_backtrace(io, ece.bt)
println(io)
Expand All @@ -27,13 +52,18 @@ end

function Base.showerror(io::IO, ece::IllegalTypeAnalysisException)
print(io, "Enzyme compilation failed due to illegal type analysis.\n")
if ece.ir !== nothing
print(io, "Current scope: \n")
print(io, ece.ir)
print(io, " This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].\n")
print(io, " Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.\n")
print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
if VERBOSE_ERRORS[]
if ece.ir !== nothing
print(io, "Current scope: \n")
print(io, ece.ir)
end
print(io, "\n Type analysis state: \n")
write(io, ece.sval)
print(io, '\n', ece.msg, '\n')
end
print(io, "\n Type analysis state: \n")
write(io, ece.sval)
print(io, '\n', ece.msg, '\n')
if ece.bt !== nothing
print(io, "\nCaused by:")
Base.show_backtrace(io, ece.bt)
Expand All @@ -48,10 +78,14 @@ struct IllegalFirstPointerException <: CompilationException
end

function Base.showerror(io::IO, ece::IllegalFirstPointerException)
print(io, "Enzyme compilation failed.\n")
if ece.ir !== nothing
print(io, "Enzyme compilation failed due to an internal error (first pointer exception).\n")
print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl")
print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
if VERBOSE_ERRORS[]
if ece.ir !== nothing
print(io, "Current scope: \n")
print(io, ece.ir)
end
end
print(io, '\n', ece.msg, '\n')
if ece.bt !== nothing
Expand All @@ -67,28 +101,28 @@ struct EnzymeInternalError <: CompilationException
end

function Base.showerror(io::IO, ece::EnzymeInternalError)
print(io, "Enzyme compilation failed.\n")
if ece.ir !== nothing
print(io, "Enzyme compilation failed due to an internal error.\n")
print(io, " Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl")
print(io, " To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)\n")
if VERBOSE_ERRORS[]
if ece.ir !== nothing
print(io, "Current scope: \n")
print(io, ece.ir)
end
print(io, '\n', ece.msg, '\n')
else
for msg in ece.msg.split('\n')
if occursin("Illegal replace ficticious phi for", msg)
print('\n', msg, '\n')
end
end
end
print(io, '\n', ece.msg, '\n')
if ece.bt !== nothing
Base.show_backtrace(io, ece.bt)
println(io)
end
end

struct EnzymeRuntimeException <: Base.Exception
msg::Cstring
end

function Base.showerror(io::IO, ece::EnzymeRuntimeException)
print(io, "Enzyme execution failed.\n")
msg = Base.unsafe_string(ece.msg)
print(io, msg, '\n')
end

struct EnzymeMutabilityException <: Base.Exception
msg::Cstring
end
Expand Down
32 changes: 32 additions & 0 deletions src/llvm/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,13 @@ function propagate_returned!(mod::LLVM.Module)
end
else
for u in LLVM.uses(un)
u = LLVM.user(u)
if u isa LLVM.CallInst
op = LLVM.called_operand(u)
if op isa LLVM.Function && LLVM.name(op) == "llvm.enzymefakeread"
continue
end
end
hasAnyUse = true
break
end
Expand Down Expand Up @@ -1611,6 +1618,12 @@ end

function delete_writes_into_removed_args(fn::LLVM.Function, toremove::Vector{Int64}, keepret::Bool)
args = collect(parameters(fn))
if !keepret
for u in LLVM.uses(fn)
u = LLVM.user(u)
replace_uses!(u, LLVM.UndefValue(value_type(u)))
end
end
for tr in toremove
tr = tr + 1
todorep = Tuple{LLVM.Instruction, LLVM.Value}[]
Expand Down Expand Up @@ -2038,6 +2051,25 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
if isempty(blocks(fn))
continue
end

rt = LLVM.return_type(LLVM.function_type(fn))
if rt isa LLVM.PointerType && addrspace(rt) == 10
for u in LLVM.uses(fn)
u = LLVM.user(u)
if isa(u, LLVM.CallInst)
B = IRBuilder()
nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
position!(B, nextInst)
cl = call!(B, funcT, rfunc, LLVM.Value[u])
LLVM.API.LLVMAddCallSiteAttribute(
cl,
LLVM.API.LLVMAttributeIndex(1),
EnumAttribute("nocapture"),
)
end
end
end

# Ensure that interprocedural optimizations do not delete the use of returnRoots (or shadows)
# if inactive sret, this will only occur on 2. If active sret, inactive retRoot, can on 3, and
# active both can occur on 4. If the original sret is removed (at index 1) we no longer need
Expand Down
80 changes: 80 additions & 0 deletions test/passes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Enzyme, LLVM, Test


@testset "Partial return preservation" begin
LLVM.Context() do ctx
mod = parse(LLVM.Module, """
source_filename = "start"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-linux-gnu"
declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5
define internal fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30
%a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*
%a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1
store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8
%a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*
%a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8
%a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)*
store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8
ret {} addrspace(10)* %newstruct
}
define {} addrspace(10)* @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%ac = call fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2)
%b = addrspacecast {} addrspace(10)* %ac to {} addrspace(10)* addrspace(11)*
%c = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %b unordered, align 8
ret {} addrspace(10)* %c
}
attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" }
""")

Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm())

callfn = LLVM.functions(mod)["inner"]
@test length(collect(filter(Base.Fix2(isa, LLVM.StoreInst), collect(instructions(first(blocks(callfn))))))) == 2
end
end


@testset "Dead return removal" begin
LLVM.Context() do ctx
mod = parse(LLVM.Module, """
source_filename = "start"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-linux-gnu"
declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5
define internal fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30
%a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*
%a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1
store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8
%a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*
%a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8
%a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)*
store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8
ret {} addrspace(10)* %newstruct
}
define void @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%ac = call fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2)
ret void
}
attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" }
""")

Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm())
callfn = LLVM.functions(mod)["caller"]
@test length(collect(instructions(first(blocks(callfn))))) == 1
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ end

include("abi.jl")
include("typetree.jl")
include("passes.jl")
include("optimize.jl")
include("make_zero.jl")

Expand Down

0 comments on commit c14e27f

Please sign in to comment.