Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement a better statement selection logic #654

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ using Core: Builtin, IntrinsicFunction, Intrinsics, SimpleVector, svec
using Core.IR

using .CC: @nospecs, ⊑,
AbstractInterpreter, AbstractLattice, ArgInfo, Bottom, CFG, CachedMethodTable, CallMeta,
ConstCallInfo, InferenceParams, InferenceResult, InferenceState, InternalMethodTable,
InvokeCallInfo, MethodCallResult, MethodMatchInfo, MethodMatches, NOT_FOUND,
OptimizationState, OptimizationParams, OverlayMethodTable, StmtInfo, UnionSplitInfo,
UnionSplitMethodMatches, VarState, VarTable, WorldRange, WorldView,
argextype, argtype_by_index, argtypes_to_type, hasintersect, ignorelimited,
instanceof_tfunc, istopfunction, singleton_type, slot_id, specialize_method,
tmeet, tmerge, typeinf_lattice, widenconst, widenlattice
AbstractInterpreter, AbstractLattice, ArgInfo, BasicBlock, Bottom, CFG, CachedMethodTable,
CallMeta, ConstCallInfo, InferenceParams, InferenceResult, InferenceState,
InternalMethodTable, InvokeCallInfo, MethodCallResult, MethodMatchInfo, MethodMatches,
NOT_FOUND, OptimizationState, OptimizationParams, OverlayMethodTable, StmtInfo,
UnionSplitInfo, UnionSplitMethodMatches, VarState, VarTable, WorldRange, WorldView,
argextype, argtype_by_index, argtypes_to_type, compute_basic_blocks, construct_domtree,
construct_postdomtree, hasintersect, ignorelimited, instanceof_tfunc, istopfunction,
nearest_common_dominator, singleton_type, slot_id, specialize_method, tmeet, tmerge,
typeinf_lattice, widenconst, widenlattice

using Base: IdSet, get_world_counter

Expand Down
234 changes: 172 additions & 62 deletions src/toplevel/virtualprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1091,16 +1091,44 @@ end

# select statements that should be concretized, and actually interpreted rather than abstracted
function select_statements(mod::Module, src::CodeInfo)
stmts = src.code
cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`?
edges = LoweredCodeUtils.CodeEdges(src, cl)

concretize = falses(length(stmts))

select_direct_requirement!(concretize, stmts, edges)

concretize = falses(length(src.code))
select_direct_requirement!(concretize, src.code, edges)
select_dependencies!(concretize, src, edges, cl)
return concretize
end

# just for testing, and debugging
function select_statements(mod::Module, src::CodeInfo, names::Symbol...)
cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`?
edges = LoweredCodeUtils.CodeEdges(src, cl)
concretize = falses(length(src.code))
objs = Set{GlobalRef}(GlobalRef(mod, name) for name in names)
LoweredCodeUtils.add_requests!(concretize, objs, edges, ())
select_dependencies!(concretize, src, edges, cl)
return concretize
end
function select_statements(mod::Module, src::CodeInfo, slots::SlotNumber...)
cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`?
edges = LoweredCodeUtils.CodeEdges(src, cl)
concretize = falses(length(src.code))
for slot in slots
for d in cl.slotassigns[slot.id]
concretize[d] = true
end
end
select_dependencies!(concretize, src, edges, cl)
return concretize
end
function select_statements(mod::Module, src::CodeInfo, idxs::Int...)
cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`?
edges = LoweredCodeUtils.CodeEdges(src, cl)
concretize = falses(length(src.code))
for idx = idxs
concretize[idx] |= true
end
select_dependencies!(concretize, src, edges, cl)
return concretize
end

Expand Down Expand Up @@ -1173,66 +1201,41 @@ end

