From 7fde557caf581dbd8e1efbc168e5cb4ebc934a6d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 14 Mar 2022 12:16:48 -0400 Subject: [PATCH] Use an early IR transform to replace GlobalRefs --- src/GPUCompiler.jl | 1 + src/early_transform.jl | 49 ++++++++++++++++++++++++++++++++++++++++++ src/jlgen.jl | 19 ++++++++++++++++ 3 files changed, 69 insertions(+) create mode 100644 src/early_transform.jl diff --git a/src/GPUCompiler.jl b/src/GPUCompiler.jl index 3503295e..478c6eeb 100644 --- a/src/GPUCompiler.jl +++ b/src/GPUCompiler.jl @@ -29,6 +29,7 @@ include("bpf.jl") include("runtime.jl") # compiler implementation +include("early_transform.jl") include("jlgen.jl") include("irgen.jl") include("optim.jl") diff --git a/src/early_transform.jl b/src/early_transform.jl new file mode 100644 index 00000000..8053265c --- /dev/null +++ b/src/early_transform.jl @@ -0,0 +1,49 @@ +function static_eval(mod, name) + if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name) + return Some(getfield(mod, name)) + else + return nothing + end +end +static_eval(gr::GlobalRef) = static_eval(gr.mod, gr.name) + +function ir_element(x, code::Vector) + while isa(x, Core.SSAValue) + x = code[x.id] + end + return x +end + +""" + is_ir_element(x, y, code::Vector) + +Return `true` if `x === y` or if `x` is an `SSAValue` such that +`is_ir_element(code[x.id], y, code)` is `true`. +""" +function is_ir_element(x, y, code::Vector) + result = false + while true # break by default + if x === y # + result = true + break + elseif isa(x, Core.SSAValue) + x = code[x.id] + else + break + end + end + return result +end + + +function early_transform!(mi, src) + for (i, x) in enumerate(src.code) + stmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x + if stmt isa GlobalRef + @show static_eval(stmt) + end + # TODO: Walk stmt.args and find other uses of `:GlobalRef` + # TODO: decide which GlobalRef to rewrite? + end + return nothing +end \ No newline at end of file diff --git a/src/jlgen.jl b/src/jlgen.jl index aa74d2b9..e4ad98f5 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -213,6 +213,25 @@ Core.Compiler.code_cache(interp::GPUInterpreter) = WorldView(interp.global_cache Core.Compiler.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing Core.Compiler.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing +import Core.Compiler: retrieve_code_info, validate_code_in_debug_mode, InferenceState +# Replace usage sites of `retrieve_code_info`, OptimizationState is one such, but in all interesting use-cases +# it is derived from an InferenceState. There is a third one in `typeinf_ext` in case the module forbids inference. +function InferenceState(result::InferenceResult, cached::Symbol, interp::GPUInterpreter) + src = retrieve_code_info(result.linfo) + src === nothing && return nothing + validate_code_in_debug_mode(result.linfo, src, "lowered") + src = transform(interp, result.linfo, src) + validate_code_in_debug_mode(result.linfo, src, "transformed") + return InferenceState(result, src, cached, interp) +end + +function transform(interp, mi, src) + src = copy(src) + early_transform!(mi, src) + return src +end + + function Core.Compiler.add_remark!(interp::GPUInterpreter, sv::InferenceState, msg) @safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg" end