Skip to content

Commit

Permalink
inference: refactor the core loops to use less memory (#45276)
Browse files Browse the repository at this point in the history
Currently inference uses `O(<number of statements>*<number of slots>)` state
in the core inference loop. This is usually fine, because users don't tend
to write functions that are particularly long. However, MTK does generate
functions that are excessively long and we've observed MTK models that spend
99% of their inference time just allocating and copying this state.
It is possible to get away with significantly smaller state, and this PR is
a first step in that direction, reducing the state to `O(<number of basic blocks>*<number of slots>)`.
Further improvements are possible by making use of slot liveness information
and only storing those slots that are live across a particular basic block.

The core change here is to keep a full set of `slottypes` only at
basic block boundaries rather than at each statement. For statements
in between, the full variable state can be fully recovered by
linearly scanning throughout the basic block, taking note of
slot assignments (together with the SSA type) and NewVarNodes.

Co-Authored-By: Keno Fisher <keno@juliacomputing.com>
  • Loading branch information
aviatesk and Keno authored May 30, 2022
1 parent 0057d75 commit 5a32626
Show file tree
Hide file tree
Showing 11 changed files with 518 additions and 362 deletions.
444 changes: 269 additions & 175 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ include("compiler/utilities.jl")
include("compiler/validation.jl")
include("compiler/methodtable.jl")

function argextype end # imported by EscapeAnalysis
function stmt_effect_free end # imported by EscapeAnalysis
function alloc_array_ndims end # imported by EscapeAnalysis
function try_compute_field end # imported by EscapeAnalysis
include("compiler/ssair/basicblock.jl")
include("compiler/ssair/domtree.jl")
include("compiler/ssair/ir.jl")

include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")

Expand Down
75 changes: 46 additions & 29 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,21 @@ mutable struct InferenceState
sptypes::Vector{Any}
slottypes::Vector{Any}
src::CodeInfo
cfg::CFG

#= intermediate states for local abstract interpretation =#
currbb::Int
currpc::Int
ip::BitSetBoundedMinPrioritySet # current active instruction pointers
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_at::Vector{Int} # current exception handler info
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
stmt_types::Vector{Union{Nothing, VarTable}}
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Union{Nothing,Vector{Any}}}
stmt_info::Vector{Any}

#= interprocedural intermediate states for abstract interpretation =#
#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
Expand All @@ -125,36 +129,37 @@ mutable struct InferenceState
interp::AbstractInterpreter

# src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results
function InferenceState(result::InferenceResult,
src::CodeInfo, cache::Symbol, interp::AbstractInterpreter)
function InferenceState(result::InferenceResult, src::CodeInfo, cache::Symbol,
interp::AbstractInterpreter)
linfo = result.linfo
world = get_world_counter(interp)
def = linfo.def
mod = isa(def, Method) ? def.module : def
sptypes = sptypes_from_meth_instance(linfo)

code = src.code::Vector{Any}
nstmts = length(code)
currpc = 1
ip = BitSetBoundedMinPrioritySet(nstmts)
handler_at = compute_trycatch(code, ip.elems)
push!(ip, 1)
cfg = compute_basic_blocks(code)

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at = compute_trycatch(code, BitSet())
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
stmt_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ]
nstmts = length(code)
stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ]
stmt_info = Any[ nothing for i = 1:nstmts ]

nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
argtypes = result.argtypes
nargs = length(argtypes)
stmt_types[1] = stmt_type1 = VarTable(undef, nslots)
for i in 1:nslots
argtyp = (i > nargs) ? Bottom : argtypes[i]
stmt_type1[i] = VarState(argtyp, i > nargs)
nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
end
src.ssavaluetypes = ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]

pclimitations = IdSet{InferenceState}()
limitations = IdSet{InferenceState}()
Expand Down Expand Up @@ -183,15 +188,14 @@ mutable struct InferenceState
cached = cache === :global

frame = new(
linfo, world, mod, sptypes, slottypes, src,
currpc, ip, handler_at, ssavalue_uses, stmt_types, stmt_edges, stmt_info,
linfo, world, mod, sptypes, slottypes, src, cfg,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred,
result, valid_worlds, bestguess, ipo_effects,
params, restrict_abstract_call_sites, cached,
interp)

# some more setups
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)
result.result = frame
cache !== :no && push!(get_inference_cache(interp), result)
Expand Down Expand Up @@ -226,6 +230,8 @@ function any_inbounds(code::Vector{Any})
return false
end

was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND

function compute_trycatch(code::Vector{Any}, ip::BitSet)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
Expand Down Expand Up @@ -422,23 +428,28 @@ end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)

function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.ssavaluetypes
old = ssavaluetypes[ssa_id]
if old === NOT_FOUND || !(new old)
# typically, we expect that old ⊑ new (that output information only
# gets less precise with worse input information), but to actually
# guarantee convergence we need to use tmerge here to ensure that is true
ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : tmerge(old, new)
W = frame.ip
s = frame.stmt_types
for r in frame.ssavalue_uses[ssa_id]
if s[r] !== nothing # s[r] === nothing => unreached statement
push!(W, r)
if was_reached(frame, r)
usebb = block_for_inst(frame.cfg, r)
# We're guaranteed to visit the statement if it's in the current
# basic block, since SSA values can only ever appear after their
# def.
if usebb != frame.currbb
push!(W, usebb)
end
end
end
end
nothing
return nothing
end

