Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: refactor the core loops to use less memory #45276

Merged
merged 16 commits into from
May 30, 2022
444 changes: 269 additions & 175 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ include("compiler/utilities.jl")
include("compiler/validation.jl")
include("compiler/methodtable.jl")

function argextype end # imported by EscapeAnalysis
function stmt_effect_free end # imported by EscapeAnalysis
function alloc_array_ndims end # imported by EscapeAnalysis
function try_compute_field end # imported by EscapeAnalysis
include("compiler/ssair/basicblock.jl")
include("compiler/ssair/domtree.jl")
include("compiler/ssair/ir.jl")

include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")

Expand Down
75 changes: 46 additions & 29 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,21 @@ mutable struct InferenceState
sptypes::Vector{Any}
slottypes::Vector{Any}
src::CodeInfo
cfg::CFG

#= intermediate states for local abstract interpretation =#
currbb::Int
currpc::Int
ip::BitSetBoundedMinPrioritySet # current active instruction pointers
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_at::Vector{Int} # current exception handler info
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
stmt_types::Vector{Union{Nothing, VarTable}}
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Union{Nothing,Vector{Any}}}
stmt_info::Vector{Any}

#= interprocedural intermediate states for abstract interpretation =#
#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
Expand All @@ -125,36 +129,37 @@ mutable struct InferenceState
interp::AbstractInterpreter

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult,
src::CodeInfo, cache::Symbol, interp::AbstractInterpreter)
function InferenceState(result::InferenceResult, src::CodeInfo, cache::Symbol,
interp::AbstractInterpreter)
linfo = result.linfo
world = get_world_counter(interp)
def = linfo.def
mod = isa(def, Method) ? def.module : def
sptypes = sptypes_from_meth_instance(linfo)

code = src.code::Vector{Any}
nstmts = length(code)
currpc = 1
ip = BitSetBoundedMinPrioritySet(nstmts)
handler_at = compute_trycatch(code, ip.elems)
push!(ip, 1)
cfg = compute_basic_blocks(code)

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at = compute_trycatch(code, BitSet())
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
stmt_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ]
nstmts = length(code)
stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
stmt_info = Any[ nothing for i = 1:nstmts ]

nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
argtypes = result.argtypes
nargs = length(argtypes)
stmt_types[1] = stmt_type1 = VarTable(undef, nslots)
for i in 1:nslots
argtyp = (i > nargs) ? Bottom : argtypes[i]
stmt_type1[i] = VarState(argtyp, i > nargs)
nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
end
src.ssavaluetypes = ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]

pclimitations = IdSet{InferenceState}()
limitations = IdSet{InferenceState}()
Expand Down Expand Up @@ -183,15 +188,14 @@ mutable struct InferenceState
cached = cache === :global

frame = new(
linfo, world, mod, sptypes, slottypes, src,
currpc, ip, handler_at, ssavalue_uses, stmt_types, stmt_edges, stmt_info,
linfo, world, mod, sptypes, slottypes, src, cfg,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred,
result, valid_worlds, bestguess, ipo_effects,
params, restrict_abstract_call_sites, cached,
interp)

# some more setups
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)
result.result = frame
cache !== :no && push!(get_inference_cache(interp), result)
Expand Down Expand Up @@ -226,6 +230,8 @@ function any_inbounds(code::Vector{Any})
return false
end

was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND

function compute_trycatch(code::Vector{Any}, ip::BitSet)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
Expand Down Expand Up @@ -422,23 +428,28 @@ end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)

function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.ssavaluetypes
old = ssavaluetypes[ssa_id]
if old === NOT_FOUND || !(new old)
# typically, we expect that old ⊑ new (that output information only
# gets less precise with worse input information), but to actually
# guarantee convergence we need to use tmerge here to ensure that is true
ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : tmerge(old, new)
W = frame.ip
s = frame.stmt_types
for r in frame.ssavalue_uses[ssa_id]
if s[r] !== nothing # s[r] === nothing => unreached statement
push!(W, r)
if was_reached(frame, r)
usebb = block_for_inst(frame.cfg, r)
# We're guaranteed to visit the statement if it's in the current
# basic block, since SSA values can only ever appear after their
# def.
if usebb != frame.currbb
push!(W, usebb)
end
end
end
end
nothing
return nothing
end

function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
Expand All @@ -457,7 +468,7 @@ function add_backedge!(li::MethodInstance, caller::InferenceState)
edges = caller.stmt_edges[caller.currpc] = []
end
push!(edges, li)
nothing
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
Expand All @@ -469,7 +480,13 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe
end
push!(edges, mt)
push!(edges, typ)
nothing
return nothing
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
edges = frame.stmt_edges[currpc]
edges === nothing || empty!(edges)
return nothing
end