# The goal of this function is to request concretization of the minimal necessary control
# flow to evaluate statements whose concretization have already been requested.
# The basic approach is to check if there are any active successors for each basic block,
# and if there is an active successor and the terminator is not a fall-through, then request
# the concretization of that terminator. Additionally, for conditional terminators, a simple
# optimization using post-domination analysis is also performed.
function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree)
# The basic algorithm is based on what was proposed in [^Wei84]. If there is even one active
# block in the blocks reachable from a conditional branch up to its successors' nearest
# common post-dominator (referred to as 𝑰𝑵𝑭𝑳 in the paper), it is necessary to follow
# that conditional branch and execute the code. Otherwise, execution can be short-circuited
# from the conditional branch to the nearest common post-dominator.
#
# COMBAK: It is important to note that in Julia's intermediate code representation (`CodeInfo`),
# "short-circuiting" a specific code region is not a simple task. Simply ignoring the path
# to the post-dominator does not guarantee fall-through to the post-dominator. Therefore,
# a more careful implementation is required for this aspect.
#
# [Wei84]: M. Weiser, "Program Slicing," IEEE Transactions on Software Engineering, 10, pages 352-357, July 1984.
function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, domtree, postdomtree)
local changed::Bool = false
function mark_concretize!(idx::Int)
if !concretize[idx]
concretize[idx] = true
changed |= concretize[idx] = true
return true
end
return false
end
nblocks = length(cfg.blocks)
for bbidx = 1:nblocks
bb = cfg.blocks[bbidx] # forward traversal
for bbidx = 1:length(cfg.blocks) # forward traversal
bb = cfg.blocks[bbidx]
nsuccs = length(bb.succs)
if nsuccs == 0
continue
elseif nsuccs == 1
terminator_idx = bb.stmts[end]
if src.code[terminator_idx] isa GotoNode
# If the destination of this `GotoNode` is not active, it's fine to ignore
# the control flow caused by this `GotoNode` and treat it as a fall-through.
# If the block that is fallen through to is active and has a dependency on
# this goto block, then the concretization of this goto block should already
# be requested (at some point of the higher concretization convergence cycle
# of `select_dependencies`), and thus, this `GotoNode` will be concretized.
if any(@view concretize[cfg.blocks[only(bb.succs)].stmts])
changed |= mark_concretize!(terminator_idx)
end
end
continue # leave a fall-through terminator unmarked: `GotoNode`s are marked later
elseif nsuccs == 2
terminator_idx = bb.stmts[end]
@assert is_conditional_terminator(src.code[terminator_idx]) "invalid IR"
succ1, succ2 = bb.succs
succ1_req = any(@view concretize[cfg.blocks[succ1].stmts])
succ2_req = any(@view concretize[cfg.blocks[succ2].stmts])
if succ1_req
if succ2_req
changed |= mark_concretize!(terminator_idx)
else
active_bb, inactive_bb = succ1, succ2
@goto asymmetric_case
end
elseif succ2_req
active_bb, inactive_bb = succ2, succ1
@label asymmetric_case
# We can ignore the control flow of this conditional terminator and treat
# it as a fall-through if only one of its successors is active and the
# active block post-dominates the inactive one, since the post-domination
# ensures that the active basic block will be reached regardless of the
# control flow.
if CC.postdominates(postdomtree, active_bb, inactive_bb)
# fall through this block
else
changed |= mark_concretize!(terminator_idx)
end
termidx = bb.stmts[end]
@assert is_conditional_terminator(src.code[termidx]) "invalid IR"
if is_conditional_block_active(concretize, bb, cfg, postdomtree)
mark_concretize!(termidx)
else
# both successors are inactive, just fall through this block
# fall-through to the post dominator block (by short-circuiting all statements between)
end
end
end
Expand All @@ -1242,6 +1245,46 @@ end
is_conditional_terminator(@nospecialize stmt) = stmt isa GotoIfNot ||
(@static @isdefined(EnterNode) ? stmt isa EnterNode : isexpr(stmt, :enter))

function is_conditional_block_active(concretize::BitVector, bb::BasicBlock, cfg::CFG, postdomtree)
return visit_𝑰𝑵𝑭𝑳_blocks(bb, cfg, postdomtree) do postdominator::Int, 𝑰𝑵𝑭𝑳::BitSet
for blk in 𝑰𝑵𝑭𝑳
if blk == postdominator
continue # skip the post-dominator block and continue to a next infl block
end
if any(@view concretize[cfg.blocks[blk].stmts])
return true
end
end
return false
end
end

function visit_𝑰𝑵𝑭𝑳_blocks(func, bb::BasicBlock, cfg::CFG, postdomtree)
succ1, succ2 = bb.succs
postdominator = nearest_common_dominator(postdomtree, succ1, succ2)
inflblks = reachable_blocks(cfg, succ1, postdominator) ∪ reachable_blocks(cfg, succ2, postdominator)
return func(postdominator, inflblks)
end

function reachable_blocks(cfg::CFG, from_bb::Int, to_bb::Int)
worklist = Int[from_bb]
visited = BitSet(from_bb)
if to_bb == from_bb
return visited
end
push!(visited, to_bb)
function visit!(bb::Int)
if bb ∉ visited
push!(visited, bb)
push!(worklist, bb)
end
end
while !isempty(worklist)
foreach(visit!, cfg.blocks[pop!(worklist)].succs)
end
return visited
end

