From 5a32626298b2de252d6d3152ccce25be18cb19ad Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Mon, 30 May 2022 13:40:44 +0900 Subject: [PATCH] inference: refactor the core loops to use less memory (#45276) Currently inference uses `O(*)` state in the core inference loop. This is usually fine, because users don't tend to write functions that are particularly long. However, MTK does generate functions that are excessively long and we've observed MTK models that spend 99% of their inference time just allocating and copying this state. It is possible to get away with significantly smaller state, and this PR is a first step in that direction, reducing the state to `O(*)`. Further improvements are possible by making use of slot liveness information and only storing those slots that are live across a particular basic block. The core change here is to keep a full set of `slottypes` only at basic block boundaries rather than at each statement. For statements in between, the full variable state can be fully recovered by linearly scanning throughout the basic block, taking note of slot assignments (together with the SSA type) and NewVarNodes. Co-Authored-By: Keno Fisher --- base/compiler/abstractinterpretation.jl | 444 ++++++++++++++---------- base/compiler/compiler.jl | 8 + base/compiler/inferencestate.jl | 75 ++-- base/compiler/optimize.jl | 45 ++- base/compiler/ssair/driver.jl | 8 - base/compiler/ssair/ir.jl | 2 +- base/compiler/ssair/passes.jl | 2 +- base/compiler/tfuncs.jl | 3 +- base/compiler/typeinfer.jl | 161 +++++---- base/compiler/typelattice.jl | 92 +++-- test/compiler/inference.jl | 40 ++- 11 files changed, 518 insertions(+), 362 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 5ea19bf900737..c9608f08f5b3a 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1555,7 +1555,8 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn match === nothing && return CallMeta(Any, Effects(), false) update_valid_age!(sv, valid_worlds) method = match.method - (ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector + tienv = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector + ti = tienv[1]; env = tienv[2]::SimpleVector (; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv) effects = result.edge_effects edge !== nothing && add_backedge!(edge::MethodInstance, sv) @@ -1738,7 +1739,7 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, body_call = abstract_call_opaque_closure(interp, ft, ArgInfo(arginfo.fargs, newargtypes), sv) # Analyze implicit type asserts on argument and return type ftt = ft.typ - (at, rt) = unwrap_unionall(ftt).parameters + (at, rt) = (unwrap_unionall(ftt)::DataType).parameters if isa(rt, TypeVar) rt = rewrap_unionall(rt.lb, ftt) else @@ -2039,7 +2040,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if isdefined(sym.mod, sym.name) t = Const(true) end - elseif isa(sym, Expr) && sym.head === :static_parameter + elseif isexpr(sym, :static_parameter) n = sym.args[1]::Int if 1 <= n <= length(sv.sptypes) spty = sv.sptypes[n] @@ -2058,7 +2059,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), t = Const(t.instance) end if !isempty(sv.pclimitations) - if t isa Const || t === Union{} + if t isa Const || t === Bottom empty!(sv.pclimitations) else t = LimitedAccuracy(t, sv.pclimitations) @@ -2097,9 +2098,10 @@ function handle_global_assignment!(interp::AbstractInterpreter, frame::Inference nothrow=nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN)) end -abstract_eval_ssavalue(s::SSAValue, sv::InferenceState) = abstract_eval_ssavalue(s, sv.src) -function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) - typ = (src.ssavaluetypes::Vector{Any})[s.id] +abstract_eval_ssavalue(s::SSAValue, sv::InferenceState) = abstract_eval_ssavalue(s, sv.ssavaluetypes) +abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any}) +function abstract_eval_ssavalue(s::SSAValue, ssavaluetypes::Vector{Any}) + typ = ssavaluetypes[s.id] if typ === NOT_FOUND return Bottom end @@ -2194,211 +2196,303 @@ function handle_control_backedge!(frame::InferenceState, from::Int, to::Int) return nothing end +struct BasicStmtChange + changes::Union{Nothing,StateUpdate} + type::Any # ::Union{Type, Nothing} - `nothing` if this statement may not be used as an SSA Value + # TODO effects::Effects + BasicStmtChange(changes::Union{Nothing,StateUpdate}, @nospecialize type) = new(changes, type) +end + +@inline function abstract_eval_basic_statement(interp::AbstractInterpreter, + @nospecialize(stmt), pc_vartable::VarTable, frame::InferenceState) + if isa(stmt, NewvarNode) + changes = StateUpdate(stmt.slot, VarState(Bottom, true), pc_vartable, false) + return BasicStmtChange(changes, nothing) + elseif !isa(stmt, Expr) + t = abstract_eval_statement(interp, stmt, pc_vartable, frame) + return BasicStmtChange(nothing, t) + end + changes = nothing + stmt = stmt::Expr + hd = stmt.head + if hd === :(=) + t = abstract_eval_statement(interp, stmt.args[2], pc_vartable, frame) + if t === Bottom + return BasicStmtChange(nothing, Bottom) + end + lhs = stmt.args[1] + if isa(lhs, SlotNumber) + changes = StateUpdate(lhs, VarState(t, false), pc_vartable, false) + elseif isa(lhs, GlobalRef) + handle_global_assignment!(interp, frame, lhs, t) + elseif !isa(lhs, SSAValue) + tristate_merge!(frame, EFFECTS_UNKNOWN) + end + return BasicStmtChange(changes, t) + elseif hd === :method + fname = stmt.args[1] + if isa(fname, SlotNumber) + changes = StateUpdate(fname, VarState(Any, false), pc_vartable, false) + end + return BasicStmtChange(changes, nothing) + elseif (hd === :code_coverage_effect || ( + hd !== :boundscheck && # :boundscheck can be narrowed to Bool + is_meta_expr(stmt))) + return BasicStmtChange(nothing, Nothing) + else + t = abstract_eval_statement(interp, stmt, pc_vartable, frame) + return BasicStmtChange(nothing, t) + end +end + +function update_bbstate!(frame::InferenceState, bb::Int, vartable::VarTable) + bbtable = frame.bb_vartables[bb] + if bbtable === nothing + # if a basic block hasn't been analyzed yet, + # we can update its state a bit more aggressively + frame.bb_vartables[bb] = copy(vartable) + return true + else + return stupdate!(bbtable, vartable) + end +end + +function init_vartable!(vartable::VarTable, frame::InferenceState) + nargtypes = length(frame.result.argtypes) + for i = 1:length(vartable) + vartable[i] = VarState(Bottom, i > nargtypes) + end + return vartable +end + # make as much progress on `frame` as possible (without handling cycles) function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) @assert !frame.inferred frame.dont_work_on_me = true # mark that this function is currently on the stack W = frame.ip - states = frame.stmt_types def = frame.linfo.def isva = isa(def, Method) && def.isva nargs = length(frame.result.argtypes) - isva slottypes = frame.slottypes - ssavaluetypes = frame.src.ssavaluetypes::Vector{Any} - while !isempty(W) - # make progress on the active ip set - local pc::Int = popfirst!(W) - local pc´::Int = pc + 1 # next program-counter (after executing instruction) - frame.currpc = pc - edges = frame.stmt_edges[pc] - edges === nothing || empty!(edges) - frame.stmt_info[pc] = nothing - stmt = frame.src.code[pc] - changes = states[pc]::VarTable - t = nothing - - hd = isa(stmt, Expr) ? stmt.head : nothing - - if isa(stmt, NewvarNode) - sn = slot_id(stmt.slot) - changes[sn] = VarState(Bottom, true) - elseif isa(stmt, GotoNode) - l = (stmt::GotoNode).label - handle_control_backedge!(frame, pc, l) - pc´ = l - elseif isa(stmt, GotoIfNot) - condx = stmt.cond - condt = abstract_eval_value(interp, condx, changes, frame) - if condt === Bottom - empty!(frame.pclimitations) - continue - end - if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condx, SlotNumber) - # if this non-`Conditional` object is a slot, we form and propagate - # the conditional constraint on it - condt = Conditional(condx, Const(true), Const(false)) - end - condval = maybe_extract_const_bool(condt) - l = stmt.dest::Int - if !isempty(frame.pclimitations) - # we can't model the possible effect of control - # dependencies on the return value, so we propagate it - # directly to all the return values (unless we error first) - condval isa Bool || union!(frame.limitations, frame.pclimitations) - empty!(frame.pclimitations) - end - # constant conditions - if condval === true - elseif condval === false - handle_control_backedge!(frame, pc, l) - pc´ = l - else - # general case - changes_else = changes - if isa(condt, Conditional) - changes_else = conditional_changes(changes_else, condt.elsetype, condt.var) - changes = conditional_changes(changes, condt.vtype, condt.var) - end - newstate_else = stupdate!(states[l], changes_else) - if newstate_else !== nothing - handle_control_backedge!(frame, pc, l) - # add else branch to active IP list - push!(W, l) - states[l] = newstate_else - end - end - elseif isa(stmt, ReturnNode) - bestguess = frame.bestguess - rt = abstract_eval_value(interp, stmt.val, changes, frame) - rt = widenreturn(rt, bestguess, nargs, slottypes, changes) - # narrow representation of bestguess slightly to prepare for tmerge with rt - if rt isa InterConditional && bestguess isa Const - let slot_id = rt.slot - old_id_type = slottypes[slot_id] - if bestguess.val === true && rt.elsetype !== Bottom - bestguess = InterConditional(slot_id, old_id_type, Bottom) - elseif bestguess.val === false && rt.vtype !== Bottom - bestguess = InterConditional(slot_id, Bottom, old_id_type) + ssavaluetypes = frame.ssavaluetypes + bbs = frame.cfg.blocks + nbbs = length(bbs) + + currbb = frame.currbb + if currbb != 1 + currbb = frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block + end + + states = frame.bb_vartables + currstate = copy(states[currbb]::VarTable) + while currbb <= nbbs + delete!(W, currbb) + bbstart = first(bbs[currbb].stmts) + bbend = last(bbs[currbb].stmts) + + for currpc in bbstart:bbend + frame.currpc = currpc + empty_backedges!(frame, currpc) + stmt = frame.src.code[currpc] + # If we're at the end of the basic block ... + if currpc == bbend + # Handle control flow + if isa(stmt, GotoNode) + succs = bbs[currbb].succs + @assert length(succs) == 1 + nextbb = succs[1] + ssavaluetypes[currpc] = Any + handle_control_backedge!(frame, currpc, stmt.label) + @goto branch + elseif isa(stmt, GotoIfNot) + condx = stmt.cond + condt = abstract_eval_value(interp, condx, currstate, frame) + if condt === Bottom + ssavaluetypes[currpc] = Bottom + empty!(frame.pclimitations) + @goto find_next_bb end - end - end - # copy limitations to return value - if !isempty(frame.pclimitations) - union!(frame.limitations, frame.pclimitations) - empty!(frame.pclimitations) - end - if !isempty(frame.limitations) - rt = LimitedAccuracy(rt, copy(frame.limitations)) - end - if tchanged(rt, bestguess) - # new (wider) return type for frame - bestguess = tmerge(bestguess, rt) - # TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end - frame.bestguess = bestguess - for (caller, caller_pc) in frame.cycle_backedges - # notify backedges of updated type information - typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before - if !((caller.src.ssavaluetypes::Vector{Any})[caller_pc] === Any) - # no reason to revisit if that call-site doesn't affect the final result - push!(caller.ip, caller_pc) + if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condx, SlotNumber) + # if this non-`Conditional` object is a slot, we form and propagate + # the conditional constraint on it + condt = Conditional(condx, Const(true), Const(false)) end + condval = maybe_extract_const_bool(condt) + if !isempty(frame.pclimitations) + # we can't model the possible effect of control + # dependencies on the return + # directly to all the return values (unless we error first) + condval isa Bool || union!(frame.limitations, frame.pclimitations) + empty!(frame.pclimitations) + end + ssavaluetypes[currpc] = Any + if condval === true + @goto fallthrough + else + succs = bbs[currbb].succs + if length(succs) == 1 + @assert condval === false || (stmt.dest === currpc + 1) + nextbb = succs[1] + @goto branch + end + @assert length(succs) == 2 + truebb = currbb + 1 + falsebb = succs[1] == truebb ? succs[2] : succs[1] + if condval === false + nextbb = falsebb + handle_control_backedge!(frame, currpc, stmt.dest) + @goto branch + else + # We continue with the true branch, but process the false + # branch here. + if isa(condt, Conditional) + else_change = conditional_change(currstate, condt.elsetype, condt.var) + if else_change !== nothing + false_vartable = stoverwrite1!(copy(currstate), else_change) + else + false_vartable = currstate + end + changed = update_bbstate!(frame, falsebb, false_vartable) + then_change = conditional_change(currstate, condt.vtype, condt.var) + if then_change !== nothing + stoverwrite1!(currstate, then_change) + end + else + changed = update_bbstate!(frame, falsebb, currstate) + end + if changed + handle_control_backedge!(frame, currpc, stmt.dest) + push!(W, falsebb) + end + @goto fallthrough + end + end + elseif isa(stmt, ReturnNode) + bestguess = frame.bestguess + rt = abstract_eval_value(interp, stmt.val, currstate, frame) + rt = widenreturn(rt, bestguess, nargs, slottypes, currstate) + # narrow representation of bestguess slightly to prepare for tmerge with rt + if rt isa InterConditional && bestguess isa Const + let slot_id = rt.slot + old_id_type = slottypes[slot_id] + if bestguess.val === true && rt.elsetype !== Bottom + bestguess = InterConditional(slot_id, old_id_type, Bottom) + elseif bestguess.val === false && rt.vtype !== Bottom + bestguess = InterConditional(slot_id, Bottom, old_id_type) + end + end + end + # copy limitations to return value + if !isempty(frame.pclimitations) + union!(frame.limitations, frame.pclimitations) + empty!(frame.pclimitations) + end + if !isempty(frame.limitations) + rt = LimitedAccuracy(rt, copy(frame.limitations)) + end + if tchanged(rt, bestguess) + # new (wider) return type for frame + bestguess = tmerge(bestguess, rt) + # TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end + frame.bestguess = bestguess + for (caller, caller_pc) in frame.cycle_backedges + if !(caller.ssavaluetypes[caller_pc] === Any) + # no reason to revisit if that call-site doesn't affect the final result + push!(caller.ip, block_for_inst(caller.cfg, caller_pc)) + end + end + end + ssavaluetypes[frame.currpc] = Any + @goto find_next_bb + elseif isexpr(stmt, :enter) + # Propagate entry info to exception handler + l = stmt.args[1]::Int + catchbb = block_for_inst(frame.cfg, l) + if update_bbstate!(frame, catchbb, currstate) + push!(W, catchbb) + end + ssavaluetypes[currpc] = Any + @goto fallthrough end - end - continue - elseif hd === :enter - stmt = stmt::Expr - l = stmt.args[1]::Int - # propagate type info to exception handler - old = states[l] - newstate_catch = stupdate!(old, changes) - if newstate_catch !== nothing - push!(W, l) - states[l] = newstate_catch - end - typeassert(states[l], VarTable) - elseif hd === :leave - else - if hd === :(=) - stmt = stmt::Expr - t = abstract_eval_statement(interp, stmt.args[2], changes, frame) - if t === Bottom - continue - end - ssavaluetypes[pc] = t - lhs = stmt.args[1] - if isa(lhs, SlotNumber) - changes = StateUpdate(lhs, VarState(t, false), changes, false) - elseif isa(lhs, GlobalRef) - handle_global_assignment!(interp, frame, lhs, t) - elseif !isa(lhs, SSAValue) - tristate_merge!(frame, EFFECTS_UNKNOWN) - end - elseif hd === :method - stmt = stmt::Expr - fname = stmt.args[1] - if isa(fname, SlotNumber) - changes = StateUpdate(fname, VarState(Any, false), changes, false) - end - elseif hd === :code_coverage_effect || ( - hd !== :boundscheck && # :boundscheck can be narrowed to Bool - is_meta_expr(stmt)) - # these do not generate code - else - t = abstract_eval_statement(interp, stmt, changes, frame) - if t === Bottom - continue - end - if !isempty(frame.ssavalue_uses[pc]) - record_ssa_assign(pc, t, frame) - else - ssavaluetypes[pc] = t - end - end - if isa(changes, StateUpdate) - let cur_hand = frame.handler_at[pc], l, enter + # Fall through terminator - treat as regular stmt + end + # Process non control-flow statements + (; changes, type) = abstract_eval_basic_statement(interp, + stmt, currstate, frame) + if type === Bottom + ssavaluetypes[currpc] = Bottom + @goto find_next_bb + end + if changes !== nothing + stoverwrite1!(currstate, changes) + let cur_hand = frame.handler_at[currpc], l, enter while cur_hand != 0 - enter = frame.src.code[cur_hand] - l = (enter::Expr).args[1]::Int + enter = frame.src.code[cur_hand]::Expr + l = enter.args[1]::Int + exceptbb = block_for_inst(frame.cfg, l) # propagate new type info to exception handler # the handling for Expr(:enter) propagates all changes from before the try/catch # so this only needs to propagate any changes - if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false - push!(W, l) + if stupdate1!(states[exceptbb]::VarTable, changes) + push!(W, exceptbb) end cur_hand = frame.handler_at[cur_hand] end end end - end + if type === nothing + ssavaluetypes[currpc] = Any + continue + end + if !isempty(frame.ssavalue_uses[currpc]) + record_ssa_assign!(currpc, type, frame) + else + ssavaluetypes[currpc] = type + end + end # for currpc in bbstart:bbend - @assert isempty(frame.pclimitations) "unhandled LimitedAccuracy" + # Case 1: Fallthrough termination + begin @label fallthrough + nextbb = currbb + 1 + end - if t === nothing - # mark other reached expressions as `Any` to indicate they don't throw - ssavaluetypes[pc] = Any + # Case 2: Directly branch to a different BB + begin @label branch + if update_bbstate!(frame, nextbb, currstate) + push!(W, nextbb) + end end - newstate = stupdate!(states[pc´], changes) - if newstate !== nothing - states[pc´] = newstate - push!(W, pc´) + # Case 3: Control flow ended along the current path (converged, return or throw) + begin @label find_next_bb + currbb = frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block + currbb == -1 && break # the working set is empty + currbb > nbbs && break + + nexttable = states[currbb] + if nexttable === nothing + init_vartable!(currstate, frame) + else + stoverwrite!(currstate, nexttable) + end end - end + end # while currbb <= nbbs + frame.dont_work_on_me = false nothing end -function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNumber) - vtype = changes[slot_id(var)] +function conditional_change(state::VarTable, @nospecialize(typ), var::SlotNumber) + vtype = state[slot_id(var)] oldtyp = vtype.typ # approximate test for `typ ∩ oldtyp` being better than `oldtyp` # since we probably formed these types with `typesubstract`, the comparison is likely simple if ignorelimited(typ) ⊑ ignorelimited(oldtyp) # typ is better unlimited, but we may still need to compute the tmeet with the limit "causes" since we ignored those in the comparison oldtyp isa LimitedAccuracy && (typ = tmerge(typ, LimitedAccuracy(Bottom, oldtyp.causes))) - return StateUpdate(var, VarState(typ, vtype.undef), changes, true) + return StateUpdate(var, VarState(typ, vtype.undef), state, true) end - return changes + return nothing end function bool_rt_to_conditional(@nospecialize(rt), slottypes::Vector{Any}, state::VarTable, slot_id::Int) diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 1132b8976e53c..82b43d5af03c2 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -128,6 +128,14 @@ include("compiler/utilities.jl") include("compiler/validation.jl") include("compiler/methodtable.jl") +function argextype end # imported by EscapeAnalysis +function stmt_effect_free end # imported by EscapeAnalysis +function alloc_array_ndims end # imported by EscapeAnalysis +function try_compute_field end # imported by EscapeAnalysis +include("compiler/ssair/basicblock.jl") +include("compiler/ssair/domtree.jl") +include("compiler/ssair/ir.jl") + include("compiler/inferenceresult.jl") include("compiler/inferencestate.jl") diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 24423deef8623..15057b45fa2a7 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -88,17 +88,21 @@ mutable struct InferenceState sptypes::Vector{Any} slottypes::Vector{Any} src::CodeInfo + cfg::CFG #= intermediate states for local abstract interpretation =# + currbb::Int currpc::Int - ip::BitSetBoundedMinPrioritySet # current active instruction pointers + ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers handler_at::Vector{Int} # current exception handler info ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info - stmt_types::Vector{Union{Nothing, VarTable}} - stmt_edges::Vector{Union{Nothing, Vector{Any}}} + # TODO: Could keep this sparsely by doing structural liveness analysis ahead of time. + bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet + ssavaluetypes::Vector{Any} + stmt_edges::Vector{Union{Nothing,Vector{Any}}} stmt_info::Vector{Any} - #= interprocedural intermediate states for abstract interpretation =# + #= intermediate states for interprocedural abstract interpretation =# pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller @@ -125,36 +129,37 @@ mutable struct InferenceState interp::AbstractInterpreter # src is assumed to be a newly-allocated CodeInfo, that can be modified in-place to contain intermediate results - function InferenceState(result::InferenceResult, - src::CodeInfo, cache::Symbol, interp::AbstractInterpreter) + function InferenceState(result::InferenceResult, src::CodeInfo, cache::Symbol, + interp::AbstractInterpreter) linfo = result.linfo world = get_world_counter(interp) def = linfo.def mod = isa(def, Method) ? def.module : def sptypes = sptypes_from_meth_instance(linfo) - code = src.code::Vector{Any} - nstmts = length(code) - currpc = 1 - ip = BitSetBoundedMinPrioritySet(nstmts) - handler_at = compute_trycatch(code, ip.elems) - push!(ip, 1) + cfg = compute_basic_blocks(code) + + currbb = currpc = 1 + ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1) + handler_at = compute_trycatch(code, BitSet()) nssavalues = src.ssavaluetypes::Int ssavalue_uses = find_ssavalue_uses(code, nssavalues) - stmt_types = Union{Nothing, VarTable}[ nothing for i = 1:nstmts ] + nstmts = length(code) stmt_edges = Union{Nothing, Vector{Any}}[ nothing for i = 1:nstmts ] stmt_info = Any[ nothing for i = 1:nstmts ] nslots = length(src.slotflags) slottypes = Vector{Any}(undef, nslots) + bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ] + bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots) argtypes = result.argtypes - nargs = length(argtypes) - stmt_types[1] = stmt_type1 = VarTable(undef, nslots) - for i in 1:nslots - argtyp = (i > nargs) ? Bottom : argtypes[i] - stmt_type1[i] = VarState(argtyp, i > nargs) + nargtypes = length(argtypes) + for i = 1:nslots + argtyp = (i > nargtypes) ? Bottom : argtypes[i] slottypes[i] = argtyp + bb_vartable1[i] = VarState(argtyp, i > nargtypes) end + src.ssavaluetypes = ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ] pclimitations = IdSet{InferenceState}() limitations = IdSet{InferenceState}() @@ -183,15 +188,14 @@ mutable struct InferenceState cached = cache === :global frame = new( - linfo, world, mod, sptypes, slottypes, src, - currpc, ip, handler_at, ssavalue_uses, stmt_types, stmt_edges, stmt_info, + linfo, world, mod, sptypes, slottypes, src, cfg, + currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info, pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred, result, valid_worlds, bestguess, ipo_effects, params, restrict_abstract_call_sites, cached, interp) # some more setups - src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ] params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at) result.result = frame cache !== :no && push!(get_inference_cache(interp), result) @@ -226,6 +230,8 @@ function any_inbounds(code::Vector{Any}) return false end +was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND + function compute_trycatch(code::Vector{Any}, ip::BitSet) # The goal initially is to record the frame like this for the state at exit: # 1: (enter 3) # == 0 @@ -422,8 +428,8 @@ end update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds) -function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState) - ssavaluetypes = frame.src.ssavaluetypes::Vector{Any} +function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceState) + ssavaluetypes = frame.ssavaluetypes old = ssavaluetypes[ssa_id] if old === NOT_FOUND || !(new ⊑ old) # typically, we expect that old ⊑ new (that output information only @@ -431,14 +437,19 @@ function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceStat # guarantee convergence we need to use tmerge here to ensure that is true ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : tmerge(old, new) W = frame.ip - s = frame.stmt_types for r in frame.ssavalue_uses[ssa_id] - if s[r] !== nothing # s[r] === nothing => unreached statement - push!(W, r) + if was_reached(frame, r) + usebb = block_for_inst(frame.cfg, r) + # We're guaranteed to visit the statement if it's in the current + # basic block, since SSA values can only ever appear after their + # def. + if usebb != frame.currbb + push!(W, usebb) + end end end end - nothing + return nothing end function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int) @@ -457,7 +468,7 @@ function add_backedge!(li::MethodInstance, caller::InferenceState) edges = caller.stmt_edges[caller.currpc] = [] end push!(edges, li) - nothing + return nothing end # used to temporarily accumulate our no method errors to later add as backedges in the callee method table @@ -469,7 +480,13 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe end push!(edges, mt) push!(edges, typ) - nothing + return nothing +end + +function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc) + edges = frame.stmt_edges[currpc] + edges === nothing || empty!(edges) + return nothing end function print_callstack(sv::InferenceState) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index af4ef61704c1d..e80f5353823ca 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -96,17 +96,20 @@ mutable struct OptimizationState sptypes::Vector{Any} # static parameters slottypes::Vector{Any} inlining::InliningState - function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter) + cfg::Union{Nothing,CFG} + function OptimizationState(frame::InferenceState, params::OptimizationParams, + interp::AbstractInterpreter, recompute_cfg::Bool=true) s_edges = frame.stmt_edges[1]::Vector{Any} inlining = InliningState(params, EdgeTracker(s_edges, frame.valid_worlds), WorldView(code_cache(interp), frame.world), interp) - return new(frame.linfo, - frame.src, nothing, frame.stmt_info, frame.mod, - frame.sptypes, frame.slottypes, inlining) + cfg = recompute_cfg ? nothing : frame.cfg + return new(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod, + frame.sptypes, frame.slottypes, inlining, cfg) end - function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter) + function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, + interp::AbstractInterpreter) # prepare src for running optimization passes # if it isn't already nssavalues = src.ssavaluetypes @@ -115,6 +118,7 @@ mutable struct OptimizationState else nssavalues = length(src.ssavaluetypes::Vector{Any}) end + sptypes = sptypes_from_meth_instance(linfo) nslots = length(src.slotflags) slottypes = src.slottypes if slottypes === nothing @@ -130,9 +134,8 @@ mutable struct OptimizationState nothing, WorldView(code_cache(interp), get_world_counter()), interp) - return new(linfo, - src, nothing, stmt_info, mod, - sptypes_from_meth_instance(linfo), slottypes, inlining) + return new(linfo, src, nothing, stmt_info, mod, + sptypes, slottypes, inlining, nothing) end end @@ -603,8 +606,8 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) meta = Expr[] idx = 1 oldidx = 1 - ssachangemap = fill(0, length(code)) - labelchangemap = coverage ? fill(0, length(code)) : ssachangemap + nstmts = length(code) + ssachangemap = labelchangemap = nothing prevloc = zero(eltype(ci.codelocs)) while idx <= length(code) codeloc = codelocs[idx] @@ -615,6 +618,12 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) insert!(ssavaluetypes, idx, Nothing) insert!(stmtinfo, idx, nothing) insert!(ssaflags, idx, IR_FLAG_NULL) + if ssachangemap === nothing + ssachangemap = fill(0, nstmts) + end + if labelchangemap === nothing + labelchangemap = coverage ? fill(0, nstmts) : ssachangemap + end ssachangemap[oldidx] += 1 if oldidx < length(labelchangemap) labelchangemap[oldidx + 1] += 1 @@ -630,6 +639,12 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) insert!(ssavaluetypes, idx + 1, Union{}) insert!(stmtinfo, idx + 1, nothing) insert!(ssaflags, idx + 1, ssaflags[idx]) + if ssachangemap === nothing + ssachangemap = fill(0, nstmts) + end + if labelchangemap === nothing + labelchangemap = coverage ? fill(0, nstmts) : ssachangemap + end if oldidx < length(ssachangemap) ssachangemap[oldidx + 1] += 1 coverage && (labelchangemap[oldidx + 1] += 1) @@ -641,7 +656,11 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) oldidx += 1 end - renumber_ir_elements!(code, ssachangemap, labelchangemap) + cfg = sv.cfg + if ssachangemap !== nothing && labelchangemap !== nothing + renumber_ir_elements!(code, ssachangemap, labelchangemap) + cfg = nothing # recompute CFG + end for i = 1:length(code) code[i] = process_meta!(meta, code[i]) @@ -649,7 +668,9 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState) strip_trailing_junk!(ci, code, stmtinfo) types = Any[] stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags) - cfg = compute_basic_blocks(code) + if cfg === nothing + cfg = compute_basic_blocks(code) + end return IRCode(stmts, cfg, linetable, sv.slottypes, meta, sv.sptypes) end diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 7759d8d80b9cc..6c17bbc7868f2 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -8,14 +8,6 @@ else end end -function argextype end # imported by EscapeAnalysis -function stmt_effect_free end # imported by EscapeAnalysis -function alloc_array_ndims end # imported by EscapeAnalysis -function try_compute_field end # imported by EscapeAnalysis - -include("compiler/ssair/basicblock.jl") -include("compiler/ssair/domtree.jl") -include("compiler/ssair/ir.jl") include("compiler/ssair/slot2ssa.jl") include("compiler/ssair/inlining.jl") include("compiler/ssair/verify.jl") diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 548c19eb031e7..bc38e61fac630 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -704,7 +704,7 @@ function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAV elseif xinfo !== nothing return !xinfo.attach_after else - return yinfo.attach_after + return (yinfo::NewNodeInfo).attach_after end end return x′.id < y′.id diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 9d36e52fb9f86..20b276b5f3f3e 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1413,7 +1413,7 @@ function type_lift_pass!(ir::IRCode) end else while isa(node, PiNode) - id = node.val.id + id = (node.val::SSAValue).id node = insts[id][:inst] end if isa(node, Union{PhiNode, PhiCNode}) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 788f465a9d49d..903a3c5e871f1 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1827,7 +1827,8 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt) effect_free = true elseif f === getglobal && length(argtypes) >= 3 nothrow = getglobal_nothrow(argtypes[2:end]) - ipo_consistent = nothrow && isconst((argtypes[2]::Const).val, (argtypes[3]::Const).val) + ipo_consistent = nothrow && isconst( # types are already checked in `getglobal_nothrow` + (argtypes[2]::Const).val::Module, (argtypes[3]::Const).val::Symbol) effect_free = true else ipo_consistent = contains_is(_CONSISTENT_BUILTINS, f) diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index fb4a732692833..97e8a0cfa1d29 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -485,7 +485,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter) limited_ret = me.bestguess isa LimitedAccuracy limited_src = false if !limited_ret - gt = me.src.ssavaluetypes::Vector{Any} + gt = me.ssavaluetypes for j = 1:length(gt) gt[j] = gtj = cycle_fix_limited(gt[j], me) if gtj isa LimitedAccuracy && me.parent !== nothing @@ -510,9 +510,9 @@ function finish(me::InferenceState, interp::AbstractInterpreter) # annotate fulltree with type information, # either because we are the outermost code, or we might use this later doopt = (me.cached || me.parent !== nothing) - type_annotate!(me, doopt) + recompute_cfg = type_annotate!(me, doopt) if doopt && may_optimize(interp) - me.result.src = OptimizationState(me, OptimizationParams(interp), interp) + me.result.src = OptimizationState(me, OptimizationParams(interp), interp, recompute_cfg) else me.result.src = me.src::CodeInfo # stash a convenience copy of the code (e.g. for reflection) end @@ -568,31 +568,22 @@ function widen_all_consts!(src::CodeInfo) return src end -function widen_ssavaluetypes!(sv::InferenceState) - ssavaluetypes = sv.src.ssavaluetypes::Vector{Any} - for j = 1:length(ssavaluetypes) - t = ssavaluetypes[j] - ssavaluetypes[j] = t === NOT_FOUND ? Bottom : widenconditional(t) - end - return nothing -end - function record_slot_assign!(sv::InferenceState) # look at all assignments to slots # and union the set of types stored there # to compute a lower bound on the storage required - states = sv.stmt_types body = sv.src.code::Vector{Any} slottypes = sv.slottypes::Vector{Any} - ssavaluetypes = sv.src.ssavaluetypes::Vector{Any} + ssavaluetypes = sv.ssavaluetypes for i = 1:length(body) expr = body[i] - st_i = states[i] # find all reachable assignments to locals - if isa(st_i, VarTable) && isexpr(expr, :(=)) + if was_reached(sv, i) && isexpr(expr, :(=)) lhs = expr.args[1] if isa(lhs, SlotNumber) - vt = widenconst(ssavaluetypes[i]) + typ = ssavaluetypes[i] + @assert typ !== NOT_FOUND "active slot in unreached region" + vt = widenconst(typ) if vt !== Bottom id = slot_id(lhs) otherTy = slottypes[id] @@ -618,17 +609,21 @@ function record_bestguess!(sv::InferenceState) return nothing end -function annotate_slot_load!(undefs::Vector{Bool}, vtypes::VarTable, sv::InferenceState, - @nospecialize x) +function annotate_slot_load!(undefs::Vector{Bool}, idx::Int, sv::InferenceState, @nospecialize x) if isa(x, SlotNumber) id = slot_id(x) - vt = vtypes[id] - if vt.undef - # mark used-undef variables - undefs[id] = true + pc = find_dominating_assignment(id, idx, sv) + if pc === nothing + block = block_for_inst(sv.cfg, idx) + state = sv.bb_vartables[block]::VarTable + vt = state[id] + undefs[id] |= vt.undef + typ = widenconditional(ignorelimited(vt.typ)) + else + typ = sv.ssavaluetypes[pc] + @assert typ !== NOT_FOUND "active slot in unreached region" end # add type annotations where needed - typ = widenconditional(ignorelimited(vt.typ)) if !(sv.slottypes[id] ⊑ typ) return TypedSlot(id, typ) end @@ -643,21 +638,35 @@ function annotate_slot_load!(undefs::Vector{Bool}, vtypes::VarTable, sv::Inferen i0 = 2 end for i = i0:length(x.args) - x.args[i] = annotate_slot_load!(undefs, vtypes, sv, x.args[i]) + x.args[i] = annotate_slot_load!(undefs, idx, sv, x.args[i]) end return x elseif isa(x, ReturnNode) && isdefined(x, :val) - return ReturnNode(annotate_slot_load!(undefs, vtypes, sv, x.val)) + return ReturnNode(annotate_slot_load!(undefs, idx, sv, x.val)) elseif isa(x, GotoIfNot) - return GotoIfNot(annotate_slot_load!(undefs, vtypes, sv, x.cond), x.dest) + return GotoIfNot(annotate_slot_load!(undefs, idx, sv, x.cond), x.dest) end return x end +# find the dominating assignment to the slot `id` in the block containing statement `idx`, +# returns `nothing` otherwise +function find_dominating_assignment(id::Int, idx::Int, sv::InferenceState) + block = block_for_inst(sv.cfg, idx) + for pc in reverse(sv.cfg.blocks[block].stmts) # N.B. reverse since the last assignement is dominating this block + pc < idx || continue # N.B. needs pc ≠ idx as `id` can be assigned at `idx` + stmt = sv.src.code[pc] + isexpr(stmt, :(=)) || continue + lhs = stmt.args[1] + isa(lhs, SlotNumber) || continue + slot_id(lhs) == id || continue + return pc + end + return nothing +end + # annotate types of all symbols in AST function type_annotate!(sv::InferenceState, run_optimizer::Bool) - widen_ssavaluetypes!(sv) - # compute the required type for each slot # to hold all of the items assigned into it record_slot_assign!(sv) @@ -667,68 +676,55 @@ function type_annotate!(sv::InferenceState, run_optimizer::Bool) # annotate variables load types # remove dead code optimization # and compute which variables may be used undef - states = sv.stmt_types stmt_info = sv.stmt_info src = sv.src - body = src.code::Vector{Any} + body = src.code nexpr = length(body) codelocs = src.codelocs - ssavaluetypes = src.ssavaluetypes + ssavaluetypes = sv.ssavaluetypes ssaflags = src.ssaflags slotflags = src.slotflags nslots = length(slotflags) undefs = fill(false, nslots) - # eliminate GotoIfNot if either of branch target is unreachable - if run_optimizer - for idx = 1:nexpr - stmt = body[idx] - if isa(stmt, GotoIfNot) && widenconst(argextype(stmt.cond, src, sv.sptypes)) === Bool - # replace live GotoIfNot with: - # - GotoNode if the fallthrough target is unreachable - # - no-op if the branch target is unreachable - if states[idx+1] === nothing - body[idx] = GotoNode(stmt.dest) - elseif states[stmt.dest] === nothing - body[idx] = nothing + # this statement traversal does five things: + # 1. introduce temporary `TypedSlot`s that are supposed to be replaced with π-nodes later + # 2. mark used-undef slots (required by the `slot2reg` conversion) + # 3. mark unreached statements for a bulk code deletion (see issue #7836) + # 4. widen `Conditional`s and remove `NOT_FOUND` from `ssavaluetypes` + # NOTE because of this, `was_reached` will no longer be available after this point + # 5. eliminate GotoIfNot if either branch target is unreachable + changemap = nothing # initialized if there is any dead region + for i = 1:nexpr + expr = body[i] + if was_reached(sv, i) + if run_optimizer + if isa(expr, GotoIfNot) && widenconst(argextype(expr.cond, src, sv.sptypes)) === Bool + # 5: replace this live GotoIfNot with: + # - GotoNode if the fallthrough target is unreachable + # - no-op if the branch target is unreachable + if !was_reached(sv, i+1) + expr = GotoNode(expr.dest) + elseif !was_reached(sv, expr.dest) + expr = nothing + end end end - end - end - - # dead code elimination for unreachable regions - i = 1 - oldidx = 0 - changemap = fill(0, nexpr) - while i <= nexpr - oldidx += 1 - st_i = states[i] - expr = body[i] - if isa(st_i, VarTable) - # introduce temporary TypedSlot for the later optimization passes - # and also mark used-undef slots - body[i] = annotate_slot_load!(undefs, st_i, sv, expr) - else # unreached statement (see issue #7836) - if is_meta_expr(expr) - # keep any lexically scoped expressions + body[i] = annotate_slot_load!(undefs, i, sv, expr) # 1&2 + ssavaluetypes[i] = widenconditional(ssavaluetypes[i]) # 4 + else # i.e. any runtime execution will never reach this statement + if is_meta_expr(expr) # keep any lexically scoped expressions + ssavaluetypes[i] = Any # 4 elseif run_optimizer - deleteat!(body, i) - deleteat!(states, i) - deleteat!(ssavaluetypes, i) - deleteat!(codelocs, i) - deleteat!(stmt_info, i) - deleteat!(ssaflags, i) - nexpr -= 1 - changemap[oldidx] = -1 - continue + if changemap === nothing + changemap = fill(0, nexpr) + end + changemap[i] = -1 # 3&4: mark for the bulk deletion else + ssavaluetypes[i] = Bottom # 4 body[i] = Const(expr) # annotate that this statement actually is dead end end - i += 1 - end - if run_optimizer - renumber_ir_elements!(body, changemap) end # finish marking used-undef variables @@ -737,7 +733,20 @@ function type_annotate!(sv::InferenceState, run_optimizer::Bool) slotflags[j] |= SLOT_USEDUNDEF | SLOT_STATICUNDEF end end - nothing + + # do the bulk deletion of unreached statements + if changemap !== nothing + inds = Int[i for (i,v) in enumerate(changemap) if v == -1] + deleteat!(body, inds) + deleteat!(ssavaluetypes, inds) + deleteat!(codelocs, inds) + deleteat!(stmt_info, inds) + deleteat!(ssaflags, inds) + renumber_ir_elements!(body, changemap) + return true + else + return false + end end # at the end, all items in b's cycle diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 235a52fac168a..e9be7db755d48 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -376,29 +376,18 @@ widenwrappedconditional(typ::LimitedAccuracy) = LimitedAccuracy(widenconditional ignorelimited(@nospecialize typ) = typ ignorelimited(typ::LimitedAccuracy) = typ.typ -function stupdate!(state::Nothing, changes::StateUpdate) - newst = copy(changes.state) - changeid = slot_id(changes.var) - newst[changeid] = changes.vtype - # remove any Conditional for this slot from the vtable - # (unless this change is came from the conditional) - if !changes.conditional - for i = 1:length(newst) - newtype = newst[i] - if isa(newtype, VarState) - newtypetyp = ignorelimited(newtype.typ) - if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid - newtypetyp = widenwrappedconditional(newtype.typ) - newst[i] = VarState(newtypetyp, newtype.undef) - end - end - end +# remove any Conditional for this slot from the vartable +function invalidate_conditional(vt::VarState, changeid::Int) + newtyp = ignorelimited(vt.typ) + if isa(newtyp, Conditional) && slot_id(newtyp.var) == changeid + newtyp = widenwrappedconditional(vt.typ) + return VarState(newtyp, vt.undef) end - return newst + return nothing end function stupdate!(state::VarTable, changes::StateUpdate) - newstate = nothing + changed = false changeid = slot_id(changes.var) for i = 1:length(state) if i == changeid @@ -406,57 +395,41 @@ function stupdate!(state::VarTable, changes::StateUpdate) else newtype = changes.state[i] end - oldtype = state[i] - # remove any Conditional for this slot from the vtable - # (unless this change is came from the conditional) - if !changes.conditional && isa(newtype, VarState) - newtypetyp = ignorelimited(newtype.typ) - if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid - newtypetyp = widenwrappedconditional(newtype.typ) - newtype = VarState(newtypetyp, newtype.undef) + if !changes.conditional + invalidated = invalidate_conditional(newtype, changeid) + if invalidated !== nothing + newtype = invalidated end end + oldtype = state[i] if schanged(newtype, oldtype) - newstate = state state[i] = smerge(oldtype, newtype) + changed = true end end - return newstate + return changed end function stupdate!(state::VarTable, changes::VarTable) - newstate = nothing + changed = false for i = 1:length(state) newtype = changes[i] oldtype = state[i] if schanged(newtype, oldtype) - newstate = state state[i] = smerge(oldtype, newtype) + changed = true end end - return newstate + return changed end -stupdate!(state::Nothing, changes::VarTable) = copy(changes) - -stupdate!(state::Nothing, changes::Nothing) = nothing - function stupdate1!(state::VarTable, change::StateUpdate) changeid = slot_id(change.var) - # remove any Conditional for this slot from the catch block vtable - # (unless this change is came from the conditional) if !change.conditional for i = 1:length(state) - oldtype = state[i] - if isa(oldtype, VarState) - oldtypetyp = ignorelimited(oldtype.typ) - if isa(oldtypetyp, Conditional) && slot_id(oldtypetyp.var) == changeid - oldtypetyp = widenconditional(oldtypetyp) - if oldtype.typ isa LimitedAccuracy - oldtypetyp = LimitedAccuracy(oldtypetyp, (oldtype.typ::LimitedAccuracy).causes) - end - state[i] = VarState(oldtypetyp, oldtype.undef) - end + invalidated = invalidate_conditional(state[i], changeid) + if invalidated !== nothing + state[i] = invalidated end end end @@ -469,3 +442,26 @@ function stupdate1!(state::VarTable, change::StateUpdate) end return false end + +function stoverwrite!(state::VarTable, newstate::VarTable) + for i = 1:length(state) + state[i] = newstate[i] + end + return state +end + +function stoverwrite1!(state::VarTable, change::StateUpdate) + changeid = slot_id(change.var) + if !change.conditional + for i = 1:length(state) + invalidated = invalidate_conditional(state[i], changeid) + if invalidated !== nothing + state[i] = invalidated + end + end + end + # and update the type of it + newtype = change.vtype + state[changeid] = newtype + return state +end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index cedca856a9561..c8d3be16a0d2f 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1940,19 +1940,22 @@ function foo25261() next = f25261(Core.getfield(next, 2)) end end -opt25261 = code_typed(foo25261, Tuple{}, optimize=false)[1].first.code -i = 1 -# Skip to after the branch -while !isa(opt25261[i], GotoIfNot); global i += 1; end -foundslot = false -for expr25261 in opt25261[i:end] - if expr25261 isa TypedSlot && expr25261.typ === Tuple{Int, Int} - # This should be the assignment to the SSAValue into the getfield - # call - make sure it's a TypedSlot - global foundslot = true +let opt25261 = code_typed(foo25261, Tuple{}, optimize=false)[1].first.code + i = 1 + # Skip to after the branch + while !isa(opt25261[i], GotoIfNot) + i += 1 + end + foundslot = false + for expr25261 in opt25261[i:end] + if expr25261 isa TypedSlot && expr25261.typ === Tuple{Int, Int} + # This should be the assignment to the SSAValue into the getfield + # call - make sure it's a TypedSlot + foundslot = true + end end + @test foundslot end -@test foundslot @testset "inter-procedural conditional constraint propagation" begin # simple cases @@ -4134,6 +4137,11 @@ end |> !Core.Compiler.is_concrete_eval_eligible entry_to_be_invalidated('a') end +# control flow backedge should taint `terminates` +@test Base.infer_effects((Int,)) do n + for i = 1:n; end +end |> !Core.Compiler.is_terminates + # Nothrow for assignment to globals global glob_assign_int::Int = 0 f_glob_assign_int() = global glob_assign_int += 1 @@ -4188,3 +4196,13 @@ let effects = Base.infer_effects(f_setfield_nothrow, ()) #@test Core.Compiler.is_effect_free(effects) @test Core.Compiler.is_nothrow(effects) end + +# check the inference convergence with an empty vartable: +# the inference state for the toplevel chunk below will have an empty vartable, +# and so we may fail to terminate (or optimize) it if we don't update vartables correctly +let # NOTE make sure this toplevel chunk doesn't contain any local binding + Base.Experimental.@force_compile + global xcond::Bool = false + while xcond end +end +@test !xcond