function print_callstack(sv::InferenceState)
Expand Down
45 changes: 33 additions & 12 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,20 @@ mutable struct OptimizationState
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
inlining::InliningState
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
cfg::Union{Nothing,CFG}
function OptimizationState(frame::InferenceState, params::OptimizationParams,
interp::AbstractInterpreter, recompute_cfg::Bool=true)
s_edges = frame.stmt_edges[1]::Vector{Any}
inlining = InliningState(params,
EdgeTracker(s_edges, frame.valid_worlds),
WorldView(code_cache(interp), frame.world),
interp)
return new(frame.linfo,
frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining)
cfg = recompute_cfg ? nothing : frame.cfg
return new(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
interp::AbstractInterpreter)
# prepare src for running optimization passes
# if it isn't already
nssavalues = src.ssavaluetypes
Expand All @@ -115,6 +118,7 @@ mutable struct OptimizationState
else
nssavalues = length(src.ssavaluetypes::Vector{Any})
end
sptypes = sptypes_from_meth_instance(linfo)
nslots = length(src.slotflags)
slottypes = src.slottypes
if slottypes === nothing
Expand All @@ -130,9 +134,8 @@ mutable struct OptimizationState
nothing,
WorldView(code_cache(interp), get_world_counter()),
interp)
return new(linfo,
src, nothing, stmt_info, mod,
sptypes_from_meth_instance(linfo), slottypes, inlining)
return new(linfo, src, nothing, stmt_info, mod,
sptypes, slottypes, inlining, nothing)
end
end

Expand Down Expand Up @@ -603,8 +606,8 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
meta = Expr[]
idx = 1
oldidx = 1
ssachangemap = fill(0, length(code))
labelchangemap = coverage ? fill(0, length(code)) : ssachangemap
nstmts = length(code)
ssachangemap = labelchangemap = nothing
prevloc = zero(eltype(ci.codelocs))
while idx <= length(code)
codeloc = codelocs[idx]
Expand All @@ -615,6 +618,12 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
insert!(ssavaluetypes, idx, Nothing)
insert!(stmtinfo, idx, nothing)
insert!(ssaflags, idx, IR_FLAG_NULL)
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = coverage ? fill(0, nstmts) : ssachangemap
end
ssachangemap[oldidx] += 1
if oldidx < length(labelchangemap)
labelchangemap[oldidx + 1] += 1
Expand All @@ -630,6 +639,12 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
insert!(ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, nothing)
insert!(ssaflags, idx + 1, ssaflags[idx])
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = coverage ? fill(0, nstmts) : ssachangemap
end
if oldidx < length(ssachangemap)
ssachangemap[oldidx + 1] += 1
coverage && (labelchangemap[oldidx + 1] += 1)
Expand All @@ -641,15 +656,21 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
oldidx += 1
end

renumber_ir_elements!(code, ssachangemap, labelchangemap)
cfg = sv.cfg
if ssachangemap !== nothing && labelchangemap !== nothing
renumber_ir_elements!(code, ssachangemap, labelchangemap)
cfg = nothing # recompute CFG
end

for i = 1:length(code)
code[i] = process_meta!(meta, code[i])
end
strip_trailing_junk!(ci, code, stmtinfo)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags)
cfg = compute_basic_blocks(code)
if cfg === nothing
cfg = compute_basic_blocks(code)
end
return IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes)
end

Expand Down
8 changes: 0 additions & 8 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ else
end
end

function argextype end # imported by EscapeAnalysis
function stmt_effect_free end # imported by EscapeAnalysis
function alloc_array_ndims end # imported by EscapeAnalysis
function try_compute_field end # imported by EscapeAnalysis

include("compiler/ssair/basicblock.jl")
include("compiler/ssair/domtree.jl")
include("compiler/ssair/ir.jl")
include("compiler/ssair/slot2ssa.jl")
include("compiler/ssair/inlining.jl")
include("compiler/ssair/verify.jl")
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAV
elseif xinfo !== nothing
return !xinfo.attach_after
else
return yinfo.attach_after
return (yinfo::NewNodeInfo).attach_after
end
end
return x′.id < y′.id
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ function type_lift_pass!(ir::IRCode)
end
else
while isa(node, PiNode)
id = node.val.id
id = (node.val::SSAValue).id
node = insts[id][:inst]
end
if isa(node, Union{PhiNode, PhiCNode})
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1825,7 +1825,8 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
effect_free = true
elseif f === getglobal && length(argtypes) >= 3
nothrow = getglobal_nothrow(argtypes[2:end])
ipo_consistent = nothrow && isconst((argtypes[2]::Const).val, (argtypes[3]::Const).val)
ipo_consistent = nothrow && isconst( # types are already checked in `getglobal_nothrow`
(argtypes[2]::Const).val::Module, (argtypes[3]::Const).val::Symbol)
effect_free = true
else
ipo_consistent = contains_is(_CONSISTENT_BUILTINS, f)
Expand Down
Loading