From 759684883346ab0e7d2ff15676df3517d3dc30e1 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:35:07 +0900 Subject: [PATCH] implement a better statement selection logic (#654) Specifically, this commit aims to review the implementation of `add_control_flow!` and improves its accuracy. Ideally, it should pass JET's existing test cases as well as the newly added ones, including the test cases from JuliaDebug/LoweredCodeUtils.jl#99. The goal is to share the same high-precision CFG selection logic between LoweredCodeUtils and JET. The new algorithm is based on what was proposed in [^Wei84]. If there is even one active block in the blocks reachable from a conditional branch up to its successors' nearest common post-dominator (referred to as **INFL** in the paper), it is necessary to follow that conditional branch and execute the code. Otherwise, execution can be short-circuited from the conditional branch to the nearest common post-dominator. COMBAK: It is important to note that in Julia's IR (`CodeInfo`), "short-circuiting" a specific code region is not a simple task. Simply ignoring the path to the post-dominator does not guarantee fall-through to the post-dominator. Therefore, a more careful implementation is required for this aspect. [Wei84]: M. Weiser, "Program Slicing," IEEE Transactions on Software Engineering, 10, pages 352-357, July 1984. --- src/JET.jl | 17 +- src/toplevel/virtualprocess.jl | 234 ++++++++++++++++++++------- test/toplevel/test_virtualprocess.jl | 120 ++++++++++++-- 3 files changed, 292 insertions(+), 79 deletions(-) diff --git a/src/JET.jl b/src/JET.jl index 7677f7ef0..cf0d11da1 100644 --- a/src/JET.jl +++ b/src/JET.jl @@ -36,14 +36,15 @@ using Core: Builtin, IntrinsicFunction, Intrinsics, SimpleVector, svec using Core.IR using .CC: @nospecs, ⊑, - AbstractInterpreter, AbstractLattice, ArgInfo, Bottom, CFG, CachedMethodTable, CallMeta, - ConstCallInfo, InferenceParams, InferenceResult, InferenceState, InternalMethodTable, - InvokeCallInfo, MethodCallResult, MethodMatchInfo, MethodMatches, NOT_FOUND, - OptimizationState, OptimizationParams, OverlayMethodTable, StmtInfo, UnionSplitInfo, - UnionSplitMethodMatches, VarState, VarTable, WorldRange, WorldView, - argextype, argtype_by_index, argtypes_to_type, hasintersect, ignorelimited, - instanceof_tfunc, istopfunction, singleton_type, slot_id, specialize_method, - tmeet, tmerge, typeinf_lattice, widenconst, widenlattice + AbstractInterpreter, AbstractLattice, ArgInfo, BasicBlock, Bottom, CFG, CachedMethodTable, + CallMeta, ConstCallInfo, InferenceParams, InferenceResult, InferenceState, + InternalMethodTable, InvokeCallInfo, MethodCallResult, MethodMatchInfo, MethodMatches, + NOT_FOUND, OptimizationState, OptimizationParams, OverlayMethodTable, StmtInfo, + UnionSplitInfo, UnionSplitMethodMatches, VarState, VarTable, WorldRange, WorldView, + argextype, argtype_by_index, argtypes_to_type, compute_basic_blocks, construct_domtree, + construct_postdomtree, hasintersect, ignorelimited, instanceof_tfunc, istopfunction, + nearest_common_dominator, singleton_type, slot_id, specialize_method, tmeet, tmerge, + typeinf_lattice, widenconst, widenlattice using Base: IdSet, get_world_counter diff --git a/src/toplevel/virtualprocess.jl b/src/toplevel/virtualprocess.jl index 6def602c2..b49e9dee2 100644 --- a/src/toplevel/virtualprocess.jl +++ b/src/toplevel/virtualprocess.jl @@ -1091,16 +1091,44 @@ end # select statements that should be concretized, and actually interpreted rather than abstracted function select_statements(mod::Module, src::CodeInfo) - stmts = src.code cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`? edges = LoweredCodeUtils.CodeEdges(src, cl) - - concretize = falses(length(stmts)) - - select_direct_requirement!(concretize, stmts, edges) - + concretize = falses(length(src.code)) + select_direct_requirement!(concretize, src.code, edges) select_dependencies!(concretize, src, edges, cl) + return concretize +end +# just for testing, and debugging +function select_statements(mod::Module, src::CodeInfo, names::Symbol...) + cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`? + edges = LoweredCodeUtils.CodeEdges(src, cl) + concretize = falses(length(src.code)) + objs = Set{GlobalRef}(GlobalRef(mod, name) for name in names) + LoweredCodeUtils.add_requests!(concretize, objs, edges, ()) + select_dependencies!(concretize, src, edges, cl) + return concretize +end +function select_statements(mod::Module, src::CodeInfo, slots::SlotNumber...) + cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`? + edges = LoweredCodeUtils.CodeEdges(src, cl) + concretize = falses(length(src.code)) + for slot in slots + for d in cl.slotassigns[slot.id] + concretize[d] = true + end + end + select_dependencies!(concretize, src, edges, cl) + return concretize +end +function select_statements(mod::Module, src::CodeInfo, idxs::Int...) + cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`? + edges = LoweredCodeUtils.CodeEdges(src, cl) + concretize = falses(length(src.code)) + for idx = idxs + concretize[idx] |= true + end + select_dependencies!(concretize, src, edges, cl) return concretize end @@ -1173,66 +1201,41 @@ end # The goal of this function is to request concretization of the minimal necessary control # flow to evaluate statements whose concretization have already been requested. -# The basic approach is to check if there are any active successors for each basic block, -# and if there is an active successor and the terminator is not a fall-through, then request -# the concretization of that terminator. Additionally, for conditional terminators, a simple -# optimization using post-domination analysis is also performed. -function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree) +# The basic algorithm is based on what was proposed in [^Wei84]. If there is even one active +# block in the blocks reachable from a conditional branch up to its successors' nearest +# common post-dominator (referred to as 𝑰𝑵𝑭𝑳 in the paper), it is necessary to follow +# that conditional branch and execute the code. Otherwise, execution can be short-circuited +# from the conditional branch to the nearest common post-dominator. +# +# COMBAK: It is important to note that in Julia's intermediate code representation (`CodeInfo`), +# "short-circuiting" a specific code region is not a simple task. Simply ignoring the path +# to the post-dominator does not guarantee fall-through to the post-dominator. Therefore, +# a more careful implementation is required for this aspect. +# +# [Wei84]: M. Weiser, "Program Slicing," IEEE Transactions on Software Engineering, 10, pages 352-357, July 1984. +function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, domtree, postdomtree) local changed::Bool = false function mark_concretize!(idx::Int) if !concretize[idx] - concretize[idx] = true + changed |= concretize[idx] = true return true end return false end - nblocks = length(cfg.blocks) - for bbidx = 1:nblocks - bb = cfg.blocks[bbidx] # forward traversal + for bbidx = 1:length(cfg.blocks) # forward traversal + bb = cfg.blocks[bbidx] nsuccs = length(bb.succs) if nsuccs == 0 continue elseif nsuccs == 1 - terminator_idx = bb.stmts[end] - if src.code[terminator_idx] isa GotoNode - # If the destination of this `GotoNode` is not active, it's fine to ignore - # the control flow caused by this `GotoNode` and treat it as a fall-through. - # If the block that is fallen through to is active and has a dependency on - # this goto block, then the concretization of this goto block should already - # be requested (at some point of the higher concretization convergence cycle - # of `select_dependencies`), and thus, this `GotoNode` will be concretized. - if any(@view concretize[cfg.blocks[only(bb.succs)].stmts]) - changed |= mark_concretize!(terminator_idx) - end - end + continue # leave a fall-through terminator unmarked: `GotoNode`s are marked later elseif nsuccs == 2 - terminator_idx = bb.stmts[end] - @assert is_conditional_terminator(src.code[terminator_idx]) "invalid IR" - succ1, succ2 = bb.succs - succ1_req = any(@view concretize[cfg.blocks[succ1].stmts]) - succ2_req = any(@view concretize[cfg.blocks[succ2].stmts]) - if succ1_req - if succ2_req - changed |= mark_concretize!(terminator_idx) - else - active_bb, inactive_bb = succ1, succ2 - @goto asymmetric_case - end - elseif succ2_req - active_bb, inactive_bb = succ2, succ1 - @label asymmetric_case - # We can ignore the control flow of this conditional terminator and treat - # it as a fall-through if only one of its successors is active and the - # active block post-dominates the inactive one, since the post-domination - # ensures that the active basic block will be reached regardless of the - # control flow. - if CC.postdominates(postdomtree, active_bb, inactive_bb) - # fall through this block - else - changed |= mark_concretize!(terminator_idx) - end + termidx = bb.stmts[end] + @assert is_conditional_terminator(src.code[termidx]) "invalid IR" + if is_conditional_block_active(concretize, bb, cfg, postdomtree) + mark_concretize!(termidx) else - # both successors are inactive, just fall through this block + # fall-through to the post dominator block (by short-circuiting all statements between) end end end @@ -1242,6 +1245,46 @@ end is_conditional_terminator(@nospecialize stmt) = stmt isa GotoIfNot || (@static @isdefined(EnterNode) ? stmt isa EnterNode : isexpr(stmt, :enter)) +function is_conditional_block_active(concretize::BitVector, bb::BasicBlock, cfg::CFG, postdomtree) + return visit_𝑰𝑵𝑭𝑳_blocks(bb, cfg, postdomtree) do postdominator::Int, 𝑰𝑵𝑭𝑳::BitSet + for blk in 𝑰𝑵𝑭𝑳 + if blk == postdominator + continue # skip the post-dominator block and continue to a next infl block + end + if any(@view concretize[cfg.blocks[blk].stmts]) + return true + end + end + return false + end +end + +function visit_𝑰𝑵𝑭𝑳_blocks(func, bb::BasicBlock, cfg::CFG, postdomtree) + succ1, succ2 = bb.succs + postdominator = nearest_common_dominator(postdomtree, succ1, succ2) + inflblks = reachable_blocks(cfg, succ1, postdominator) ∪ reachable_blocks(cfg, succ2, postdominator) + return func(postdominator, inflblks) +end + +function reachable_blocks(cfg::CFG, from_bb::Int, to_bb::Int) + worklist = Int[from_bb] + visited = BitSet(from_bb) + if to_bb == from_bb + return visited + end + push!(visited, to_bb) + function visit!(bb::Int) + if bb ∉ visited + push!(visited, bb) + push!(worklist, bb) + end + end + while !isempty(worklist) + foreach(visit!, cfg.blocks[pop!(worklist)].succs) + end + return visited +end + function add_required_inplace!(concretize::BitVector, src::CodeInfo, edges, cl) changed = false for i = 1:length(src.code) @@ -1272,31 +1315,98 @@ function is_arg_requested(@nospecialize(arg), concretize, edges, cl) return false end +# The purpose of this function is to find other statements that affect the execution of the +# statements choosen by `select_direct_dependencies!`. In other words, it extracts the +# minimal amount of code necessary to realize the required concretization. +# This technique is generally referred to as "program slicing," and specifically as +# "static program slicing". The basic algorithm implemented here is an extension of the one +# proposed in https://dl.acm.org/doi/10.5555/800078.802557, which is especially tuned for +# Julia's intermediate code representation. function select_dependencies!(concretize::BitVector, src::CodeInfo, edges, cl) typedefs = LoweredCodeUtils.find_typedefs(src) - cfg = CC.compute_basic_blocks(src.code) - postdomtree = CC.construct_postdomtree(cfg.blocks) + cfg = compute_basic_blocks(src.code) + domtree = construct_domtree(cfg.blocks) + postdomtree = construct_postdomtree(cfg.blocks) while true changed = false - # discover struct/method definitions at the beginning, - # and propagate the definition requirements by tracking SSA precedessors + # Discover Dtruct/method definitions at the beginning, + # and propagate the definition requirements by tracking SSA precedessors. + # (TODO maybe hoist this out of the loop?) changed |= LoweredCodeUtils.add_typedefs!(concretize, src, edges, typedefs, ()) changed |= add_ssa_preds!(concretize, src, edges, ()) - # mark some common inplace operations like `push!(x, ...)` and `setindex!(x, ...)` - # when `x` has been marked already: otherwise we may end up using it with invalid state + # Mark some common inplace operations like `push!(x, ...)` and `setindex!(x, ...)` + # when `x` has been marked already: otherwise we may end up using it with invalid state. + # However, note that this is an incomplete approach, and note that the slice created + # by this routine will not be sound because of this. This is because + # `add_required_inplace!` only requires certain special-cased function calls and + # does not take into account the possibility that arguments may be mutated in + # arbitrary function calls. Ideally, function calls should be required while + # considering the effects of these statements, or by some sort of an + # inter-procedural program slicing changed |= add_required_inplace!(concretize, src, edges, cl) changed |= add_ssa_preds!(concretize, src, edges, ()) - # mark necessary control flows, - # and propagate the definition requirements by tracking SSA precedessors - changed |= add_control_flow!(concretize, src, cfg, postdomtree) + # Mark necessary control flows. + changed |= add_control_flow!(concretize, src, cfg, domtree, postdomtree) changed |= add_ssa_preds!(concretize, src, edges, ()) changed || break end + + # now mark the active goto nodes + add_active_gotos!(concretize, src, cfg, postdomtree) + + nothing +end + +function add_active_gotos!(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree) + dead_blocks = compute_dead_blocks(concretize, src, cfg, postdomtree) + changed = false + for bbidx = 1:length(cfg.blocks) + if bbidx ∉ dead_blocks + bb = cfg.blocks[bbidx] + nsuccs = length(bb.succs) + if nsuccs == 1 + termidx = bb.stmts[end] + if src.code[termidx] isa GotoNode + changed |= concretize[termidx] = true + end + end + end + end + return changed +end + +# find dead blocks using the same approach as `add_control_flow!`, for the converged `concretize` +function compute_dead_blocks(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree) + dead_blocks = BitSet() + for bbidx = 1:length(cfg.blocks) + bb = cfg.blocks[bbidx] + nsuccs = length(bb.succs) + if nsuccs == 2 + termidx = bb.stmts[end] + @assert is_conditional_terminator(src.code[termidx]) "invalid IR" + visit_𝑰𝑵𝑭𝑳_blocks(bb, cfg, postdomtree) do postdominator::Int, 𝑰𝑵𝑭𝑳::BitSet + is_active_inflblks = false + for blk in 𝑰𝑵𝑭𝑳 + if blk == postdominator + continue # skip the post-dominator block and continue to a next infl block + end + if any(@view concretize[cfg.blocks[blk].stmts]) + is_active_inflblks |= true + break + end + end + if !is_active_inflblks + union!(dead_blocks, delete!(𝑰𝑵𝑭𝑳, postdominator)) + end + end + end + end + return dead_blocks end function JuliaInterpreter.step_expr!(interp::ConcreteInterpreter, frame::Frame, @nospecialize(node), istoplevel::Bool) diff --git a/test/toplevel/test_virtualprocess.jl b/test/toplevel/test_virtualprocess.jl index 5e7a0144e..dabea6a87 100644 --- a/test/toplevel/test_virtualprocess.jl +++ b/test/toplevel/test_virtualprocess.jl @@ -1713,17 +1713,17 @@ end # this particular example is adapted from https://en.wikipedia.org/wiki/Program_slicing let src = @src let N = 10 - sum = 0 + s = 0 product = 1 # should NOT be selected w = 7 for i in 1:N - sum += i + w + s += i + w product *= i # should NOT be selected end - @eval global getsum() = $sum # concretization is forced write(product) # should NOT be selected end - slice = JET.select_statements(@__MODULE__, src) + slotid = findfirst(n::Symbol->n===:s, src.slotnames)::Int + slice = JET.select_statements(@__MODULE__, src, Core.SlotNumber(slotid)) found_N = found_sum = found_product = found_w = found_write = false for (i, stmt) in enumerate(src.code) @@ -1733,7 +1733,7 @@ end if src.slotnames[lhs.id] === :w found_w = true @test slice[i] - elseif src.slotnames[lhs.id] === :sum + elseif src.slotnames[lhs.id] === :s found_sum = true @test slice[i] elseif src.slotnames[lhs.id] === :N @@ -1748,7 +1748,7 @@ end found_write = true @test !slice[i] elseif (JET.isexpr(stmt, :call) && (arg1 = stmt.args[1]; arg1 isa Core.SSAValue) && - src.code[arg1.id] === :write) + src.code[arg1.id] === :write) found_write = true @test !slice[i] end @@ -1759,14 +1759,14 @@ end redirect_stdout(io) do vmod, res = @analyze_toplevel2 let N = 10 - sum = 0 + s = 0 product = 1 # should NOT be selected w = 7 for i in 1:N - sum += i + w + s += i + w product *= i # should NOT be selected end - @eval global getsum() = $sum # concretization is forced + @eval global getsum() = $s # concretization is forced println("This should not be printed: ", product) # should NOT be selected end @test isempty(res.res.toplevel_error_reports) @@ -1778,6 +1778,108 @@ end @test isempty(s) end + # A more complex test case (xref: https://github.com/JuliaDebug/LoweredCodeUtils.jl/pull/99#issuecomment-2236373067) + # This test case might seem simple at first glance, but note that `x2` and `a2` are + # defined at the top level (because of the `begin` at the top). + # Since global variable type declarations have been allowed since v1. + # 10, a conditional branch that includes `Core.get_binding_type` is generated for + # these simple global variable assignments. + # Specifically, the code is lowered into something like this: + # 1 1: conditional branching based on `x2`'s binding type + # │╲ + # │ ╲ + # │ ╲ 2: goto block for the case when no conversion is required for the value of `x2` + # 2 3 3: fall-through block for the case when a conversion is required for the value of `x2` + # │ ╱ + # │ ╱ + # │╱ + # 4 4: assignment to `x2`, **and** + # │╲ conditional branching based on `a2`'s binding type + # │ ╲ + # │ ╲ 5: goto block for the case when no conversion is required for the value of `a2` + # 5 6 6: fall-through block for the case when a conversion is required for the value of `a2` + # │ ╱ + # │ ╱ + # │╱ + # 7 7: assignment to `a2` + # What's important to note here is that since there's an assignment to `a2`, + # concretization of the blocks 4-6 is necessary. However, at the same time we also want + # to skip concretizing the blocks 1-3. + let src = @src begin + x2 = 5 + a2 = 1 + end + slice = JET.select_statements(@__MODULE__, src, :a2) + + found_a2 = found_a2_get_binding_type = found_x2 = found_x2_get_binding_type = false + for (i, stmt) in enumerate(src.code) + if JET.isexpr(stmt, :(=)) + lhs, rhs = stmt.args + if lhs isa GlobalRef + lhs = lhs.name + end + if lhs === :a2 + found_a2 = true + @test slice[i] + elseif lhs === :x2 + found_x2 = true + @test !slice[i] # this is easy to meet + end + elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :a2)) + found_a2_get_binding_type = true + @test slice[i] + elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :x2)) + found_x2_get_binding_type = true + @test !slice[i] # this is difficult to meet + end + end + @test found_a2; @test found_a2_get_binding_type; @test found_x2; @test found_x2_get_binding_type + end + let src = @src begin + cond = true + if cond + x = 1 + y = 1 + else + x = 2 + y = 2 + end + end + slice = JET.select_statements(@__MODULE__, src, :x) + + found_cond = found_cond_get_binding_type = false + found_x = found_x_get_binding_type = found_y = found_y_get_binding_type = 0 + for (i, stmt) in enumerate(src.code) + if JET.isexpr(stmt, :(=)) + lhs, rhs = stmt.args + if lhs isa GlobalRef + lhs = lhs.name + end + if lhs === :cond + found_cond = true + @test slice[i] + elseif lhs === :x + found_x += 1 + @test slice[i] + elseif lhs === :y + found_y += 1 + @test !slice[i] + end + elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :cond)) + found_cond_get_binding_type = true + @test slice[i] + elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :x)) + found_x_get_binding_type += 1 + @test slice[i] + elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :y)) + found_y_get_binding_type += 1 + @test !slice[i] + end + end + @test found_cond; @test found_cond_get_binding_type + @test found_x == found_x_get_binding_type == found_y == found_y_get_binding_type == 2 + end + @testset "captured variables" begin let (vmod, res) = @analyze_toplevel2 begin begin