function add_required_inplace!(concretize::BitVector, src::CodeInfo, edges, cl)
changed = false
for i = 1:length(src.code)
Expand Down Expand Up @@ -1272,31 +1315,98 @@ function is_arg_requested(@nospecialize(arg), concretize, edges, cl)
return false
end

# The purpose of this function is to find other statements that affect the execution of the
# statements choosen by `select_direct_dependencies!`. In other words, it extracts the
# minimal amount of code necessary to realize the required concretization.
# This technique is generally referred to as "program slicing," and specifically as
# "static program slicing". The basic algorithm implemented here is an extension of the one
# proposed in https://dl.acm.org/doi/10.5555/800078.802557, which is especially tuned for
# Julia's intermediate code representation.
function select_dependencies!(concretize::BitVector, src::CodeInfo, edges, cl)
typedefs = LoweredCodeUtils.find_typedefs(src)
cfg = CC.compute_basic_blocks(src.code)
postdomtree = CC.construct_postdomtree(cfg.blocks)
cfg = compute_basic_blocks(src.code)
domtree = construct_domtree(cfg.blocks)
postdomtree = construct_postdomtree(cfg.blocks)

while true
changed = false

# discover struct/method definitions at the beginning,
# and propagate the definition requirements by tracking SSA precedessors
# Discover Dtruct/method definitions at the beginning,
# and propagate the definition requirements by tracking SSA precedessors.
# (TODO maybe hoist this out of the loop?)
changed |= LoweredCodeUtils.add_typedefs!(concretize, src, edges, typedefs, ())
changed |= add_ssa_preds!(concretize, src, edges, ())

# mark some common inplace operations like `push!(x, ...)` and `setindex!(x, ...)`
# when `x` has been marked already: otherwise we may end up using it with invalid state
# Mark some common inplace operations like `push!(x, ...)` and `setindex!(x, ...)`
# when `x` has been marked already: otherwise we may end up using it with invalid state.
# However, note that this is an incomplete approach, and note that the slice created
# by this routine will not be sound because of this. This is because
# `add_required_inplace!` only requires certain special-cased function calls and
# does not take into account the possibility that arguments may be mutated in
# arbitrary function calls. Ideally, function calls should be required while
# considering the effects of these statements, or by some sort of an
# inter-procedural program slicing
changed |= add_required_inplace!(concretize, src, edges, cl)
changed |= add_ssa_preds!(concretize, src, edges, ())

# mark necessary control flows,
# and propagate the definition requirements by tracking SSA precedessors
changed |= add_control_flow!(concretize, src, cfg, postdomtree)
# Mark necessary control flows.
changed |= add_control_flow!(concretize, src, cfg, domtree, postdomtree)
changed |= add_ssa_preds!(concretize, src, edges, ())

changed || break
end

# now mark the active goto nodes
add_active_gotos!(concretize, src, cfg, postdomtree)

nothing
end

function add_active_gotos!(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree)
dead_blocks = compute_dead_blocks(concretize, src, cfg, postdomtree)
changed = false
for bbidx = 1:length(cfg.blocks)
if bbidx ∉ dead_blocks
bb = cfg.blocks[bbidx]
nsuccs = length(bb.succs)
if nsuccs == 1
termidx = bb.stmts[end]
if src.code[termidx] isa GotoNode
changed |= concretize[termidx] = true
end
end
end
end
return changed
end

# find dead blocks using the same approach as `add_control_flow!`, for the converged `concretize`
function compute_dead_blocks(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree)
dead_blocks = BitSet()
for bbidx = 1:length(cfg.blocks)
bb = cfg.blocks[bbidx]
nsuccs = length(bb.succs)
if nsuccs == 2
termidx = bb.stmts[end]
@assert is_conditional_terminator(src.code[termidx]) "invalid IR"
visit_𝑰𝑵𝑭𝑳_blocks(bb, cfg, postdomtree) do postdominator::Int, 𝑰𝑵𝑭𝑳::BitSet
is_active_inflblks = false
for blk in 𝑰𝑵𝑭𝑳
if blk == postdominator
continue # skip the post-dominator block and continue to a next infl block
end
if any(@view concretize[cfg.blocks[blk].stmts])
is_active_inflblks |= true
break
end
end
if !is_active_inflblks
union!(dead_blocks, delete!(𝑰𝑵𝑭𝑳, postdominator))
end
end
end
end
return dead_blocks
end

function JuliaInterpreter.step_expr!(interp::ConcreteInterpreter, frame::Frame, @nospecialize(node), istoplevel::Bool)
Expand Down
Loading
Loading