function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
Expand All @@ -457,7 +468,7 @@ function add_backedge!(li::MethodInstance, caller::InferenceState)
edges = caller.stmt_edges[caller.currpc] = []
end
push!(edges, li)
nothing
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
Expand All @@ -469,7 +480,13 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe
end
push!(edges, mt)
push!(edges, typ)
nothing
return nothing
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
edges = frame.stmt_edges[currpc]
edges === nothing || empty!(edges)
return nothing
end

function print_callstack(sv::InferenceState)
Expand Down
45 changes: 33 additions & 12 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,20 @@ mutable struct OptimizationState
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
inlining::InliningState
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
cfg::Union{Nothing,CFG}
function OptimizationState(frame::InferenceState, params::OptimizationParams,
interp::AbstractInterpreter, recompute_cfg::Bool=true)
s_edges = frame.stmt_edges[1]::Vector{Any}
inlining = InliningState(params,
EdgeTracker(s_edges, frame.valid_worlds),
WorldView(code_cache(interp), frame.world),
interp)
return new(frame.linfo,
frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining)
cfg = recompute_cfg ? nothing : frame.cfg
return new(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
interp::AbstractInterpreter)
# prepare src for running optimization passes
# if it isn't already
nssavalues = src.ssavaluetypes
Expand All @@ -115,6 +118,7 @@ mutable struct OptimizationState
else
nssavalues = length(src.ssavaluetypes::Vector{Any})
end
sptypes = sptypes_from_meth_instance(linfo)
nslots = length(src.slotflags)
slottypes = src.slottypes
if slottypes === nothing
Expand All @@ -130,9 +134,8 @@ mutable struct OptimizationState
nothing,
WorldView(code_cache(interp), get_world_counter()),
interp)
return new(linfo,
src, nothing, stmt_info, mod,
sptypes_from_meth_instance(linfo), slottypes, inlining)
return new(linfo, src, nothing, stmt_info, mod,
sptypes, slottypes, inlining, nothing)
end
end

Expand Down Expand Up @@ -603,8 +606,8 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
meta = Expr[]
idx = 1
oldidx = 1
ssachangemap = fill(0, length(code))
labelchangemap = coverage ? fill(0, length(code)) : ssachangemap
nstmts = length(code)
ssachangemap = labelchangemap = nothing
prevloc = zero(eltype(ci.codelocs))
while idx <= length(code)
codeloc = codelocs[idx]
Expand All @@ -615,6 +618,12 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
insert!(ssavaluetypes, idx, Nothing)
insert!(stmtinfo, idx, nothing)
insert!(ssaflags, idx, IR_FLAG_NULL)
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = coverage ? fill(0, nstmts) : ssachangemap
end
ssachangemap[oldidx] += 1
if oldidx < length(labelchangemap)
labelchangemap[oldidx + 1] += 1
Expand All @@ -630,6 +639,12 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
insert!(ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, nothing)
insert!(ssaflags, idx + 1, ssaflags[idx])
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = coverage ? fill(0, nstmts) : ssachangemap
end
if oldidx < length(ssachangemap)
ssachangemap[oldidx + 1] += 1
coverage && (labelchangemap[oldidx + 1] += 1)
Expand All @@ -641,15 +656,21 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
oldidx += 1
end

renumber_ir_elements!(code, ssachangemap, labelchangemap)
cfg = sv.cfg
if ssachangemap !== nothing && labelchangemap !== nothing
renumber_ir_elements!(code, ssachangemap, labelchangemap)
cfg = nothing # recompute CFG
end

for i = 1:length(code)
code[i] = process_meta!(meta, code[i])
end
strip_trailing_junk!(ci, code, stmtinfo)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags)
cfg = compute_basic_blocks(code)
if cfg === nothing
cfg = compute_basic_blocks(code)
end
return IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes)
end

Expand Down
8 changes: 0 additions & 8 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ else
end
end

function argextype end # imported by EscapeAnalysis
function stmt_effect_free end # imported by EscapeAnalysis
function alloc_array_ndims end # imported by EscapeAnalysis
function try_compute_field end # imported by EscapeAnalysis

include("compiler/ssair/basicblock.jl")
include("compiler/ssair/domtree.jl")
include("compiler/ssair/ir.jl")
include("compiler/ssair/slot2ssa.jl")
include("compiler/ssair/inlining.jl")
include("compiler/ssair/verify.jl")
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAV
elseif xinfo !== nothing
return !xinfo.attach_after
else
return yinfo.attach_after
return (yinfo::NewNodeInfo).attach_after
end
end
return x′.id < y′.id
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ function type_lift_pass!(ir::IRCode)
end
else
while isa(node, PiNode)
id = node.val.id
id = (node.val::SSAValue).id
node = insts[id][:inst]
end
if isa(node, Union{PhiNode, PhiCNode})
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1827,7 +1827,8 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
effect_free = true
elseif f === getglobal && length(argtypes) >= 3
nothrow = getglobal_nothrow(argtypes[2:end])
ipo_consistent = nothrow && isconst((argtypes[2]::Const).val, (argtypes[3]::Const).val)
ipo_consistent = nothrow && isconst( # types are already checked in `getglobal_nothrow`
(argtypes[2]::Const).val::Module, (argtypes[3]::Const).val::Symbol)
effect_free = true
else
ipo_consistent = contains_is(_CONSISTENT_BUILTINS, f)
Expand Down
Loading

0 comments on commit 5a32626

Please sign in to comment.