From 851ff9edb20faab477f948ea99a6bca8bec20ff6 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Tue, 9 May 2023 06:49:53 +0000 Subject: [PATCH] irinterp: Consider cfg information from discovered errors If we infer a call to `Union{}`, we can terminate further abstract interpretation. However, this of course also means that we can make use of that information to refine the types of any phis that may have originated from the basic block containing the call that was refined to `Union{}`. --- base/compiler/ssair/irinterp.jl | 106 +++++++++++++++++++++----------- test/compiler/inference.jl | 33 ++++++++++ 2 files changed, 103 insertions(+), 36 deletions(-) diff --git a/base/compiler/ssair/irinterp.jl b/base/compiler/ssair/irinterp.jl index d58d18b188757..ad6077fa48859 100644 --- a/base/compiler/ssair/irinterp.jl +++ b/base/compiler/ssair/irinterp.jl @@ -58,6 +58,49 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRIn return RTEffects(rt, effects) end +function update_phi!(irsv::IRInterpretationState, from::Int, to::Int) + ir = irsv.ir + if length(ir.cfg.blocks[to].preds) == 0 + # Kill the entire block + for bidx = ir.cfg.blocks[to].stmts + ir.stmts[bidx][:inst] = nothing + ir.stmts[bidx][:type] = Bottom + ir.stmts[bidx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW + end + return + end + for sidx = ir.cfg.blocks[to].stmts + sinst = ir.stmts[sidx][:inst] + isa(sinst, Nothing) && continue # allowed between `PhiNode`s + isa(sinst, PhiNode) || break + for (eidx, edge) in enumerate(sinst.edges) + if edge == from + deleteat!(sinst.edges, eidx) + deleteat!(sinst.values, eidx) + push!(irsv.ssa_refined, sidx) + break + end + end + end +end +update_phi!(irsv::IRInterpretationState) = (from::Int, to::Int)->update_phi!(irsv, from, to) + +function kill_terminator_edges!(irsv::IRInterpretationState, term_idx::Int, bb::Int=block_for_inst(irsv.ir, term_idx)) + ir = irsv.ir + inst = ir[SSAValue(term_idx)][:inst] + if isa(inst, GotoIfNot) + kill_edge!(ir, bb, inst.dest, update_phi!(irsv)) + kill_edge!(ir, bb, bb+1, update_phi!(irsv)) + elseif isa(inst, GotoNode) + kill_edge!(ir, bb, inst.label, update_phi!(irsv)) + elseif isa(inst, ReturnNode) + # Nothing to do + else + @assert !isexpr(inst, :enter) + kill_edge!(ir, bb, bb+1, update_phi!(irsv)) + end +end + function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union{Int,Nothing}, @nospecialize(inst), @nospecialize(typ), irsv::IRInterpretationState, extra_reprocess::Union{Nothing,BitSet,BitSetBoundedMinPrioritySet}) @@ -66,30 +109,6 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union cond = inst.cond condval = maybe_extract_const_bool(argextype(cond, ir)) if condval isa Bool - function update_phi!(from::Int, to::Int) - if length(ir.cfg.blocks[to].preds) == 0 - # Kill the entire block - for bidx = ir.cfg.blocks[to].stmts - ir.stmts[bidx][:inst] = nothing - ir.stmts[bidx][:type] = Bottom - ir.stmts[bidx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW - end - return - end - for sidx = ir.cfg.blocks[to].stmts - sinst = ir.stmts[sidx][:inst] - isa(sinst, Nothing) && continue # allowed between `PhiNode`s - isa(sinst, PhiNode) || break - for (eidx, edge) in enumerate(sinst.edges) - if edge == from - deleteat!(sinst.edges, eidx) - deleteat!(sinst.values, eidx) - push!(irsv.ssa_refined, sidx) - break - end - end - end - end if isa(cond, SSAValue) kill_def_use!(irsv.tpdum, cond, idx) end @@ -100,10 +119,10 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union if condval ir.stmts[idx][:inst] = nothing ir.stmts[idx][:type] = Any - kill_edge!(ir, bb, inst.dest, update_phi!) + kill_edge!(ir, bb, inst.dest, update_phi!(irsv)) else ir.stmts[idx][:inst] = GotoNode(inst.dest) - kill_edge!(ir, bb, bb+1, update_phi!) + kill_edge!(ir, bb, bb+1, update_phi!(irsv)) end return true end @@ -123,9 +142,6 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union rt, nothrow = concrete_eval_invoke(interp, inst, inst.args[1]::MethodInstance, irsv) if nothrow ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW - if isa(rt, Const) && is_inlineable_constant(rt.val) - ir.stmts[idx][:inst] = quoted(rt.val) - end end elseif head === :throw_undef_if_not || # TODO: Terminate interpretation early if known false? head === :gc_preserve_begin || @@ -148,9 +164,17 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union else error("reprocess_instruction!: unhandled instruction found") end - if rt !== nothing && !⊑(typeinf_lattice(interp), typ, rt) - ir.stmts[idx][:type] = rt - return true + if rt !== nothing + if isa(rt, Const) + ir.stmts[idx][:type] = rt + if is_inlineable_constant(rt.val) + ir.stmts[idx][:inst] = quoted(rt.val) + end + return true + elseif !⊑(typeinf_lattice(interp), typ, rt) + ir.stmts[idx][:type] = rt + return true + end end return false end @@ -227,12 +251,22 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR any_refined = true delete!(ssa_refined, idx) end - if any_refined && reprocess_instruction!(interp, - idx, bb, inst, typ, irsv, extra_reprocess) - push!(ssa_refined, idx) + did_reprocess = false + if any_refined + did_reprocess = reprocess_instruction!(interp, + idx, bb, inst, typ, irsv, extra_reprocess) + if did_reprocess + push!(ssa_refined, idx) + inst = ir.stmts[idx][:inst] + typ = ir.stmts[idx][:type] + end + end + if idx == lstmt + process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) && @goto residual_scan + (isa(inst, GotoNode) || isa(inst, GotoIfNot) || isa(inst, ReturnNode) || isexpr(inst, :enter)) && continue end - idx == lstmt && process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) && @goto residual_scan if typ === Bottom && !isa(inst, PhiNode) + kill_terminator_edges!(irsv, lstmt, bb) break end end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 1b137d1d8f661..5987e10401bc8 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4871,3 +4871,36 @@ function nt_splat_partial(x::Int) Val{tuple(nt...)[2]}() end @test @inferred(nt_splat_partial(42)) == Val{2}() + +# Test that irinterp refines based on discovered errors +Base.@assume_effects :foldable Base.@constprop :aggressive function kill_error_edge(b1, b2, xs, x) + y = b1 ? "julia" : xs[] + if b2 + a = length(y) + else + a = sin(y) + end + a + x +end + +Base.@assume_effects :foldable Base.@constprop :aggressive function kill_error_edge(b1, b2, xs, ys, x) + y = b1 ? xs[] : ys[] + if b2 + a = length(y) + else + a = sin(y) + end + a + x +end + +let src = code_typed1((Bool,Base.RefValue{Any},Int,)) do b2, xs, x + kill_error_edge(true, b2, xs, x) + end + @test count(@nospecialize(x)->isa(x, Core.PhiNode), src.code) == 0 +end + +let src = code_typed1((Bool,Base.RefValue{String}, Base.RefValue{Any},Int,)) do b2, xs, ys, x + kill_error_edge(true, b2, xs, ys, x) + end + @test count(@nospecialize(x)->isa(x, Core.PhiNode), src.code) == 0 +end