Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try out lowered IR validation. #331

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ end
if validate
@timeit_debug to "validation" begin
check_invocation(job)
check_ir(job, ir)
check_llvm_ir(job, ir)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false

# provide a specific interpreter to use.
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(ci_cache(job), method_table(job), job.source.world)
GPUInterpreter(job, ci_cache(job), method_table(job),)

# does this target support throwing Julia exceptions with jl_throw?
# if not, calls to throw will be replaced with calls to the GPU runtime
Expand Down
36 changes: 24 additions & 12 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,31 +170,29 @@ using Core.Compiler:
AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams

struct GPUInterpreter <: AbstractInterpreter
job::CompilerJob

global_cache::CodeCache
method_table::Union{Nothing,Core.MethodTable}

# Cache of inference results for this particular interpreter
local_cache::Vector{InferenceResult}
# The world age we're working inside of
world::UInt

# Parameters for inference and optimization
inf_params::InferenceParams
opt_params::OptimizationParams

function GPUInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt)
@assert world <= Base.get_world_counter()
function GPUInterpreter(job::CompilerJob, cache::CodeCache, mt::Union{Nothing,Core.MethodTable})
@assert job.source.world <= Base.get_world_counter()

return new(
job,
cache,
mt,

# Initially empty cache
Vector{InferenceResult}(),

# world age counter
world,

# parameters for inference and optimization
InferenceParams(unoptimize_throw_blocks=false),
VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() :
Expand All @@ -205,14 +203,28 @@ end

Core.Compiler.InferenceParams(interp::GPUInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::GPUInterpreter) = interp.opt_params
Core.Compiler.get_world_counter(interp::GPUInterpreter) = interp.world
Core.Compiler.get_world_counter(interp::GPUInterpreter) = interp.job.source.world
Core.Compiler.get_inference_cache(interp::GPUInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::GPUInterpreter) = WorldView(interp.global_cache, interp.world)
Core.Compiler.code_cache(interp::GPUInterpreter) =
WorldView(interp.global_cache, Core.Compiler.get_world_counter(interp))

# No need to do any locking since we're not putting our results into the runtime cache
Core.Compiler.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing

import Core.Compiler: retrieve_code_info, validate_code_in_debug_mode, InferenceState
# Replace usage sites of `retrieve_code_info`, OptimizationState is one such use, but in all
# interesting use-cases it is derived from an InferenceState. There is a third one in
# `typeinf_ext` in case the module forbids inference.
function InferenceState(result::InferenceResult, cached::Symbol, interp::GPUInterpreter)
src = retrieve_code_info(result.linfo)
src === nothing && return nothing
validate_code_in_debug_mode(result.linfo, src, "lowered")
check_julia_ir(interp, result.linfo, src)
return InferenceState(result, src, cached, interp)
end


function Core.Compiler.add_remark!(interp::GPUInterpreter, sv::InferenceState, msg)
@safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg"
end
Expand All @@ -228,14 +240,14 @@ if isdefined(Base.Experimental, Symbol("@overlay"))
using Core.Compiler: OverlayMethodTable
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::GPUInterpreter) =
OverlayMethodTable(interp.world, interp.method_table)
OverlayMethodTable(Core.Compiler.get_world_counter(interp), interp.method_table)
else
Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) =
OverlayMethodTable(interp.world, interp.method_table)
OverlayMethodTable(Core.Compiler.get_world_counter(interp), interp.method_table)
end
else
Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) =
WorldOverlayMethodTable(interp.world)
WorldOverlayMethodTable(Core.Compiler.get_world_counter(interp))
end


Expand Down
67 changes: 59 additions & 8 deletions src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function check_method(@nospecialize(job::CompilerJob))
if job.source.kernel
cache = ci_cache(job)
mt = method_table(job)
interp = GPUInterpreter(cache, mt, world)
interp = GPUInterpreter(job, cache, mt)
rt = return_type(only(ms); interp)

if rt != Nothing
Expand Down Expand Up @@ -102,6 +102,55 @@ struct InvalidIRError <: Exception
errors::Vector{IRError}
end

# Julia IR

const UNDEFINED_GLOBAL = "use of an undefined global binding"
const MUTABLE_GLOBAL = "use of a mutable global binding"

function check_julia_ir(interp, mi, src)
# pseudo (single-frame) backtrace pointing to a source code location
function backtrace(i)
loc = src.linetable[i]
[StackTraces.StackFrame(loc.method, loc.file, loc.line, mi, false, false, C_NULL)]
end

function check(i, x, errors::Vector{IRError})
if x isa Expr
for y in x.args
check(i, y, errors)
end
elseif x isa GlobalRef
Base.isbindingresolved(x.mod, x.name) || return
# XXX: when does this happen? do we miss any cases by bailing out early?
# why doesn't calling `Base.resolve(x, force=true)` work?
if !Base.isdefined(x.mod, x.name)
push!(errors, (UNDEFINED_GLOBAL, backtrace(i), x))
end
if !Base.isconst(x.mod, x.name)
push!(errors, (MUTABLE_GLOBAL, backtrace(i), x))
end

# TODO: make the validation conditional, but make sure we don't cache invalid IR

# TODO: perform more validation? e.g. disallow Arrays and other CPU values?
end

return
end

errors = IRError[]
for (i, x) in enumerate(src.code)
check(i, x, errors)
end
if !isempty(errors)
throw(InvalidIRError(interp.job, errors))
end

return
end

# LLVM IR

const RUNTIME_FUNCTION = "call to the Julia runtime"
const UNKNOWN_FUNCTION = "call to an unknown function"
const POINTER_FUNCTION = "call through a literal pointer"
Expand All @@ -117,6 +166,8 @@ function Base.showerror(io::IO, err::InvalidIRError)
print(io, " (call to ", meta, ")")
elseif kind == DELAYED_BINDING
print(io, " (use of '", meta, "')")
else
print(io, " (", meta, ")")
end
end
Base.show_backtrace(io, bt)
Expand All @@ -132,8 +183,8 @@ function Base.showerror(io::IO, err::InvalidIRError)
return
end

function check_ir(job, args...)
errors = check_ir!(job, IRError[], args...)
function check_llvm_ir(job, args...)
errors = check_llvm_ir!(job, IRError[], args...)
unique!(errors)
if !isempty(errors)
throw(InvalidIRError(job, errors))
Expand All @@ -142,18 +193,18 @@ function check_ir(job, args...)
return
end

function check_ir!(job, errors::Vector{IRError}, mod::LLVM.Module)
function check_llvm_ir!(job, errors::Vector{IRError}, mod::LLVM.Module)
for f in functions(mod)
check_ir!(job, errors, f)
check_llvm_ir!(job, errors, f)
end

return errors
end

function check_ir!(job, errors::Vector{IRError}, f::LLVM.Function)
function check_llvm_ir!(job, errors::Vector{IRError}, f::LLVM.Function)
for bb in blocks(f), inst in instructions(bb)
if isa(inst, LLVM.CallInst)
check_ir!(job, errors, inst)
check_llvm_ir!(job, errors, inst)
end
end

Expand All @@ -162,7 +213,7 @@ end

const libjulia = Ref{Ptr{Cvoid}}(C_NULL)

function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
function check_llvm_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
bt = backtrace(inst)
dest = called_value(inst)
if isa(dest, LLVM.Function)
Expand Down