diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index f07757eafc6e1..6224ed769e5e1 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1379,9 +1379,7 @@ function inline_const_if_inlineable!(inst::Instruction) end function assemble_inline_todo!(ir::IRCode, state::InliningState) - # todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie) todo = Pair{Int, Any}[] - et = state.et for idx in 1:length(ir.stmts) simpleres = process_simple!(ir, idx, state, todo) @@ -1586,6 +1584,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, end end end + isa(val, Union{SSAValue, NewSSAValue}) && return val # avoid infinite loop urs = userefs(val) for op in urs op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 2f1359e4002ae..770e2aa294db2 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -381,6 +381,9 @@ struct UndefToken end; const UNDEF_TOKEN = UndefToken() isdefined(stmt, :val) || return OOB_TOKEN op == 1 || return OOB_TOKEN return stmt.val + elseif isa(stmt, Union{SSAValue, NewSSAValue}) + op == 1 || return OOB_TOKEN + return stmt elseif isa(stmt, UpsilonNode) isdefined(stmt, :val) || return OOB_TOKEN op == 1 || return OOB_TOKEN @@ -430,6 +433,9 @@ end elseif isa(stmt, ReturnNode) op == 1 || throw(BoundsError()) stmt = typeof(stmt)(v) + elseif isa(stmt, Union{SSAValue, NewSSAValue}) + op == 1 || throw(BoundsError()) + stmt = v elseif isa(stmt, UpsilonNode) op == 1 || throw(BoundsError()) stmt = typeof(stmt)(v) @@ -457,7 +463,7 @@ end function userefs(@nospecialize(x)) relevant = (isa(x, Expr) && is_relevant_expr(x)) || - isa(x, GotoIfNot) || isa(x, ReturnNode) || + isa(x, GotoIfNot) || isa(x, ReturnNode) || isa(x, SSAValue) || isa(x, NewSSAValue) || isa(x, PiNode) || isa(x, PhiNode) || isa(x, PhiCNode) || isa(x, UpsilonNode) return UseRefIterator(x, relevant) end @@ -480,50 +486,10 @@ end # This function is used from the show code, which may have a different # `push!`/`used` type since it's in Base. -function scan_ssa_use!(push!, used, @nospecialize(stmt)) - if isa(stmt, SSAValue) - push!(used, stmt.id) - end - for useref in userefs(stmt) - val = useref[] - if isa(val, SSAValue) - push!(used, val.id) - end - end -end +scan_ssa_use!(push!, used, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt) # Manually specialized copy of the above with push! === Compiler.push! -function scan_ssa_use!(used::IdSet, @nospecialize(stmt)) - if isa(stmt, SSAValue) - push!(used, stmt.id) - end - for useref in userefs(stmt) - val = useref[] - if isa(val, SSAValue) - push!(used, val.id) - end - end -end - -function ssamap(f, @nospecialize(stmt)) - urs = userefs(stmt) - for op in urs - val = op[] - if isa(val, SSAValue) - op[] = f(val) - end - end - return urs[] -end - -function foreachssa(f, @nospecialize(stmt)) - for op in userefs(stmt) - val = op[] - if isa(val, SSAValue) - f(val) - end - end -end +scan_ssa_use!(used::IdSet, @nospecialize(stmt)) = foreachssa(ssa -> push!(used, ssa.id), stmt) function insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after::Bool=false) node = add!(ir.new_nodes, pos, attach_after) @@ -751,20 +717,13 @@ end function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) needs_late_fixup = false - if isa(v, SSAValue) - compact.used_ssas[v.id] += 1 - elseif isa(v, NewSSAValue) - compact.new_new_used_ssas[v.id] += 1 - needs_late_fixup = true - else - for ops in userefs(v) - val = ops[] - if isa(val, SSAValue) - compact.used_ssas[val.id] += 1 - elseif isa(val, NewSSAValue) - compact.new_new_used_ssas[val.id] += 1 - needs_late_fixup = true - end + for ops in userefs(v) + val = ops[] + if isa(val, SSAValue) + compact.used_ssas[val.id] += 1 + elseif isa(val, NewSSAValue) + compact.new_new_used_ssas[val.id] += 1 + needs_late_fixup = true end end return needs_late_fixup @@ -931,6 +890,27 @@ function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int) return compact end +__set_check_ssa_counts(onoff::Bool) = __check_ssa_counts__[] = onoff +const __check_ssa_counts__ = fill(false) + +function _oracle_check(compact::IncrementalCompact) + observed_used_ssas = Core.Compiler.find_ssavalue_uses1(compact) + for i = 1:length(observed_used_ssas) + if observed_used_ssas[i] != compact.used_ssas[i] + return observed_used_ssas + end + end + return nothing +end + +function oracle_check(compact::IncrementalCompact) + maybe_oracle_used_ssas = _oracle_check(compact) + if maybe_oracle_used_ssas !== nothing + @eval Main (compact = $compact; oracle_used_ssas = $maybe_oracle_used_ssas) + error("Oracle check failed, inspect Main.compact and Main.oracle_used_ssas") + end +end + getindex(view::TypesView, idx::SSAValue) = getindex(view, idx.id) function getindex(view::TypesView, idx::Int) if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx @@ -1425,7 +1405,6 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}= # result_idx is not, incremented, but that's ok and expected compact.result[old_result_idx] = compact.ir.stmts[idx] result_idx = process_node!(compact, old_result_idx, compact.ir.stmts[idx], idx, idx, active_bb, true) - stmt_if_any = old_result_idx == result_idx ? nothing : compact.result[old_result_idx][:inst] compact.result_idx = result_idx if idx == last(bb.stmts) && !attach_after_stmt_after(compact, idx) finish_current_bb!(compact, active_bb, old_result_idx) @@ -1464,11 +1443,7 @@ function maybe_erase_unused!( callback(val) end if effect_free - if isa(stmt, SSAValue) - kill_ssa_value(stmt) - else - foreachssa(kill_ssa_value, stmt) - end + foreachssa(kill_ssa_value, stmt) inst[:inst] = nothing return true end @@ -1570,6 +1545,9 @@ end function complete(compact::IncrementalCompact) result_bbs = resize!(compact.result_bbs, compact.active_result_bb-1) cfg = CFG(result_bbs, Int[first(result_bbs[i].stmts) for i in 2:length(result_bbs)]) + if __check_ssa_counts__[] + oracle_check(compact) + end return IRCode(compact.ir, compact.result, cfg, compact.new_new_nodes) end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index c2597363df282..3937141f0aa5e 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1151,15 +1151,6 @@ function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact end end -function count_uses(@nospecialize(stmt), uses::Vector{Int}) - for ur in userefs(stmt) - use = ur[] - if isa(use, SSAValue) - uses[use.id] += 1 - end - end -end - function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::SPCSet, phi::Int) worklist = Int[] push!(worklist, phi) diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index a5dd6a0fd8f29..1d6219448cf9c 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -72,9 +72,6 @@ function make_ssa!(ci::CodeInfo, code::Vector{Any}, idx, slot, @nospecialize(typ end function new_to_regular(@nospecialize(stmt), new_offset::Int) - if isa(stmt, NewSSAValue) - return SSAValue(stmt.id + new_offset) - end urs = userefs(stmt) for op in urs val = op[] diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index fe97b81c07e24..3a2dc6e00f7a3 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -228,6 +228,27 @@ end # SSAValues/Slots # ################### +function ssamap(f, @nospecialize(stmt)) + urs = userefs(stmt) + for op in urs + val = op[] + if isa(val, SSAValue) + op[] = f(val) + end + end + return urs[] +end + +function foreachssa(f, @nospecialize(stmt)) + urs = userefs(stmt) + for op in urs + val = op[] + if isa(val, SSAValue) + f(val) + end + end +end + function find_ssavalue_uses(body::Vector{Any}, nvals::Int) uses = BitSet[ BitSet() for i = 1:nvals ] for line in 1:length(body) @@ -333,6 +354,38 @@ end @inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id : isa(s, Argument) ? (s::Argument).n : (s::TypedSlot).id +###################### +# IncrementalCompact # +###################### + +# specifically meant to be used with body1 = compact.result and body2 = compact.new_new_nodes, with nvals == length(compact.used_ssas) +function find_ssavalue_uses1(compact) + body1, body2 = compact.result.inst, compact.new_new_nodes.stmts.inst + nvals = length(compact.used_ssas) + nbody1 = length(body1) + nbody2 = length(body2) + + uses = zeros(Int, nvals) + function increment_uses(ssa::SSAValue) + uses[ssa.id] += 1 + end + + for line in 1:(nbody1 + nbody2) + # index into the right body + if line <= nbody1 + isassigned(body1, line) || continue + e = body1[line] + else + line -= nbody1 + isassigned(body2, line) || continue + e = body2[line] + end + + foreachssa(increment_uses, e) + end + return uses +end + ########### # options # ########### diff --git a/test/compiler/ssair.jl b/test/compiler/ssair.jl index f1bd442e7f093..f74b5b80d3e35 100644 --- a/test/compiler/ssair.jl +++ b/test/compiler/ssair.jl @@ -3,7 +3,7 @@ using Base.Meta using Core.IR const Compiler = Core.Compiler -using .Compiler: CFG, BasicBlock +using .Compiler: CFG, BasicBlock, NewSSAValue make_bb(preds, succs) = BasicBlock(Compiler.StmtRange(0, 0), preds, succs) @@ -334,3 +334,66 @@ f_if_typecheck() = (if nothing; end; unsafe_load(Ptr{Int}(0))) stderr = IOBuffer() success(pipeline(Cmd(cmd); stdout=stdout, stderr=stderr)) && isempty(String(take!(stderr))) end + +let + function test_useref(stmt, v, op) + if isa(stmt, Expr) + @test stmt.args[op] === v + elseif isa(stmt, GotoIfNot) + @test stmt.cond === v + elseif isa(stmt, ReturnNode) || isa(stmt, UpsilonNode) + @test stmt.val === v + elseif isa(stmt, SSAValue) || isa(stmt, NewSSAValue) + @test stmt === v + elseif isa(stmt, PiNode) + @test stmt.val === v && stmt.typ === typeof(stmt) + elseif isa(stmt, PhiNode) || isa(stmt, PhiCNode) + @test stmt.values[op] === v + end + end + + function _test_userefs(@nospecialize stmt) + ex = Expr(:call, :+, Core.SSAValue(3), 1) + urs = Core.Compiler.userefs(stmt)::Core.Compiler.UseRefIterator + it = Core.Compiler.iterate(urs) + while it !== nothing + ur = getfield(it, 1)::Core.Compiler.UseRef + op = getfield(it, 2)::Int + v1 = Core.Compiler.getindex(ur) + # set to dummy expression and then back to itself to test `_useref_setindex!` + v2 = Core.Compiler.setindex!(ur, ex) + test_useref(v2, ex, op) + Core.Compiler.setindex!(ur, v1) + @test Core.Compiler.getindex(ur) === v1 + it = Core.Compiler.iterate(urs, op) + end + end + + function test_userefs(body) + for stmt in body + _test_userefs(stmt) + end + end + + # this isn't valid code, we just care about looking at a variety of IR nodes + body = Any[ + Expr(:enter, 11), + Expr(:call, :+, SSAValue(3), 1), + Expr(:throw_undef_if_not, :expected, false), + Expr(:leave, 1), + Expr(:(=), SSAValue(1), Expr(:call, :+, SSAValue(3), 1)), + UpsilonNode(), + UpsilonNode(SSAValue(2)), + PhiCNode(Any[SSAValue(5), SSAValue(7), SSAValue(9)]), + PhiCNode(Any[SSAValue(6)]), + PhiNode(Int32[8], Any[SSAValue(7)]), + PiNode(SSAValue(6), GotoNode), + GotoIfNot(SSAValue(3), 10), + GotoNode(5), + SSAValue(7), + NewSSAValue(9), + ReturnNode(SSAValue(11)), + ] + + test_userefs(body) +end