diff --git a/src/absint.jl b/src/absint.jl index ae9c35a09b..10dc024013 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -323,12 +323,17 @@ function abs_typeof(arg::LLVM.Value, partial::Bool=false)::Union{Tuple{Bool, Typ end function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} - if isa(arg, ConstantExpr) ce = arg while isa(ce, ConstantExpr) if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr ce = operands(ce)[1] + elseif opcode(ce) == LLVM.API.LLVMGetElementPtr + if all(x -> isa(x, LLVM.ConstantInt) && convert(UInt, x) == 0, operands(ce)[2:end]) + ce = operands(ce)[1] + else + break + end else break end @@ -336,7 +341,7 @@ function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String} if isa(ce, LLVM.GlobalVariable) ce = LLVM.initializer(ce) if (isa(ce, LLVM.ConstantArray) || isa(ce, LLVM.ConstantDataArray)) && eltype(value_type(ce)) == LLVM.IntType(8) - return (true, String(map((x)->convert(UInt8, x), collect(flib)[1:(end-1)]))) + return (true, String(map((x)->convert(UInt8, x), collect(ce)[1:(end-1)]))) end end diff --git a/src/compiler.jl b/src/compiler.jl index f2d1829571..3f57743cc0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3238,7 +3238,11 @@ function annotate!(mod, mode) for fname in ("julia.typeof",) if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readnone", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end @@ -3246,15 +3250,18 @@ function annotate!(mod, mode) for fname in ("jl_excstack_state","ijl_excstack_state") if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.StringAttribute("inaccessiblememonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end for fname in ("jl_types_equal", "ijl_types_equal") if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end @@ -3278,7 +3285,12 @@ function annotate!(mod, mode) if operands(c)[1] != fn continue end - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("readonly", 0)) + attr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("readonly") + else + EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + end + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) end end end @@ -3287,7 +3299,11 @@ function annotate!(mod, mode) if haskey(fns, fname) fn = fns[fname] # TODO per discussion w keno perhaps this should change to readonly / inaccessiblememonly - push!(function_attributes(fn), LLVM.EnumAttribute("readnone", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end push!(function_attributes(fn), LLVM.StringAttribute("enzyme_shouldrecompute")) end end @@ -3320,7 +3336,11 @@ function annotate!(mod, mode) for fname in ("julia.pointer_from_objref",) if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readnone", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end end end @@ -3336,8 +3356,13 @@ function annotate!(mod, mode) fn = fns[boxfn] push!(return_attributes(fn), LLVM.EnumAttribute("noalias", 0)) push!(function_attributes(fn), no_escaping_alloc) + accattr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("inaccessiblememonly") + else + EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + end if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + push!(function_attributes(fn), accattr) end for u in LLVM.uses(fn) c = LLVM.user(u) @@ -3348,7 +3373,7 @@ function annotate!(mod, mode) if cf == fn LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), accattr) end end if !isa(cf, LLVM.Function) @@ -3363,7 +3388,12 @@ function annotate!(mod, mode) LLVM.API.LLVMAddCallSiteAttribute(c, LLVM.API.LLVMAttributeReturnIndex, LLVM.EnumAttribute("noalias", 0)) LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), no_escaping_alloc) if !(boxfn in ("jl_array_copy", "ijl_array_copy", "jl_idtable_rehash", "ijl_idtable_rehash")) - LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), LLVM.EnumAttribute("inaccessiblememonly", 0)) + attr = if LLVM.version().major <= 15 + LLVM.EnumAttribute("inaccessiblememonly") + else + EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data) + end + LLVM.API.LLVMAddCallSiteAttribute(c, reinterpret(LLVM.API.LLVMAttributeIndex, LLVM.API.LLVMAttributeFunctionIndex), attr) end end end @@ -3372,14 +3402,22 @@ function annotate!(mod, mode) for gc in ("llvm.julia.gc_preserve_begin", "llvm.julia.gc_preserve_end") if haskey(fns, gc) fn = fns[gc] - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end for rfn in ("jl_object_id_", "jl_object_id", "ijl_object_id_", "ijl_object_id") if haskey(fns, rfn) fn = fns[rfn] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("memory", NoEffects.data)) + end end end @@ -3388,8 +3426,12 @@ function annotate!(mod, mode) if haskey(fns, rfn) fn = fns[rfn] push!(parameter_attributes(fn, 2), LLVM.StringAttribute("enzyme_inactive")) - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end # Key of jl_eqtable_get/put is inactive, definitionally @@ -3400,15 +3442,23 @@ function annotate!(mod, mode) push!(parameter_attributes(fn, 4), LLVM.StringAttribute("enzyme_inactive")) push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("writeonly")) push!(parameter_attributes(fn, 4), LLVM.EnumAttribute("nocapture")) - push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("argmemonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end for rfn in ("jl_in_threaded_region_", "jl_in_threaded_region") if haskey(fns, rfn) fn = fns[rfn] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) - push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly")) + push!(function_attributes(fn), LLVM.EnumAttribute("inaccessiblememonly")) + else + push!(function_attributes(fn), EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data)) + end end end end @@ -4893,17 +4943,26 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if kind(prev) == kind(StringAttribute("enzyme_shouldrecompute")) push!(attributes, prev) end - if kind(prev) == kind(EnumAttribute("readonly")) - push!(attributes, prev) - end - if kind(prev) == kind(EnumAttribute("readnone")) - push!(attributes, prev) + if LLVM.version().major <= 15 + if kind(prev) == kind(EnumAttribute("readonly")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("readnone")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("argmemonly")) + push!(attributes, prev) + end + if kind(prev) == kind(EnumAttribute("inaccessiblememonly")) + push!(attributes, prev) + end end - if kind(prev) == kind(EnumAttribute("argmemonly")) - push!(attributes, prev) + if LLVM.version().major > 15 + if kind(prev) == kind(EnumAttribute("memory")) + old = MemoryEffect(value(attr)) + mem = MemoryEffect(( set_writing(getModRef(old, ArgMem)) << getLocationPos(ArgMem)) | (getModRef(old, InaccessibleMem) << getLocationPos(InaccessibleMem)) | (getModRef(old, Other) << getLocationPos(Other))) + push!(attributes, EnumAttribute("memory", mem.data)) end - if kind(prev) == kind(EnumAttribute("inaccessiblememonly")) - push!(attributes, prev) end if kind(prev) == kind(EnumAttribute("speculatable")) push!(attributes, prev) @@ -5382,44 +5441,85 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals if func == typeof(Base.eps) || func == typeof(Base.nextfloat) || func == typeof(Base.prevfloat) - handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), - EnumAttribute("readnone", 0), - EnumAttribute("speculatable", 0), + if LLVM.version().major <= 15 + handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), + EnumAttribute("readnone"), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute") + ]) + else + handleCustom(llvmfn, "jl_inactive_inout", [StringAttribute("enzyme_inactive"), + EnumAttribute("memory", NoEffects.data), + EnumAttribute("speculatable"), StringAttribute("enzyme_shouldrecompute") ]) + end continue end if func == typeof(Base.to_tuple_type) - handleCustom(llvmfn, "jl_to_tuple_type", - [EnumAttribute("readonly", 0), - EnumAttribute("inaccessiblememonly", 0), - EnumAttribute("speculatable", 0), - StringAttribute("enzyme_shouldrecompute"), - StringAttribute("enzyme_inactive"), - ]) + if LLVM.version().major <= 15 + handleCustom(llvmfn, "jl_to_tuple_type", + [EnumAttribute("readonly"), + EnumAttribute("inaccessiblememonly", 0), + EnumAttribute("speculatable", 0), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + ]) + else + handleCustom(llvmfn, "jl_to_tuple_type", + [ + EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + EnumAttribute("inaccessiblememonly", 0), + EnumAttribute("speculatable", 0), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + ]) + end continue end if func == typeof(Base.mightalias) - handleCustom(llvmfn, "jl_mightalias", - [EnumAttribute("readonly", 0), - StringAttribute("enzyme_shouldrecompute"), - StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation"), - EnumAttribute("nofree"), - StringAttribute("enzyme_ta_norecur"), - ], true, false) + if LLVM.version().major <= 15 + handleCustom(llvmfn, "jl_mightalias", + [EnumAttribute("readonly"), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation"), + EnumAttribute("nofree"), + StringAttribute("enzyme_ta_norecur"), + ], true, false) + else + handleCustom(llvmfn, "jl_mightalias", + [ + EnumAttribute("memory", ReadOnlyEffects.data), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation"), + EnumAttribute("nofree"), + StringAttribute("enzyme_ta_norecur"), + ], true, false) + end continue end if func == typeof(Base.Threads.threadid) || func == typeof(Base.Threads.nthreads) name = (func == typeof(Base.Threads.threadid)) ? "jl_threadid" : "jl_nthreads" - handleCustom(llvmfn, name, - [EnumAttribute("readonly", 0), - EnumAttribute("inaccessiblememonly", 0), - EnumAttribute("speculatable", 0), - StringAttribute("enzyme_shouldrecompute"), - StringAttribute("enzyme_inactive"), - StringAttribute("enzyme_no_escaping_allocation") - ]) + if LLVM.version().major <= 15 + handleCustom(llvmfn, name, + [EnumAttribute("readonly"), + EnumAttribute("inaccessiblememonly"), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation") + ]) + else + handleCustom(llvmfn, name, + [EnumAttribute("memory", MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))).data), + EnumAttribute("speculatable"), + StringAttribute("enzyme_shouldrecompute"), + StringAttribute("enzyme_inactive"), + StringAttribute("enzyme_no_escaping_allocation") + ]) + end continue end # Since this is noreturn and it can't write to any operations in the function @@ -5428,7 +5528,13 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # fn, but it doesn't presently so for now we will ensure this by hand if func == typeof(Base.Checked.throw_overflowerr_binaryop) llvmfn = functions(mod)[k.specfunc] - handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) + if LLVM.version().major <= 15 + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), EnumAttribute("readonly"), StringAttribute("enzyme_ta_norecur")]) + else + handleCustom(llvmfn, "enz_noop", [StringAttribute("enzyme_inactive"), + EnumAttribute("memory", ReadOnlyEffects.data), + StringAttribute("enzyme_ta_norecur")]) + end continue end if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && has_method(Tuple{typeof(EnzymeRules.inactive), specTypes.parameters...}, world, method_table) @@ -5576,9 +5682,14 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; res = call!(builder, LLVM.function_type(llvmfn), llvmfn, collect(parameters(wrapper_f))) + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for idx in length(collect(parameters(llvmfn))) for attr in collect(parameter_attributes(llvmfn, idx)) - if kind(attr) == kind(EnumAttribute("sret")) + if kind(attr) == sretkind LLVM.API.LLVMAddCallSiteAttribute(res, LLVM.API.LLVMAttributeIndex(idx), attr) end end @@ -5708,6 +5819,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeReturnIndex, StringAttribute("enzyme_inactive")) end + TapeType::Type = Cvoid if params.run_enzyme diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 842ebe0fe3..dbda1b0880 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -1,4 +1,4 @@ -mutable struct PipelineConfig +struct PipelineConfig Speedup::Cint Size::Cint lower_intrinsics::Cint @@ -17,7 +17,8 @@ end const RunAttributor = Ref(true) -function pipeline_options(; lower_intrinsics=true, dump_native=false, external_use=false, llvm_only=false, always_inline=true, enalbe_early_simplifications=true, +function pipeline_options(; lower_intrinsics=true, dump_native=false, external_use=false, llvm_only=false, always_inline=true, enable_early_simplifications=true, + enable_early_optimizations=true, enable_scalar_optimizations=true, enable_loop_optimizations=true, enable_vector_pipeline=true, @@ -26,6 +27,255 @@ function pipeline_options(; lower_intrinsics=true, dump_native=false, external_u return PipelineConfig(Speedup, Size, lower_intrinsics, dump_native, external_use, llvm_only, always_inline, enable_early_simplifications, enable_early_optimizations, enable_scalar_optimizations, enable_loop_optimizations, enable_vector_pipeline, remove_ni, cleanup) end +function run_jl_pipeline(pm, tm; kwargs...) + config = Ref(pipeline_options(;kwargs...)) + function jl_pipeline(m) + @dispose pb=PassBuilder(tm) begin + NewPMModulePassManager(pb) do mpm + @ccall jl_build_newpm_pipeline(mpm.ref::Ptr{Cvoid}, pb.ref::Ptr{Cvoid}, config::Ptr{PipelineConfig})::Cvoid + run!(mpm, m, tm) + end + end + return true + end + add!(pm, ModulePass("JLPipeline", jl_pipeline)) +end + +@static if VERSION < v"1.11.0-DEV.428" +else + barrier_noop!(pm) = nothing +end + +@static if VERSION < v"1.11-" + function gc_invariant_verifier_tm!(pm, tm, cond) + gc_invariant_verifier!(pm, cond) + end +else + function gc_invariant_verifier_tm!(pm, tm, cond) + function gc_invariant_verifier(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, GCInvariantVerifierPass(GCInvariantVerifierPassOptions(;strong=cond))) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("GCInvariantVerifier", gc_invariant_verifier)) + end +end + +@static if VERSION < v"1.11-" + function propagate_julia_addrsp_tm!(pm, tm) + propagate_julia_addrsp!(pm) + end +else + function propagate_julia_addrsp_tm!(pm, tm) + function prop_julia_addr(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, PropagateJuliaAddrspacesPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("PropagateJuliaAddrSpace", prop_julia_addr)) + end +end + +@static if VERSION < v"1.11-" + function alloc_opt_tm!(pm, tm) + alloc_opt!(pm) + end +else + function alloc_opt_tm!(pm, tm) + function alloc_opt(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, AllocOptPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("AllocOpt", alloc_opt)) + end +end + +@static if VERSION < v"1.11-" + function remove_ni_tm!(pm, tm) + remove_ni!(pm) + end +else + function remove_ni_tm!(pm, tm) + function remove_ni(f) + @dispose pb=PassBuilder(tm) begin + NewPMModulePassManager(pb) do fpm + add!(fpm, RemoveNIPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, ModulePass("RemoveNI", remove_ni)) + end +end + +@static if VERSION < v"1.11-" + function julia_licm_tm!(pm, tm) + julia_licm!(pm) + end +else + function julia_licm_tm!(pm, tm) + function julia_licm(f) + @dispose pb=PassBuilder(tm) begin + NewPMLoopPassManager(pb) do fpm + add!(fpm, JuliaLICMPass()) + run!(fpm, f, tm) + end + end + return true + end + # really looppass + add!(pm, FunctionPass("JuliaLICM", julia_licm)) + end +end + +@static if VERSION < v"1.11-" + function lower_simdloop_tm!(pm, tm) + lower_simdloop!(pm) + end +else + function lower_simdloop_tm!(pm, tm) + function lower_simdloop(f) + @dispose pb=PassBuilder(tm) begin + NewPMLoopPassManager(pb) do fpm + add!(fpm, LowerSIMDLoopPass()) + run!(fpm, f, tm) + end + end + return true + end + # really looppass + add!(pm, FunctionPass("LowerSIMDLoop", lower_simdloop)) + end +end + +@static if VERSION < v"1.11-" + function demote_float16_tm!(pm, tm) + demote_float16!(pm) + end +else + function demote_float16_tm!(pm, tm) + function demote_float16(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, DemoteFloat16Pass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("DemoteFloat16", demote_float16)) + end +end + +@static if VERSION < v"1.11-" + function lower_exc_handlers_tm!(pm, tm) + lower_exc_handlers!(pm) + end +else + function lower_exc_handlers_tm!(pm, tm) + function lower_exc_handlers(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, LowerExcHandlersPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("LowerExcHandlers", lower_exc_handlers)) + end +end + +@static if VERSION < v"1.11-" + function lower_ptls_tm!(pm, tm, dump_native) + lower_ptls!(pm, dump_native) + end +else + function lower_ptls_tm!(pm, tm, dump_native) + function lower_ptls(f) + @dispose pb=PassBuilder(tm) begin + NewPMModulePassManager(pb) do fpm + add!(fpm, LowerPTLSPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, ModulePass("LowerPTLS", lower_ptls)) + end +end + +@static if VERSION < v"1.11-" + function combine_mul_add_tm!(pm, tm) + combine_mul_add!(pm) + end +else + function combine_mul_add_tm!(pm, tm) + function combine_mul_add(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, CombineMulAddPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("CombineMulAdd", combine_mul_add)) + end +end + +@static if VERSION < v"1.11-" + function late_lower_gc_frame_tm!(pm, tm) + late_lower_gc_frame!(pm) + end +else + function late_lower_gc_frame_tm!(pm, tm) + function late_lower_gc_frame(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, LateLowerGCPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("LateLowerGCFrame", late_lower_gc_frame)) + end +end + +@static if VERSION < v"1.11-" + function final_lower_gc_tm!(pm, tm) + final_lower_gc!(pm) + end +else + function final_lower_gc_tm!(pm, tm) + function final_lower_gc(f) + @dispose pb=PassBuilder(tm) begin + NewPMFunctionPassManager(pb) do fpm + add!(fpm, FinalLowerGCPass()) + run!(fpm, f, tm) + end + end + return true + end + add!(pm, FunctionPass("FinalLowerGCFrame", final_lower_gc)) + end +end + function addNA(inst, node::LLVM.Metadata, MD) md = metadata(inst) next = nothing @@ -626,6 +876,11 @@ function fix_decayaddr!(mod::LLVM.Module) mayread = false maywrite = false sret = true + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for (i, v) in enumerate(operands(st)[1:end-1]) if v == inst readnone = false @@ -633,7 +888,7 @@ function fix_decayaddr!(mod::LLVM.Module) writeonly = false t_sret = false for a in collect(parameter_attributes(fop, i)) - if kind(a) == kind(EnumAttribute("sret")) + if kind(a) == sretkind t_sret = true end if kind(a) == kind(StringAttribute("enzyme_sret")) @@ -803,7 +1058,7 @@ function prop_global!(g) end # From https://llvm.org/doxygen/IR_2Instruction_8cpp_source.html#l00959 -function mayWriteToMemory(inst::LLVM.Instruction)::Bool +function mayWriteToMemory(inst::LLVM.Instruction; err_is_readonly=false)::Bool # we will ignore fense here if isa(inst, LLVM.StoreInst) return true @@ -838,9 +1093,14 @@ function mayWriteToMemory(inst::LLVM.Instruction)::Bool return false end # Note out of spec, and only legal in context of removing unused calls - if kind(attr) == kind(StringAttribute("enzyme_error")) + if kind(attr) == kind(StringAttribute("enzyme_error")) && err_is_readonly return false end + if kind(attr) == kind(StringAttribute("memory")) + if is_readonly(MemoryEffect(value(attr))) + return false + end + end end Libc.free(Attrs) return true @@ -887,8 +1147,7 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end push!(done, cur) - attrs = collect(function_attributes(cur)) - if any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) || any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + if is_readonly(cur) continue end @@ -901,7 +1160,7 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end for bb in blocks(cur) for inst in instructions(bb) - if !mayWriteToMemory(inst) + if !mayWriteToMemory(inst; err_is_readonly=true) continue end if isa(inst, LLVM.CallInst) @@ -917,17 +1176,7 @@ function remove_readonly_unused_calls!(fn::LLVM.Function, next::Set{String}) end end - changed = false - attrs = collect(function_attributes(fn)) - if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) - if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs) - delete!(function_attributes(fn), EnumAttribute("writeonly")) - push!(function_attributes(fn), EnumAttribute("readnone")) - else - push!(function_attributes(fn), EnumAttribute("readonly")) - end - changed = true - end + changed = set_readonly!(fn) if length(calls) == 0 || hasUser return changed @@ -1345,6 +1594,11 @@ function validate_return_roots!(mod) enzyme_srets_v = Int[] rroots = Int[] rroots_v = Int[] + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for (i, a) in enumerate(parameters(f)) for attr in collect(parameter_attributes(f, i)) if isa(attr, StringAttribute) @@ -1361,7 +1615,7 @@ function validate_return_roots!(mod) push!(enzyme_srets, i) end end - if kind(attr) == kind(EnumAttribute("sret")) + if kind(attr) == sretkind push!(srets, (i, attr)) end end @@ -1519,7 +1773,7 @@ end cse!(pm) = LLVM.API.LLVMAddEarlyCSEPass(pm) -function removeDeadArgs!(mod::LLVM.Module) +function removeDeadArgs!(mod::LLVM.Module, tm) # We need to run globalopt first. This is because remove dead args will otherwise # take internal functions and replace their args with undef. Then on LLVM up to # and including 12 (but fixed 13+), Attributor will incorrectly change functions that @@ -1531,10 +1785,16 @@ function removeDeadArgs!(mod::LLVM.Module) end # Prevent dead-arg-elimination of functions which we may require args for in the derivative funcT = LLVM.FunctionType(LLVM.VoidType(), LLVMType[], vararg=true) - func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"), EnumAttribute("nofree")]) - rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) - sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) - + if LLVM.version().major <= 15 + func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"), EnumAttribute("nofree")]) + rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) + sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("readonly"), EnumAttribute("nofree"), EnumAttribute("argmemonly")]) + else + func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("memory", NoEffects.data), EnumAttribute("nofree")]) + rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) + sfunc, _ = get_function!(mod, "llvm.enzyme.sret_use", funcT, [EnumAttribute("memory", ReadOnlyArgMemEffects.data), EnumAttribute("nofree")]) + end + for fn in functions(mod) if isempty(blocks(fn)) continue @@ -1561,12 +1821,17 @@ function removeDeadArgs!(mod::LLVM.Module) end end end + sretkind = kind(if LLVM.version().major >= 12 + TypeAttribute("sret", LLVM.Int32Type()) + else + EnumAttribute("sret") + end) for idx in (1, 2) if length(collect(parameters(fn))) < idx continue end attrs = collect(parameter_attributes(fn, idx)) - if any( ( kind(attr) == kind(EnumAttribute("sret")) || kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs) + if any( ( kind(attr) == sretkind || kind(attr) == kind(StringAttribute("enzyme_sret")) || kind(attr) == kind(StringAttribute("enzyme_sret_v")) ) for attr in attrs) for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) @@ -1602,7 +1867,7 @@ function removeDeadArgs!(mod::LLVM.Module) ModulePassManager() do pm instruction_combining!(pm) jl_inst_simplify!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? cse!(pm) run!(pm, mod) @@ -1621,7 +1886,7 @@ function removeDeadArgs!(mod::LLVM.Module) ModulePassManager() do pm instruction_combining!(pm) jl_inst_simplify!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? if RunAttributor[] if LLVM.version().major >= 13 @@ -1666,7 +1931,7 @@ function optimize!(mod::LLVM.Module, tm) add_library_info!(pm, triple(mod)) add_transform_info!(pm, tm) - propagate_julia_addrsp!(pm) + propagate_julia_addrsp_tm!(pm, tm) scoped_no_alias_aa!(pm) type_based_alias_analysis!(pm) basic_alias_analysis!(pm) @@ -1678,7 +1943,7 @@ end scalar_repl_aggregates_ssa!(pm) # SSA variant? mem_cpy_opt!(pm) always_inliner!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) LLVM.API.LLVMAddGlobalOptimizerPass(pm) # Extra gvn!(pm) # Extra instruction_combining!(pm) @@ -1693,22 +1958,28 @@ end jl_inst_simplify!(pm) reassociate!(pm) early_cse!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) loop_idiom!(pm) loop_rotate!(pm) - lower_simdloop!(pm) - licm!(pm) - if LLVM.version() >= v"15" - simple_loop_unswitch_legacy!(pm) + + if VERSION < v"1.11-" + lower_simdloop_tm!(pm, tm) + licm!(pm) + if LLVM.version() >= v"15" + simple_loop_unswitch_legacy!(pm) + else + loop_unswitch!(pm) + end else - loop_unswitch!(pm) + run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) end + instruction_combining!(pm) jl_inst_simplify!(pm) ind_var_simplify!(pm) loop_deletion!(pm) loop_unroll!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) scalar_repl_aggregates_ssa!(pm) # SSA variant? gvn!(pm) @@ -1722,7 +1993,7 @@ end jl_inst_simplify!(pm) jump_threading!(pm) dead_store_elimination!(pm) - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) cfgsimplification!(pm) loop_idiom!(pm) loop_deletion!(pm) @@ -1740,7 +2011,7 @@ end # GC passes barrier_noop!(pm) - gc_invariant_verifier!(pm, false) + gc_invariant_verifier_tm!(pm, tm, false) # FIXME: Currently crashes printing cfgsimplification!(pm) @@ -1750,7 +2021,7 @@ end gvn!(pm) # Exxtra run!(pm, mod) end - removeDeadArgs!(mod) + removeDeadArgs!(mod, tm) detect_writeonly!(mod) nodecayed_phis!(mod) end @@ -1762,12 +2033,12 @@ function addTargetPasses!(pm, tm, trip) end # https://github.com/JuliaLang/julia/blob/2eb5da0e25756c33d1845348836a0a92984861ac/src/aotcompile.cpp#L620 -function addOptimizationPasses!(pm) +function addOptimizationPasses!(pm, tm) add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) constant_merge!(pm) - propagate_julia_addrsp!(pm) + propagate_julia_addrsp_tm!(pm, tm) scoped_no_alias_aa!(pm) type_based_alias_analysis!(pm) basic_alias_analysis!(pm) @@ -1783,7 +2054,7 @@ function addOptimizationPasses!(pm) # merging the `alloca` for the unboxed data and the `alloca` created by the `alloc_opt` # pass. - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) # consider AggressiveInstCombinePass at optlevel > 2 instruction_combining!(pm) @@ -1801,24 +2072,46 @@ function addOptimizationPasses!(pm) # Load forwarding above can expose allocations that aren't actually used # remove those before optimizing loops. - alloc_opt!(pm) - loop_rotate!(pm) - # moving IndVarSimplify here prevented removing the loop in perf_sumcartesian(10:-1:1) - loop_idiom!(pm) - - # LoopRotate strips metadata from terminator, so run LowerSIMD afterwards - lower_simdloop!(pm) # Annotate loop marked with "loopinfo" as LLVM parallel loop - licm!(pm) - julia_licm!(pm) - # Subsequent passes not stripping metadata from terminator - instruction_combining!(pm) # TODO: createInstSimplifyLegacy - jl_inst_simplify!(pm) - ind_var_simplify!(pm) - loop_deletion!(pm) - loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll + alloc_opt_tm!(pm, tm) + + + if VERSION < v"1.11-" + loop_rotate!(pm) + # moving IndVarSimplify here prevented removing the loop in perf_sumcartesian(10:-1:1) + loop_idiom!(pm) + + # LoopRotate strips metadata from terminator, so run LowerSIMD afterwards + lower_simdloop_tm!(pm, tm) # Annotate loop marked with "loopinfo" as LLVM parallel loop + licm!(pm) + julia_licm_tm!(pm, tm) + # Subsequent passes not stripping metadata from terminator + instruction_combining!(pm) # TODO: createInstSimplifyLegacy + jl_inst_simplify!(pm) + + ind_var_simplify!(pm) + loop_deletion!(pm) + loop_unroll!(pm) # TODO: in Julia createSimpleLoopUnroll + else + # LowerSIMDLoopPass + # LoopRotatePass [opt >= 2] + # LICMPass + # JuliaLICMPass + # SimpleLoopUnswitchPass + # LICMPass + # JuliaLICMPass + # IRCEPass + # LoopInstSimplifyPass + # - in ours this is instcombine with jlinstsimplify + # LoopIdiomRecognizePass + # IndVarSimplifyPass + # LoopDeletionPass + # LoopFullUnrollPass + run_jl_pipeline(pm, tm; lower_intrinsics=false, dump_native=false, external_use=false, llvm_only=false, always_inline=false, enable_early_simplifications=false, enable_early_optimizations=false, enable_scalar_optimizations=false, enable_loop_optimizations=true, enable_vector_pipeline=false, remove_ni=false, cleanup=false) + end + # Run our own SROA on heap objects before LLVM's - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) # Re-run SROA after loop-unrolling (useful for small loops that operate, # over the structure of an aggregate) scalar_repl_aggregates!(pm) @@ -1840,7 +2133,7 @@ function addOptimizationPasses!(pm) # More dead allocation (store) deletion before loop optimization # consider removing this: - alloc_opt!(pm) + alloc_opt_tm!(pm, tm) # see if all of the constant folding has exposed more loops # to simplification and deletion @@ -1859,31 +2152,31 @@ function addOptimizationPasses!(pm) aggressive_dce!(pm) end -function addMachinePasses!(pm) - combine_mul_add!(pm) +function addMachinePasses!(pm, tm) + combine_mul_add_tm!(pm, tm) # TODO: createDivRemPairs[] - demote_float16!(pm) + demote_float16_tm!(pm, tm) gvn!(pm) end -function addJuliaLegalizationPasses!(pm, lower_intrinsics=true) +function addJuliaLegalizationPasses!(pm, tm, lower_intrinsics=true) if lower_intrinsics # LowerPTLS removes an indirect call. As a result, it is likely to trigger # LLVM's devirtualization heuristics, which would result in the entire # pass pipeline being re-exectuted. Prevent this by inserting a barrier. barrier_noop!(pm) add!(pm, FunctionPass("ReinsertGCMarker", reinsert_gcmarker_pass!)) - lower_exc_handlers!(pm) + lower_exc_handlers_tm!(pm, tm) # BUDE.jl demonstrates a bug here TODO - gc_invariant_verifier!(pm, false) + gc_invariant_verifier_tm!(pm, tm, false) verifier!(pm) # Needed **before** LateLowerGCFrame on LLVM < 12 # due to bug in `CreateAlignmentAssumption`. - remove_ni!(pm) - late_lower_gc_frame!(pm) - final_lower_gc!(pm) + remove_ni_tm!(pm, tm) + late_lower_gc_frame_tm!(pm, tm) + final_lower_gc_tm!(pm, tm) # We need these two passes and the instcombine below # after GC lowering to let LLVM do some constant propagation on the tags. # and remove some unnecessary write barrier checks. @@ -1891,20 +2184,20 @@ function addJuliaLegalizationPasses!(pm, lower_intrinsics=true) sccp!(pm) # Remove dead use of ptls dce!(pm) - lower_ptls!(pm, #=dump_native=# false) + lower_ptls_tm!(pm, tm, #=dump_native=# false) instruction_combining!(pm) jl_inst_simplify!(pm) # Clean up write barrier and ptls lowering cfgsimplification!(pm) else barrier_noop!(pm) - remove_ni!(pm) + remove_ni_tm!(pm, tm) end end function post_optimze!(mod, tm, machine=true) addr13NoAlias(mod) - removeDeadArgs!(mod) + removeDeadArgs!(mod, tm) for f in collect(functions(mod)) API.EnzymeFixupJuliaCallingConvention(f) end @@ -1914,15 +2207,15 @@ function post_optimze!(mod, tm, machine=true) end LLVM.ModulePassManager() do pm addTargetPasses!(pm, tm, LLVM.triple(mod)) - addOptimizationPasses!(pm) + addOptimizationPasses!(pm, tm) run!(pm, mod) end if machine # TODO enable validate_return_roots # validate_return_roots!(mod) LLVM.ModulePassManager() do pm - addJuliaLegalizationPasses!(pm, true) - addMachinePasses!(pm) + addJuliaLegalizationPasses!(pm, tm, true) + addMachinePasses!(pm, tm) run!(pm, mod) end end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 4b38256e61..b5bdb3afa2 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -1,3 +1,185 @@ +struct MemoryEffect + data::UInt32 +end + +@enum(ModRefInfo, + MRI_NoModRef = 0, + MRI_Ref = 1, + MRI_Mod = 2, + MRI_ModRef = 3) + +@enum(IRMemLocation, + ArgMem = 0, + InaccessibleMem = 1, + Other = 2) + +const BitsPerLoc = UInt32(2) +const LocMask = UInt32((1 << BitsPerLoc) - 1) +function getLocationPos(Loc::IRMemLocation) + return UInt32(Loc) * BitsPerLoc +end +function Base.:<<(mr::ModRefInfo, rhs::UInt32) + UInt32(mr) << rhs +end +function Base.:|(lhs::ModRefInfo, rhs::ModRefInfo) + ModRefInfo(UInt32(lhs) | UInt32(rhs)) +end +function Base.:&(lhs::ModRefInfo, rhs::ModRefInfo) + ModRefInfo(UInt32(lhs) & UInt32(rhs)) +end +const AllEffects = MemoryEffect((MRI_ModRef << getLocationPos(ArgMem)) | (MRI_ModRef << getLocationPos(InaccessibleMem)) | (MRI_ModRef << getLocationPos(Other))) +const ReadOnlyEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_Ref << getLocationPos(InaccessibleMem)) | (MRI_Ref << getLocationPos(Other))) +const ReadOnlyArgMemEffects = MemoryEffect((MRI_Ref << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) +const NoEffects = MemoryEffect((MRI_NoModRef << getLocationPos(ArgMem)) | (MRI_NoModRef << getLocationPos(InaccessibleMem)) | (MRI_NoModRef << getLocationPos(Other))) + +# Get ModRefInfo for any location. +function getModRef(effect::MemoryEffect, loc::IRMemLocation)::ModRefInfo + ModRefInfo((effect.data >> getLocationPos(loc)) & LocMask) +end + +function getModRef(effect::MemoryEffect)::ModRefInfo + cur = MRI_NoModRef + for loc in (ArgMem, InaccessibleMem, Other) + cur |= getModRef(effect, loc) + end + return cur +end + +function setModRef(effect::MemoryEffect, Loc::IRMemLocation, MR::ModRefInfo)::MemoryEffect + data = effect.data + Data &= ~(LocMask << getLocationPos(Loc)) + Data |= MR << getLocationPos(Loc) + return MemoryEffect(data) +end + +function setModRef(effect::MemoryEffect)::MemoryEffect + for loc in (ArgMem, InaccessibleMem, Other) + effect = setModRef(effect, mri)= getModRef(effect, loc) + end + return effect +end + +function set_readonly(mri::ModRefInfo) + return mri & MRI_Ref +end +function set_writeonly(mri::ModRefInfo) + return mri & MRI_Mod +end +function set_reading(mri::ModRefInfo) + return mri | MRI_Ref +end +function set_writing(mri::ModRefInfo) + return mri | MRI_Mod +end + +function set_readonly(effect::MemoryEffect) + data = UInt32(0) + for loc in (ArgMem, InaccessibleMem, Other) + data = UInt32(set_readonly(getModRef(effect, loc))) << getLocationPos(loc) + end + return MemoryEffect(data) +end + +function is_readonly(mri::ModRefInfo) + return mri == MRI_NoModRef || mri == MRI_Ref +end + +function is_readnone(mri::ModRefInfo) + return mri == MRI_NoModRef +end + +function is_writeonly(mri::ModRefInfo) + return mri == MRI_NoModRef || mri == MRI_Mod +end + +for n in (:is_readonly, :is_readnone, :is_writeonly) +@eval begin + function $n(memeffect::MemoryEffect) + return $n(getModRef(memeffect)) + end +end +end + +function is_readonly(f::LLVM.Function) + for attr in collect(function_attributes(f)) + if kind(attr) == kind(EnumAttribute("readonly")) + return true + end + if kind(attr) == kind(EnumAttribute("readnone")) + return true + end + if LLVM.version().major > 15 + if kind(attr) == kind(EnumAttribute("memory")) + if is_readonly(MemoryEffect(value(attr))) + return true + end + end + end + end + return false +end + +function is_readnone(f::LLVM.Function) + for attr in collect(function_attributes(cur)) + if kind(attr) == kind(EnumAttribute("readnone")) + return true + end + if LLVM.version().major > 15 + if kind(attr) == kind(EnumAttribute("memory")) + if is_readnone(MemoryEffect(value(attr))) + return true + end + end + end + end + return false +end + +function is_writeonly(f::LLVM.Function) + for attr in collect(function_attributes(cur)) + if kind(attr) == kind(EnumAttribute("readnone")) + return true + end + if kind(attr) == kind(EnumAttribute("writeonly")) + return true + end + if LLVM.version().major > 15 + if kind(attr) == kind(EnumAttribute("memory")) + if is_writeonly(MemoryEffect(value(attr))) + return true + end + end + end + end + return false +end + +function set_readonly!(fn::LLVM.Function) + attrs = collect(function_attributes(fn)) + if LLVM.version().major <= 15 + if !any(kind(attr) == kind(EnumAttribute("readonly")) for attr in attrs) && !any(kind(attr) == kind(EnumAttribute("readnone")) for attr in attrs) + if any(kind(attr) == kind(EnumAttribute("writeonly")) for attr in attrs) + delete!(function_attributes(fn), EnumAttribute("writeonly")) + push!(function_attributes(fn), EnumAttribute("readnone")) + else + push!(function_attributes(fn), EnumAttribute("readonly")) + end + return true + end + return false + else + for attr in attrs + if kind(attr) == kind(EnumAttribute("memory")) + old = MemoryEffect(value(attr)) + eff = set_readonly(old) + push!(function_attributes(fn), EnumAttribute("memory", eff.data)) + return old != eff + end + end + push!(function_attributes(fn), EnumAttribute("memory", set_readonly(AllEffects).data)) + return true + end +end function get_function!(mod::LLVM.Module, name::AbstractString, FT::LLVM.FunctionType, attrs=[]) if haskey(functions(mod), name) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 8715ce1991..951d527d6d 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -385,27 +385,90 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) nval = ptrtoint!(b, call!(b, LLVM.function_type(mfn2), mfn2, [LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 0))]), value_type(inst)) replace_uses!(inst, nval) LLVM.API.LLVMInstructionEraseFromParent(inst) - elseif fn == "jl_load_and_lookup" + elseif fn == "jl_load_and_lookup" || fn == "ijl_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) - legal, flib = abs_cstring(operands(inst)[1]) - legal2, fname = abs_cstring(operands(inst)[2]) - legal &= legal2 + arg1 = operands(inst)[1] - hnd = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(inst, 2)) - if isa(hnd, LLVM.GlobalVariable) - hnd = LLVM.name(hnd) - else - legal = false + while isa(arg1, ConstantExpr) + if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast || opcode(arg1) == LLVM.API.LLVMBitCast || opcode(arg1) == LLVM.API.LLVMIntToPtr + arg1 = operands(arg1)[1] + else + break + end end + if isa(arg1, LLVM.ConstantInt) + arg1 = reinterpret(Ptr{Cvoid}, convert(UInt, arg1)) + legal2, fname = abs_cstring(operands(inst)[2]) + if legal2 + hnd = operands(inst)[3] + if isa(hnd, LLVM.GlobalVariable) + hnd = LLVM.name(hnd) + if fn == "jl_lazy_load_and_lookup" + res = ccall(:jl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) + else + res = ccall(:ijl_load_and_lookup, Ptr{Cvoid}, (Ptr{Cvoid}, Cstring, Ptr{Cvoid}), arg1, fname, reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr)) + end + replaceWith = LLVM.ConstantInt(LLVM.IntType(8*sizeof(Int)), reinterpret(UInt, res)) + for u in LLVM.uses(inst) + st = LLVM.user(u) + if isa(st, LLVM.StoreInst) && LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 0)) == inst + ptr = LLVM.Value(LLVM.LLVM.API.LLVMGetOperand(st, 1)) + for u in LLVM.uses(ptr) + ld = LLVM.user(u) + if isa(ld, LLVM.LoadInst) + b = IRBuilder() + position!(b, ld) + for u in LLVM.uses(ld) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + end + replace_uses!(ld, LLVM.inttoptr!(b, replaceWith, value_type(inst))) + end + end + end + end - if !legal - return + b = IRBuilder() + position!(b, inst) + replacement = LLVM.inttoptr!(b, replaceWith, value_type(inst)) + for u in LLVM.uses(inst) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.PHIInst) + if all(x->first(x) == inst || first(x) == replacement, LLVM.incoming(u)) + + for u in LLVM.uses(u) + u = LLVM.user(u) + if isa(u, LLVM.CallInst) + push!(calls, u) + end + if isa(u, LLVM.BitCastInst) + for u1 in LLVM.uses(u) + u1 = LLVM.user(u1) + if isa(u1, LLVM.CallInst) + push!(calls, u1) + end + end + replace_uses!(u, LLVM.inttoptr!(b, replaceWith, value_type(u))) + end + end + end + end + end + replace_uses!(inst, replacement) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end end - # res = ccall(:jl_load_and_lookup, Ptr{Cvoid}, (Cstring, Cstring, Ptr{Cvoid}), flib, fname, cglobal(Symbol(hnd))) - push!(errors, ("jl_load_and_lookup", bt, nothing)) + + elseif fn == "jl_lazy_load_and_lookup" || fn == "ijl_lazy_load_and_lookup" ofn = LLVM.parent(LLVM.parent(inst)) diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 37208f4afc..d0146a64e2 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -473,6 +473,54 @@ function arrayreshape_rev(B, orig, gutils, tape) return nothing end +function gcloaded_fwd(B, orig, gutils, normalR, shadowR) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) + return true + end + + origops = LLVM.operands(orig) + if is_constant_value(gutils, origops[1]) + emit_error(B, orig, "Enzyme: gcloaded has active return, but inactive input(1)") + end + if is_constant_value(gutils, origops[2]) + emit_error(B, orig, "Enzyme: gcloaded has active return, but inactive input(2)") + end + + width = get_width(gutils) + + shadowin1 = invert_pointer(gutils, origops[1], B) + shadowin2 = invert_pointer(gutils, origops[2], B) + if width == 1 + args = LLVM.Value[shadowin1, shadowin2] + shadowres = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) + else + shadowres = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) + for idx in 1:width + args = LLVM.Value[ + extract_value!(B, shadowin1, idx-1) + extract_value!(B, shadowin2, idx-1) + ] + tmp = call_samefunc_with_inverted_bundles!(B, gutils, orig, args, [API.VT_Shadow, API.VT_Shadow], #=lookup=#false) + shadowres = insert_value!(B, shadowres, tmp, idx-1) + end + end + unsafe_store!(shadowR, shadowres.ref) + + return false +end + +function gcloaded_augfwd(B, orig, gutils, normalR, shadowR, tapeR) + gcloaded_fwd(B, orig, gutils, normalR, shadowR) +end + +function gcloaded_rev(B, orig, gutils, tape) + return nothing +end + function boxfloat_fwd(B, orig, gutils, normalR, shadowR) origops = collect(operands(orig)) width = get_width(gutils) @@ -1206,6 +1254,12 @@ end @revfunc(jlcall2_rev), @fwdfunc(jlcall2_fwd), ) + register_handler!( + ("julia.gc_loaded",), + @augfunc(gcloaded_augfwd), + @revfunc(gcloaded_rev), + @fwdfunc(gcloaded_fwd), + ) register_handler!( ("jl_apply_generic", "ijl_apply_generic"), @augfunc(generic_augfwd), diff --git a/src/typetree.jl b/src/typetree.jl index 79ca41cd81..2c846ae49e 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -152,28 +152,46 @@ function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, return tt end -function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} - offset = 0 - - tt = copy(typetree(T, ctx, dl, seen)) - if !allocatedinline(T) +@static if VERSION < v"1.11-" + function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} + offset = 0 + + tt = copy(typetree(T, ctx, dl, seen)) + if !allocatedinline(T) + merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, 0) + end merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, 0) - end - merge!(tt, TypeTree(API.DT_Pointer, ctx)) - only!(tt, offset) + only!(tt, offset) + + offset += sizeof(Ptr{Cvoid}) - offset += sizeof(Ptr{Cvoid}) + sizeofstruct = offset + 2 + 2 + 4 + 2 * sizeof(Csize_t) + if true # STORE_ARRAY_LEN + sizeofstruct += sizeof(Csize_t) + end - sizeofstruct = offset + 2 + 2 + 4 + 2 * sizeof(Csize_t) - if true # STORE_ARRAY_LEN - sizeofstruct += sizeof(Csize_t) + for i in offset:(sizeofstruct-1) + merge!(tt, TypeTree(API.DT_Integer, i, ctx)) + end + return tt end +else + function typetree_inner(::Type{<:GenericMemory{kind, T}}, ctx, dl, seen::TypeTreeTable) where {kind, T} + offset = 0 + tt = copy(typetree(T, ctx, dl, seen)) + if !allocatedinline(T) + merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, 0) + end + merge!(tt, TypeTree(API.DT_Pointer, ctx)) + only!(tt, sizeof(Csize_t)) - for i in offset:(sizeofstruct-1) - merge!(tt, TypeTree(API.DT_Integer, i, ctx)) + for i in 0:(sizeof(Csize_t)-1) + merge!(tt, TypeTree(API.DT_Integer, i, ctx)) + end + return tt end - return tt end if VERSION >= v"1.7.0-DEV.204"