Skip to content

Commit

Permalink
refactor unreachability analysis
Browse files Browse the repository at this point in the history
Separated from #43999.
xref: 
<#43999 (comment)>
  • Loading branch information
aviatesk committed May 4, 2022
1 parent e4d21d4 commit 56ab8f0
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 176 deletions.
19 changes: 13 additions & 6 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1835,7 +1835,13 @@ function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(
elseif isa(e, SSAValue)
return abstract_eval_ssavalue(e, sv)
elseif isa(e, SlotNumber) || isa(e, Argument)
return vtypes[slot_id(e)].typ
sn = slot_id(e)
s = vtypes[sn]
if s.undef === true || # may not be defined
s.typ === Bottom # never analyzed
sv.src.slotflags[sn] |= SLOT_USEDUNDEF | SLOT_STATICUNDEF
end
return s.typ
elseif isa(e, GlobalRef)
return abstract_eval_global(e.mod, e.name, sv)
end
Expand Down Expand Up @@ -2022,10 +2028,11 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
sym = e.args[1]
t = Bool
if isa(sym, SlotNumber)
vtyp = vtypes[slot_id(sym)]
sn = slot_id(sym)
vtyp = vtypes[sn]
if vtyp.typ === Bottom
t = Const(false) # never assigned previously
elseif !vtyp.undef
elseif vtyp.undef === false
t = Const(true) # definitely assigned previously
end
elseif isa(sym, Symbol)
Expand Down Expand Up @@ -2332,9 +2339,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), changes, false)
end
elseif hd === :code_coverage_effect ||
(hd !== :boundscheck && # :boundscheck can be narrowed to Bool
hd !== nothing && is_meta_expr_head(hd))
elseif hd === :code_coverage_effect || (hd !== nothing &&
hd !== :boundscheck && # :boundscheck can be narrowed to Bool
is_meta_expr_head(hd))
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# (only used in abstractinterpret, doesn't appear in optimize)
struct VarState
typ
undef::Bool
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
undef::Union{Nothing,Bool} # nothing if unanalyzed
VarState(@nospecialize(typ), undef::Union{Nothing,Bool}) = new(typ, undef)
end

