From 4f160a0de183a88eaacdff1944bd6dc5a509d72a Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 10:06:25 -0600 Subject: [PATCH 1/8] Complex bessel (#2179) * Complex bessel * fix * more tests * Update EnzymeSpecialFunctionsExt.jl * Update Project.toml --- Project.toml | 2 +- ext/EnzymeSpecialFunctionsExt.jl | 3 +++ src/compiler.jl | 5 +++++ test/ext/specialfunctions.jl | 10 ++++------ 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 46b37ccb12..fa5a2c7e99 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.8" -Enzyme_jll = "0.0.167" +Enzyme_jll = "0.0.168" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/ext/EnzymeSpecialFunctionsExt.jl b/ext/EnzymeSpecialFunctionsExt.jl index 65d87dc118..09e62e98c4 100644 --- a/ext/EnzymeSpecialFunctionsExt.jl +++ b/ext/EnzymeSpecialFunctionsExt.jl @@ -5,6 +5,9 @@ using Enzyme function __init__() Enzyme.Compiler.known_ops[typeof(SpecialFunctions._logabsgamma)] = (:logabsgamma, 1, (:digamma, typeof(SpecialFunctions.digamma))) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.bessely)] = (:cmplx_jn, 2, nothing) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselj)] = (:cmplx_jn, 2, nothing) + Enzyme.Compiler.cmplx_known_ops[typeof(SpecialFunctions.besselk)] = (:cmplx_kn, 2, nothing) end end diff --git a/src/compiler.jl b/src/compiler.jl index 2f4cbfc8da..5fcc53dbde 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -146,6 +146,11 @@ const known_ops = Dict{DataType,Tuple{Symbol,Int,Union{Nothing,Tuple{Symbol,Data name, arity, toinject = cmplx_known_ops[func] Tys = (Complex{Float32}, Complex{Float64}) if length(sparam_vals) == arity + if name == :cmplx_jn || name == :cmplx_yn + if (sparam_vals[2] ∈ Tys) && sparam_vals[2].parameters[1] == sparam_vals[1] + return name, toinject, sparam_vals[2] + end + end T = first(sparam_vals) if (T isa Type) T = T::Type diff --git a/test/ext/specialfunctions.jl b/test/ext/specialfunctions.jl index 1a87cf2d2b..a64c214489 100644 --- a/test/ext/specialfunctions.jl +++ b/test/ext/specialfunctions.jl @@ -16,11 +16,9 @@ using SpecialFunctions # test_scalar(SpecialFunctions.airyaiprime, x) # test_scalar(SpecialFunctions.airybi, x) # test_scalar(SpecialFunctions.airybiprime, x) - if x isa Real - test_scalar(SpecialFunctions.besselj0, x) - test_scalar(SpecialFunctions.besselj1, x) - test_scalar((y) -> SpecialFunctions.besselj(2, y), x) - end + test_scalar(SpecialFunctions.besselj0, x) + test_scalar(SpecialFunctions.besselj1, x) + test_scalar((y) -> SpecialFunctions.besselj(2, y), x) # test_scalar((y) -> SpecialFunctions.sphericalbessely(y, 0.5), 0.3) # test_scalar(SpecialFunctions.dawson, x) @@ -36,7 +34,7 @@ using SpecialFunctions # test_scalar(SpecialFunctions.erfcinv, x) end - if x isa Real && x > 0 + if !(x isa Real) || x > 0 test_scalar(SpecialFunctions.bessely0, x) test_scalar(SpecialFunctions.bessely1, x) test_scalar((y) -> SpecialFunctions.bessely(2, y), x) From 66ded5f20ecb69ad146dce90d25253a8b686fc84 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 13:04:33 -0600 Subject: [PATCH 2/8] workaround i1 issue in llvm.jl (#2181) --- src/absint.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/absint.jl b/src/absint.jl index 3b9034bef6..50282e745c 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -205,7 +205,11 @@ end end function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool - sz = sizeof(dl, arg_t) + sz = if arg_t == LLVM.IntType(1) + 1 + else + sizeof(dl, arg_t) + end if byref != GPUCompiler.BITS_VALUE if sz != sizeof(Int) throw(AssertionError("non bits type $byref of $typ2 has size $sz != sizeof(Int) from arg type $arg_t")) From 3b36ea25157efd37ebe63d1e025decb6defaeb43 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 13:04:57 -0600 Subject: [PATCH 3/8] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fa5a2c7e99..f7bcad34fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.19" +version = "0.13.20" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 2fa5bb1352e9771d5f42ae5e054dd459e8af5409 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Dec 2024 19:19:44 -0600 Subject: [PATCH 4/8] Nofree for math methods (#2184) * Nofree for math methods * fix --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 5fcc53dbde..0155e5da34 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3961,6 +3961,9 @@ end lowerConvention = false end k_name = LLVM.name(llvmfn) + if !has_fn_attr(llvmfn, EnumAttribute("nofree")) + push!(LLVM.function_attributes(llvmfn), EnumAttribute("nofree")) + end end name = string(name) From 865cced8bf96e6300663fe8d8775637957ec056f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 00:02:44 -0600 Subject: [PATCH 5/8] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f7bcad34fc..940446335e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.20" +version = "0.13.21" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" From 3edec409c4e43590320df0b02a3463a24638ae0e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 00:03:35 -0600 Subject: [PATCH 6/8] Fix higher order codegen (#2161) * Fix higher order codegen * fix * fix * working * Update validation.jl * handle, again * Update validation.jl --- src/compiler.jl | 26 +++- src/compiler/interpreter.jl | 25 +--- src/compiler/validation.jl | 231 ++++++------------------------------ src/llvm/transforms.jl | 188 +++++++++++++++++++++++++++++ src/rules/parallelrules.jl | 4 +- 5 files changed, 245 insertions(+), 229 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0155e5da34..36d9c1473d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5226,12 +5226,12 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType)) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), - TapeType, + TapeType ) end @@ -5269,7 +5269,7 @@ end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool = true) +function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String} mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf @@ -5287,7 +5287,12 @@ function _thunk(job, postopt::Bool = true) end # Run post optimization pipeline - if postopt + prepost = if postopt + mstr = if job.config.params.ABI <: InlineABI + "" + else + string(mod) + end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) if DumpPostOpt[] @@ -5296,12 +5301,17 @@ function _thunk(job, postopt::Bool = true) else propagate_returned!(mod) end + mstr + else + "" end - return (mod, adjoint_name, primal_name, meta.TapeType) + return (mod, adjoint_name, primal_name, meta.TapeType, prepost) end const cache = Dict{UInt,CompileResult}() +const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}() + const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult key = hash(job) @@ -5313,6 +5323,12 @@ const cache_lock = ReentrantLock() if obj === nothing asm = _thunk(job) obj = _link(job, asm...) + if obj.adjoint isa Ptr{Nothing} + autodiff_cache[obj.adjoint] = (asm[2], asm[5]) + end + if obj.primal isa Ptr{Nothing} && asm[3] isa String + autodiff_cache[obj.primal] = (asm[3], asm[5]) + end cache[key] = obj end obj diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2d02604eda..2f9d1fbf60 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -44,7 +44,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter forward_rules::Bool reverse_rules::Bool - deferred_lower::Bool broadcast_rewrite::Bool handler::T end @@ -55,7 +54,6 @@ function EnzymeInterpreter( world::UInt, forward_rules::Bool, reverse_rules::Bool, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing ) @@ -83,7 +81,6 @@ function EnzymeInterpreter( IdDict{Any, Bool}(), forward_rules, reverse_rules, - deferred_lower, broadcast_rewrite, handler ) @@ -94,10 +91,9 @@ EnzymeInterpreter( mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler) Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params @@ -865,25 +861,6 @@ function abstract_call_known( end end - if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4 - if widenconst(argtypes[2]) <: Enzyme.Mode && - widenconst(argtypes[3]) <: Enzyme.Annotation && - widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : - [:(Enzyme.autodiff_deferred), fargs[2:end]...], - [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], - ) - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - Enzyme.autodiff_deferred::Any, - arginfo2::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) - end - end if interp.handler != nothing return interp.handler(interp, f, arginfo, si, sv, max_methods) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index e109415d0f..525e4d874c 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -129,9 +129,7 @@ end function memoize!(ptr::Ptr{Cvoid}, fn::String)::String fn = get(ptr_map, ptr, fn) - if !haskey(ptr_map, ptr) - ptr_map[ptr] = fn - else + if haskey(ptr_map, ptr) @assert ptr_map[ptr] == fn end return fn @@ -185,194 +183,6 @@ function check_ir(@nospecialize(job::CompilerJob), mod::LLVM.Module) end end -# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA -function rewrite_ccalls!(mod::LLVM.Module) - for f in collect(functions(mod)) - replaceAndErase = Tuple{Instruction,Instruction}[] - for bb in blocks(f), inst in instructions(bb) - if isa(inst, LLVM.CallInst) - fn = called_operand(inst) - changed = false - B = IRBuilder() - position!(B, inst) - if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" - uservals = LLVM.Value[] - for lval in collect(arguments(inst)) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - changed = true - end - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - if !isdefined(LLVM, :OperandBundleDef) - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(operand_bundles(inst)), - prevname, - ) - else - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), - prevname, - ) - end - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - continue - end - if !isdefined(LLVM, :OperandBundleDef) - newbundles = OperandBundle[] - else - newbundles = OperandBundleDef[] - end - for bunduse in operand_bundles(inst) - if isdefined(LLVM, :OperandBundleDef) - bunduse = LLVM.OperandBundleDef(bunduse) - end - - if !isdefined(LLVM, :OperandBundleDef) - if LLVM.tag(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - else - if LLVM.tag_name(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - end - uservals = LLVM.Value[] - subchanged = false - for lval in LLVM.inputs(bunduse) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - subchanged = true - end - if !subchanged - push!(newbundles, bunduse) - continue - end - changed = true - if !isdefined(LLVM, :OperandBundleDef) - push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) - else - push!( - newbundles, - OperandBundleDef(LLVM.tag_name(bunduse), uservals), - ) - end - end - changed = false - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - newinst = call!( - B, - called_type(inst), - called_operand(inst), - collect(arguments(inst)), - newbundles, - prevname, - ) - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - end - end - for (inst, newinst) in replaceAndErase - replace_uses!(inst, newinst) - LLVM.API.LLVMInstructionEraseFromParent(inst) - end - end -end - function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod::LLVM.Module) imported = Set(String[]) if haskey(functions(mod), "malloc") @@ -390,14 +200,14 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) eraseInst(mod, f) end - rewrite_ccalls!(mod) + Compiler.rewrite_ccalls!(mod) del = LLVM.Function[] for f in collect(functions(mod)) if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -408,7 +218,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -417,7 +227,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod return errors end -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}, mod::LLVM.Module) calls = LLVM.CallInst[] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 mod = LLVM.parent(f) @@ -643,7 +453,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp while length(calls) > 0 inst = pop!(calls) - check_ir!(job, errors, imported, inst, calls) + check_ir!(job, errors, imported, inst, calls, mod) end return errors end @@ -690,7 +500,7 @@ end import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}, mod::LLVM.Module) world = job.world interp = GPUCompiler.get_interpreter(job) method_table = Core.Compiler.method_table(interp) @@ -1211,13 +1021,36 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp ptr_val = convert(Int, ptr_arg) ptr = Ptr{Cvoid}(ptr_val) + if haskey(autodiff_cache, ptr) + pname, pmod = autodiff_cache[ptr] + + @assert !haskey(functions(mod), pname) + + pmod = parse(LLVM.Module, pmod) + + @assert haskey(functions(pmod), pname) + + for fn in functions(pmod) + if !isempty(LLVM.blocks(fn)) + linkage!(fn, LLVM.name(fn) != pname ? LLVM.API.LLVMInternalLinkage : LLVM.API.LLVMExternalLinkage) + end + end + + GPUCompiler.link_library!(mod, pmod) + + replaceWith = functions(mod)[pname] + push!(function_attributes(replaceWith), EnumAttribute("alwaysinline")) + linkage!(functions(mod)[pname], LLVM.API.LLVMInternalLinkage) + replace_uses!(ptr_arg, LLVM.const_pointercast(replaceWith, value_type(ptr_arg))) + return errors + end + # look it up in the Julia JIT cache frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint), ptr, 0) if length(frames) >= 1 fn, file, line, linfo, fromC, inlined = last(frames) - # Remember pointer in our global map fn = FFI.memoize!(ptr, string(fn)) if length(fn) > 1 && fromC @@ -1229,6 +1062,8 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp fn, LLVM.API.LLVMGetCalledFunctionType(inst), ) + # Remember pointer for subsequent restoration + push!(function_attributes(LLVM.Function(lfn)), StringAttribute("enzymejl_needs_restoration", string(reinterpret(UInt, ptr)))) else lfn = LLVM.API.LLVMConstBitCast( lfn, diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index aebb8bab5c..2f9c61c0b4 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -1,4 +1,192 @@ +# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA +function rewrite_ccalls!(mod::LLVM.Module) + for f in collect(functions(mod)) + replaceAndErase = Tuple{Instruction,Instruction}[] + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.CallInst) + fn = called_operand(inst) + changed = false + B = IRBuilder() + position!(B, inst) + if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" + uservals = LLVM.Value[] + for lval in collect(arguments(inst)) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + changed = true + end + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + if !isdefined(LLVM, :OperandBundleDef) + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(operand_bundles(inst)), + prevname, + ) + else + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), + prevname, + ) + end + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + continue + end + if !isdefined(LLVM, :OperandBundleDef) + newbundles = OperandBundle[] + else + newbundles = OperandBundleDef[] + end + for bunduse in operand_bundles(inst) + if isdefined(LLVM, :OperandBundleDef) + bunduse = LLVM.OperandBundleDef(bunduse) + end + + if !isdefined(LLVM, :OperandBundleDef) + if LLVM.tag(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + else + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + end + uservals = LLVM.Value[] + subchanged = false + for lval in LLVM.inputs(bunduse) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + subchanged = true + end + if !subchanged + push!(newbundles, bunduse) + continue + end + changed = true + if !isdefined(LLVM, :OperandBundleDef) + push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) + else + push!( + newbundles, + OperandBundleDef(LLVM.tag_name(bunduse), uservals), + ) + end + end + changed = false + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!( + B, + called_type(inst), + called_operand(inst), + collect(arguments(inst)), + newbundles, + prevname, + ) + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + end + end + for (inst, newinst) in replaceAndErase + replace_uses!(inst, newinst) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end +end + function force_recompute!(mod::LLVM.Module) for f in functions(mod), bb in blocks(f) iter = LLVM.API.LLVMGetFirstInstruction(bb) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 78c9cd9ce8..d4356aba61 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -275,7 +275,7 @@ end world, ) - cmod, fwdmodenm, _, _ = _thunk(ejob, false) #=postopt=# + cmod, fwdmodenm, _, _, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) @@ -334,7 +334,7 @@ end world, ) - cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, false) #=postopt=# + cmod, adjointnm, augfwdnm, TapeType, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) From 551ddd1ca94f1a2f7b1fe018e6415af191115ec9 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 00:59:01 -0600 Subject: [PATCH 7/8] Update errors.jl --- src/errors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index 4f7b6f9480..83946f1dbd 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -29,7 +29,7 @@ function Base.showerror(io::IO, ece::NoDerivativeException) end end if occursin("cannot handle unknown binary operator", ece.msg) - for msg in ece.msg.split('\n') + for msg in split(ece.msg, '\n') if occursin("cannot handle unknown binary operator", msg) print('\n', msg, '\n') end @@ -111,7 +111,7 @@ function Base.showerror(io::IO, ece::EnzymeInternalError) end print(io, '\n', ece.msg, '\n') else - for msg in ece.msg.split('\n') + for msg in split(ece.msg, '\n') if occursin("Illegal replace ficticious phi for", msg) print('\n', msg, '\n') end From 8e10a0a37d42db7ced533461bdc3d986ce22e3af Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 11:05:03 -0600 Subject: [PATCH 8/8] World backedge holder (#2183) * World backedge holder * fix * fixup * Update compiler.jl * Update compiler.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * Update interpreter.jl * try2 * nothing works rip * more test * hn * keep trying * more * fix * fix * isapplic * fix2 * mark broken --- src/compiler.jl | 17 +++- src/compiler/interpreter.jl | 176 ++++++++++++++++++++++++++++++++---- test/ruleinvalidation.jl | 7 +- 3 files changed, 181 insertions(+), 19 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 36d9c1473d..ddccde1a24 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5573,7 +5573,22 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: # new_ci.min_world = min_world[] new_ci.min_world = world new_ci.max_world = max_world[] - new_ci.edges = Core.MethodInstance[mi] + + edges = Core.MethodInstance[mi] + + if Mode == API.DEM_ForwardMode + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, world)) + Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) + else + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, world)) + end + + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, world)) + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{Val{0}}, world)) + Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(Val(0))) + + new_ci.edges = edges + # XXX: setting this edge does not give us proper method invalidation, see # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. # invoking `code_llvm` also does the necessary codegen, as does calling the diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2f9d1fbf60..bd57ec92dd 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -8,6 +8,7 @@ using Core.Compiler: OptimizationParams, MethodInstance using GPUCompiler: @safe_debug +using GPUCompiler if VERSION < v"1.11.0-DEV.1552" using GPUCompiler: CodeCache, WorldView, @safe_debug end @@ -23,6 +24,141 @@ else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end +function rule_backedge_holder_generator(world::UInt, source, self, ft::Type) + @nospecialize + sig = Tuple{typeof(Base.identity), Int} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + has_ambig = Ptr{Int32}(C_NULL) + mthds = Base._methods_by_ftype( + sig, + nothing, + -1, #=lim=# + world, + false, #=ambig=# + min_world, + max_world, + has_ambig, + ) + mtypes, msp, m = mthds[1] + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + m, + mtypes, + msp, + ) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + + ### TODO: backedge from inactive, augmented_primal, forward, reverse + edges = Any[] + + @static if false + if ft == typeof(EnzymeRules.augmented_primal) + # this is illegal + # sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.augmented_primal), Tuple{<:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world)) + elseif ft == typeof(EnzymeRules.forward) + # this is illegal + # sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.forward), Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world)) + else + # sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive), Tuple{Vararg{Annotation}}, world)) + + # sig = Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Annotation}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_noinl), Tuple{Vararg{Annotation}}, world)) + + # sig = Tuple{typeof(EnzymeRules.noalias), Vararg{Any}} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.noalias), Tuple{Vararg{Any}}, world)) + + # sig = Tuple{typeof(EnzymeRules.inactive_type), Type} + # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) + push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_type), Tuple{Type}, world)) + end + end + + new_ci.edges = edges + + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :ft] + new_ci.slotflags = UInt8[0x00 for i = 1:2] + + # return the codegen world age + push!(new_ci.code, Core.Compiler.ReturnNode(0)) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below + end + new_ci.ssavaluetypes += 1 + + return new_ci +end + +@eval Base.@assume_effects :removable :foldable :nothrow @inline function rule_backedge_holder(ft) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, rule_backedge_holder_generator)) +end + +begin + # Forward-rule catch all + fwd_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}) + # Reverse-rule catch all + rev_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}) + # Inactive-rule catch all + ina_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}) + # All other derivative-related catch all (just for autodiff, not inference), including inactive_noinl, noalias, and inactive_type + gen_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{Val{0}}) + + + fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + EnzymeRules.add_mt_backedge!(fwd_rule_be, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable, fwd_sig) + + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + EnzymeRules.add_mt_backedge!(rev_rule_be, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable, rev_sig) + + + for ina_sig in ( + Tuple{typeof(EnzymeRules.inactive), Vararg{Any}}, + ) + EnzymeRules.add_mt_backedge!(ina_rule_be, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable, ina_sig) + end + + for gen_sig in ( + Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, + Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_type), Type}, + ) + EnzymeRules.add_mt_backedge!(gen_rule_be, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable, gen_sig) + end +end + struct EnzymeInterpreter{T} <: AbstractInterpreter @static if HAS_INTEGRATED_CACHE token::Any @@ -40,8 +176,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams - rules_cache::IdDict{Any, Bool} - forward_rules::Bool reverse_rules::Bool broadcast_rewrite::Bool @@ -78,7 +212,6 @@ function EnzymeInterpreter( # parameters for inference and optimization parms, OptimizationParams(), - IdDict{Any, Bool}(), forward_rules, reverse_rules, broadcast_rewrite, @@ -99,6 +232,7 @@ Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache + @static if HAS_INTEGRATED_CACHE Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token else @@ -221,25 +355,35 @@ function Core.Compiler.abstract_call_gf_by_type( elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) else - # 1. Check if function is inactive - if is_inactive_from_sig(interp, specTypes, sv) + method_table = Core.Compiler.method_table(interp) + if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else - # 2. Check if rule is defined - has_rule = get!(interp.rules_cache, specTypes) do - if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) - return true - elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) - return true - else - return false + if interp.forward_rules + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :frule) + end + end + + if interp.reverse_rules + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end - if has_rule - callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) - end end + + if interp.forward_rules + Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, interp.world)::Core.MethodInstance) + Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) + end + if interp.reverse_rules + Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance) + Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.augmented_primal)) + end + Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, interp.world)::Core.MethodInstance) + Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(typeof(EnzymeRules.inactive))) end + @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) else diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 37cb21b08f..501b0aac10 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -42,6 +42,9 @@ end # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 - +@static if VERSION < v"1.11-" + @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +else + @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +end end # module