"""
Expand Down Expand Up @@ -152,7 +152,7 @@ mutable struct InferenceState
stmt_types[1] = stmt_type1 = VarTable(undef, nslots)
for i in 1:nslots
argtyp = (i > nargs) ? Bottom : argtypes[i]
stmt_type1[i] = VarState(argtyp, i > nargs)
stmt_type1[i] = VarState(argtyp, i > nargs && nothing)
slottypes[i] = argtyp
end

Expand Down
117 changes: 74 additions & 43 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,26 @@ mutable struct OptimizationState
linfo::MethodInstance
src::CodeInfo
ir::Union{Nothing, IRCode}
was_reached::Union{Nothing, BitSet}
stmt_info::Vector{Any}
mod::Module
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
inlining::InliningState
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
was_reached = BitSet()
for i = 1:length(frame.stmt_types)
if isa(frame.stmt_types[i], VarTable)
push!(was_reached, i)
end
end
s_edges = frame.stmt_edges[1]::Vector{Any}
inlining = InliningState(params,
EdgeTracker(s_edges, frame.valid_worlds),
WorldView(code_cache(interp), frame.world),
interp)
return new(frame.linfo,
frame.src, nothing, frame.stmt_info, frame.mod,
frame.src, nothing, was_reached, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
Expand Down Expand Up @@ -131,11 +138,13 @@ mutable struct OptimizationState
WorldView(code_cache(interp), get_world_counter()),
interp)
return new(linfo,
src, nothing, stmt_info, mod,
src, nothing, nothing, stmt_info, mod,
sptypes_from_meth_instance(linfo), slottypes, inlining)
end
end

was_reached((; was_reached)::OptimizationState, pc::Int) = was_reached === nothing || pc in was_reached

function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
src = retrieve_code_info(linfo)
src === nothing && return nothing
Expand Down Expand Up @@ -399,7 +408,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
(; def, specTypes) = linfo

analyzed = nothing # `ConstAPI` if this call can use constant calling convention
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)
force_noinline = _any(x::Expr -> x.head === :meta && x.args[1] === :noinline, ir.meta)

# compute inlining and other related optimizations
result = caller.result
Expand Down Expand Up @@ -554,30 +563,53 @@ function run_passes(ci::CodeInfo, sv::OptimizationState, caller::InferenceResult
end

function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
code = copy_exprargs(ci.code)
coverage = coverage_enabled(sv.mod)
# Go through and add an unreachable node after every
# Union{} call. Then reindex labels.
idx = 1
oldidx = 1
changemap = fill(0, length(code))
prevloc = zero(eltype(ci.codelocs))
stmtinfo = sv.stmt_info
codelocs = ci.codelocs
ssavaluetypes = ci.ssavaluetypes::Vector{Any}
ssaflags = ci.ssaflags
linetable = ci.linetable
if !isa(linetable, Vector{LineInfoNode})
linetable = collect(LineInfoNode, linetable::Vector{Any})::Vector{LineInfoNode}
end
if !coverage && JLOptions().code_coverage == 3 # path-specific coverage mode
for line in ci.linetable
line = line::LineInfoNode
for line in linetable
if is_file_tracked(line.file)
# if any line falls in a tracked file enable coverage for all
coverage = true
break
end
end
end
# Go through and add an unreachable node after every
# Union{} call. Then reindex labels
code = copy_exprargs(ci.code)
stmtinfo = sv.stmt_info
codelocs = ci.codelocs
ssavaluetypes = ci.ssavaluetypes::Vector{Any}
ssaflags = ci.ssaflags
meta = Expr[]
idx = 1
oldidx = 1
changemap = fill(0, length(code))
labelmap = coverage ? fill(0, length(code)) : changemap
prevloc = zero(eltype(ci.codelocs))
while idx <= length(code)
stmt = code[idx]
if process_meta!(meta, stmt) || !(is_meta_expr(stmt) || was_reached(sv, oldidx))
if oldidx < length(labelmap)
changemap[oldidx] != 0 && (changemap[oldidx+1] = changemap[oldidx])
if coverage && labelmap[oldidx] != 0
labelmap[oldidx + 1] = labelmap[oldidx]
end
changemap[oldidx] = -1
coverage && (labelmap[oldidx] = -1)
end
# TODO: It would be more efficient to do this in bulk
deleteat!(code, idx)
deleteat!(codelocs, idx)
deleteat!(ssavaluetypes, idx)
deleteat!(stmtinfo, idx)
deleteat!(ssaflags, idx)
oldidx += 1
continue
end
codeloc = codelocs[idx]
if coverage && codeloc != prevloc && codeloc != 0
# insert a side-effect instruction before the current instruction in the same basic block
Expand All @@ -593,7 +625,16 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
idx += 1
prevloc = codeloc
end
if code[idx] isa Expr && ssavaluetypes[idx] === Union{}
if isa(stmt, GotoIfNot)
# replace GotoIfNot with:
# - GotoNode if the fallthrough target is unreachable
# - no-op if the branch target is unreachable
if !was_reached(sv, oldidx + 1)
code[idx] = GotoNode(stmt.dest)
elseif !was_reached(sv, stmt.dest)
code[idx] = nothing
end
elseif stmt isa Expr && ssavaluetypes[idx] === Union{}
if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
# insert unreachable in the same basic block after the current instruction (splitting it)
insert!(code, idx + 1, ReturnNode())
Expand All @@ -611,34 +652,22 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
idx += 1
oldidx += 1
end

renumber_ir_elements!(code, changemap, labelmap)

meta = Any[]
for i = 1:length(code)
code[i] = remove_meta!(code[i], meta)
end
strip_trailing_junk!(ci, code, stmtinfo)
cfg = compute_basic_blocks(code)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags)
linetable = ci.linetable
isa(linetable, Vector{LineInfoNode}) || (linetable = collect(LineInfoNode, linetable::Vector{Any}))
ir = IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes)
return ir
cfg = compute_basic_blocks(code)
return IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes)
end

function remove_meta!(@nospecialize(stmt), meta::Vector{Any})
if isa(stmt, Expr)
head = stmt.head
if head === :meta
args = stmt.args
if length(args) > 0
push!(meta, stmt)
end
return nothing
end
function process_meta!(meta::Vector{Expr}, @nospecialize stmt)
if isexpr(stmt, :meta) && length(stmt.args) 1
push!(meta, stmt)
return true
end
return stmt
return false
end

function slot2reg(ir::IRCode, ci::CodeInfo, sv::OptimizationState)
Expand Down Expand Up @@ -800,7 +829,9 @@ end

function cumsum_ssamap!(ssamap::Vector{Int})
rel_change = 0
any_change = false
for i = 1:length(ssamap)
any_change = any_change || ssamap[i] != 0
rel_change += ssamap[i]
if ssamap[i] == -1
# Keep a marker that this statement was deleted
Expand All @@ -809,16 +840,15 @@ function cumsum_ssamap!(ssamap::Vector{Int})
ssamap[i] = rel_change
end
end
return any_change
end

function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, labelchangemap::Vector{Int})
cumsum_ssamap!(labelchangemap)
any_change = cumsum_ssamap!(labelchangemap)
if ssachangemap !== labelchangemap
cumsum_ssamap!(ssachangemap)
end
if labelchangemap[end] == 0 && ssachangemap[end] == 0
return
any_change |= cumsum_ssamap!(ssachangemap)
end
any_change || return
for i = 1:length(body)
el = body[i]
if isa(el, GotoNode)
Expand All @@ -828,7 +858,8 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
if isa(cond, SSAValue)
cond = SSAValue(cond.id + ssachangemap[cond.id])
end
body[i] = GotoIfNot(cond, el.dest + labelchangemap[el.dest])
was_deleted = labelchangemap[el.dest] == typemin(Int)
body[i] = was_deleted ? cond : GotoIfNot(cond, el.dest + labelchangemap[el.dest])
elseif isa(el, ReturnNode)
if isdefined(el, :val)
val = el.val
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ struct IRCode
linetable::Vector{LineInfoNode}
cfg::CFG
new_nodes::NewNodeStream
meta::Vector{Any}
meta::Vector{Expr}

function IRCode(stmts::InstructionStream, cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Any}, sptypes::Vector{Any})
function IRCode(stmts::InstructionStream, cfg::CFG, linetable::Vector{LineInfoNode}, argtypes::Vector{Any}, meta::Vector{Expr}, sptypes::Vector{Any})
return new(stmts, argtypes, sptypes, linetable, cfg, NewNodeStream(), meta)
end
function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
ssavaluetypes isa Vector{Any} ? copy(ssavaluetypes) : Any[ Any for i = 1:(ssavaluetypes::Int) ]
end
stmts = InstructionStream(code, ssavaluetypes, Any[nothing for i = 1:nstmts], copy(ci.codelocs), copy(ci.ssaflags))
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), argtypes, Any[], sptypes)
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), argtypes, Expr[], sptypes)
return ir
end

Expand Down
Loading

0 comments on commit 56ab8f0

Please sign in to comment.