Skip to content

Commit

Permalink
wip: keep track of SSA where Conditional is formed and invalidate i…
Browse files Browse the repository at this point in the history
…t on use
  • Loading branch information
aviatesk committed Aug 23, 2024
1 parent c8ba3be commit 2224949
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 68 deletions.
78 changes: 38 additions & 40 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ function propagate_conditional(rt::InterConditional, cond::Conditional)
new_elsetype = rt.elsetype === Const(true) ? cond.thentype : cond.elsetype
if rt.thentype == Bottom
@assert rt.elsetype != Bottom
return Conditional(cond.slot, Bottom, new_elsetype)
return Conditional(cond.slot, Bottom, new_elsetype, cond.from_ssa)
elseif rt.elsetype == Bottom
@assert rt.thentype != Bottom
return Conditional(cond.slot, new_thentype, Bottom)
return Conditional(cond.slot, new_thentype, Bottom, cond.from_ssa)
end
return Conditional(cond.slot, new_thentype, new_elsetype)
return Conditional(cond.slot, new_thentype, new_elsetype, cond.from_ssa)
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
Expand Down Expand Up @@ -512,7 +512,7 @@ function from_interconditional(𝕃ᵢ::AbstractLattice, @nospecialize(rt), sv::
if alias !== nothing
return form_mustalias_conditional(alias, thentype, elsetype)
end
return Conditional(slot, thentype, elsetype) # record a Conditional improvement to this slot
return Conditional(slot, thentype, elsetype, #=from_ssa=#sv.currpc) # record a Conditional improvement to this slot
end
return widenconditional(rt)
end
Expand Down Expand Up @@ -1430,7 +1430,7 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
# TODO bail out here immediately rather than just propagating Bottom ?
given_argtypes[i] = Bottom
else
given_argtypes[i] = Conditional(slotid, thentype, elsetype)
given_argtypes[i] = Conditional(slotid, thentype, elsetype, #=from_ssa=#0)
end
continue
end
Expand Down Expand Up @@ -1929,7 +1929,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
if isa(a, SlotNumber)
cndt = isa_condition(a2, a3, InferenceParams(interp).max_union_splitting, rt)
if cndt !== nothing
return Conditional(a, cndt.thentype, cndt.elsetype)
return Conditional(a, cndt.thentype, cndt.elsetype, #=from_ssa=#sv.currpc)
end
end
if isa(a2, MustAlias)
Expand All @@ -1947,7 +1947,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
# !(x isa T) implies !(Type{a2} <: T)
# TODO: complete splitting, based on which portions of the Union a3 for which isa_tfunc returns Const(true) or Const(false) instead of Bool
elsetype = typesubtract(a3, Type{widenconst(a2)}, InferenceParams(interp).max_union_splitting)
return Conditional(b, a3, elsetype)
return Conditional(b, a3, elsetype, #=from_ssa=#sv.currpc)
end
end
elseif f === (===)
Expand All @@ -1959,15 +1959,15 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
if isa(aty, Const)
if isa(b, SlotNumber)
cndt = egal_condition(aty, bty, InferenceParams(interp).max_union_splitting, rt)
return Conditional(b, cndt.thentype, cndt.elsetype)
return Conditional(b, cndt.thentype, cndt.elsetype, #=from_ssa=#sv.currpc)
elseif isa(bty, MustAlias) && !isa(rt, Const) # skip refinement when the field is known precisely (just optimization)
cndt = egal_condition(aty, bty.fldtyp, InferenceParams(interp).max_union_splitting)
return form_mustalias_conditional(bty, cndt.thentype, cndt.elsetype)
end
elseif isa(bty, Const)
if isa(a, SlotNumber)
cndt = egal_condition(bty, aty, InferenceParams(interp).max_union_splitting, rt)
return Conditional(a, cndt.thentype, cndt.elsetype)
return Conditional(a, cndt.thentype, cndt.elsetype, #=from_ssa=#sv.currpc)
elseif isa(aty, MustAlias) && !isa(rt, Const) # skip refinement when the field is known precisely (just optimization)
cndt = egal_condition(bty, aty.fldtyp, InferenceParams(interp).max_union_splitting)
return form_mustalias_conditional(aty, cndt.thentype, cndt.elsetype)
Expand Down Expand Up @@ -1998,18 +1998,18 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
if isa(b, SlotNumber)
thentype = rt === Const(false) ? Bottom : widenslotwrapper(bty)
elsetype = rt === Const(true) ? Bottom : widenslotwrapper(bty)
return Conditional(b, thentype, elsetype)
return Conditional(b, thentype, elsetype, #=from_ssa=#sv.currpc)
elseif isa(a, SlotNumber)
thentype = rt === Const(false) ? Bottom : widenslotwrapper(aty)
elsetype = rt === Const(true) ? Bottom : widenslotwrapper(aty)
return Conditional(a, thentype, elsetype)
return Conditional(a, thentype, elsetype, #=from_ssa=#sv.currpc)
end
elseif f === Core.Compiler.not_int
aty = argtypes[2]
if isa(aty, Conditional)
thentype = rt === Const(false) ? Bottom : aty.elsetype
elsetype = rt === Const(true) ? Bottom : aty.thentype
return Conditional(aty.slot, thentype, elsetype)
return Conditional(aty.slot, thentype, elsetype, #=from_ssa=#sv.currpc)
end
elseif f === isdefined
a = ssa_def_slot(fargs[2], sv)
Expand All @@ -2032,7 +2032,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
elsetype = elsetype ty
end
end
return Conditional(a, thentype, elsetype)
return Conditional(a, thentype, elsetype, #=from_ssa=#sv.currpc)
else
thentype = form_partially_defined_struct(argtype2, argtypes[3])
if thentype !== nothing
Expand All @@ -2042,7 +2042,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
elseif rt === Const(true)
elsetype = Bottom
end
return Conditional(a, thentype, elsetype)
return Conditional(a, thentype, elsetype, #=from_ssa=#sv.currpc)
end
end
end
Expand Down Expand Up @@ -2307,7 +2307,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Any, Any]), si, Tuple{typeof(f), Any, Any}, sv, max_methods)
rty = abstract_call_known(interp, (===), arginfo, si, sv, max_methods).rt
if isa(rty, Conditional)
return CallMeta(Conditional(rty.slot, rty.elsetype, rty.thentype), Bottom, EFFECTS_TOTAL, NoCallInfo()) # swap if-else
newrty = Conditional(rty.slot, rty.elsetype, rty.thentype, #=from_ssa=#sv.currpc)
return CallMeta(newrty, Bottom, EFFECTS_TOTAL, NoCallInfo()) # swap if-else
elseif isa(rty, Const)
return CallMeta(Const(rty.val === false), Bottom, EFFECTS_TOTAL, MethodResultPure())
end
Expand Down Expand Up @@ -2552,7 +2553,7 @@ struct RTEffects
end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState)
unused = call_result_unused(sv, sv.currpc)
unused = call_result_unused(sv, #=from_ssa=#sv.currpc)
if unused
add_curr_ssaflag!(sv, IR_FLAG_UNUSED)
end
Expand Down Expand Up @@ -2695,7 +2696,7 @@ function abstract_eval_new_opaque_closure(interp::AbstractInterpreter, e::Expr,
rt = widenconst(rt)
# Propagation of PartialOpaque disabled
end
if isa(rt, PartialOpaque) && isa(sv, InferenceState) && !call_result_unused(sv, sv.currpc)
if isa(rt, PartialOpaque) && isa(sv, InferenceState) && !call_result_unused(sv, #=from_ssa=#sv.currpc)
# Infer this now so that the specialization is available to
# optimization.
argtypes = most_general_argtypes(rt)
Expand Down Expand Up @@ -3221,6 +3222,7 @@ end
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
changes = StateUpdate(lhs, VarState(rt, false), false)
setassignment!(frame.slotassignments, slot_id(lhs), frame.currpc)
elseif isa(lhs, GlobalRef)
handle_global_assignment!(interp, frame, lhs, rt)
elseif !isa(lhs, SSAValue)
Expand Down Expand Up @@ -3357,7 +3359,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

states = frame.bb_vartables
currstate = copy(states[currbb]::VarTable)
slotwrapperssas = BitSet()
while currbb <= nbbs
delete!(W, currbb)
bbstart = first(bbs[currbb].stmts)
Expand Down Expand Up @@ -3391,7 +3392,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condslot, SlotNumber)
# if this non-`Conditional` object is a slot, we form and propagate
# the conditional constraint on it
condt = Conditional(condslot, Const(true), Const(false))
condt = Conditional(condslot, Const(true), Const(false), #=from_ssa=#frame.currpc)
end
condval = maybe_extract_const_bool(condt)
nothrow = (condval !== nothing) || (𝕃ᵢ, orig_condt, Bool)
Expand Down Expand Up @@ -3436,7 +3437,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

# We continue with the true branch, but process the false
# branch here.
if isa(condt, Conditional)
if isa(condt, Conditional) && is_valid_conditional(condt, currpc, frame)
else_change = conditional_change(𝕃ᵢ, currstate, condt, #=then_or_else=#false)
if else_change !== nothing
elsestate = copy(currstate)
Expand Down Expand Up @@ -3475,14 +3476,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
return caller.ssavaluetypes[caller_pc] !== Any
end
end
ssavaluetypes[frame.currpc] = Any
ssavaluetypes[currpc] = Any
@goto find_next_bb
elseif isa(stmt, EnterNode)
ssavaluetypes[currpc] = Any
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
if isdefined(stmt, :scope)
scopet = abstract_eval_value(interp, stmt.scope, currstate, frame)
handler = gethandler(frame, frame.currpc+1)::TryCatchFrame
handler = gethandler(frame, currpc+1)::TryCatchFrame
@assert handler.scopet !== nothing
if !(𝕃ᵢ, scopet, handler.scopet)
handler.scopet = tmerge(𝕃ᵢ, scopet, handler.scopet)
Expand Down Expand Up @@ -3521,9 +3522,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
if changes !== nothing
stoverwrite1!(currstate, changes)
# widen any slot wrapper types that should be invalidated by this change
# just like what's done for `currstate`
invalidate_ssa_slotwrapper!(ssavaluetypes, slotwrapperssas, slot_id(changes.var))
end
if refinements isa SlotRefinement
apply_refinement!(𝕃ᵢ, refinements.slot, refinements.typ, currstate, changes)
Expand All @@ -3538,7 +3536,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
ssavaluetypes[currpc] = Any
continue
end
record_ssa_assign!(𝕃ᵢ, currpc, rt, frame, slotwrapperssas)
record_ssa_assign!(𝕃ᵢ, currpc, rt, frame)
end # for currpc in bbstart:bbend

# Case 1: Fallthrough termination
Expand Down Expand Up @@ -3585,6 +3583,18 @@ function apply_refinement!(𝕃ᵢ::AbstractLattice, slot::SlotNumber, @nospecia
end
end

function is_valid_conditional(condt::Conditional, use_ssa::Int, sv::InferenceState)
domtree = get!(sv.lazydomtree)
dominates_ssa(sv.cfg, domtree, condt.from_ssa, use_ssa) ||
condt.from_ssa == use_ssa ||
return false
return all(getassignment(sv.slotassignments, condt.slot)) do aidx::Int
return (aidx == condt.from_ssa == 0 ||
dominates_ssa(sv.cfg, domtree, aidx, condt.from_ssa) ||
dominates_ssa(sv.cfg, domtree, use_ssa, aidx))
end
end

function conditional_change(𝕃ᵢ::AbstractLattice, currstate::VarTable, condt::Conditional, then_or_else::Bool)
vtype = currstate[condt.slot]
oldtyp = vtype.typ
Expand Down Expand Up @@ -3612,23 +3622,11 @@ function condition_object_change(currstate::VarTable, condt::Conditional,
vtype = currstate[slot_id(condslot)]
newcondt = Conditional(condt.slot,
then_or_else ? condt.thentype : Union{},
then_or_else ? Union{} : condt.elsetype)
then_or_else ? Union{} : condt.elsetype,
#=from_ssa=#condt.from_ssa)
return StateUpdate(condslot, VarState(newcondt, vtype.undef), false)
end

# remove any lattice elements that wrap the reassigned slot object within `ssavaluetypes`
function invalidate_ssa_slotwrapper!(ssavaluetypes::Vector{Any}, slotwrapperssas::BitSet, changeid::Int)
for idx = slotwrapperssas
invalidate_ssa_slotwrapper!(ssavaluetypes, idx, changeid)
end
end
function invalidate_ssa_slotwrapper!(ssavaluetypes::Vector{Any}, idx::Int, changeid::Int)
typ = ssavaluetypes[idx]
if should_invalidate(typ, changeid)
ssavaluetypes[idx] = @noinline widenwrappedslotwrapper(typ)
end
end

# make as much progress on `frame` as possible (by handling cycles)
function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
typeinf_local(interp, frame)
Expand Down
34 changes: 23 additions & 11 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,16 @@ function get!(x::LazyCFGReachability)
end

mutable struct LazyGenericDomtree{IsPostDom}
ir::IRCode
cfg::CFG
domtree::GenericDomTree{IsPostDom}
LazyGenericDomtree{IsPostDom}(ir::IRCode) where {IsPostDom} = new{IsPostDom}(ir)
LazyGenericDomtree{IsPostDom}(cfg::CFG) where {IsPostDom} = new{IsPostDom}(cfg)
end
LazyGenericDomtree{IsPostDom}(ir::IRCode) where {IsPostDom} = LazyGenericDomtree{IsPostDom}(ir.cfg)
function get!(x::LazyGenericDomtree{IsPostDom}) where {IsPostDom}
isdefined(x, :domtree) && return x.domtree
return @timeit "domtree 2" x.domtree = IsPostDom ?
construct_postdomtree(x.ir) :
construct_domtree(x.ir)
construct_postdomtree(x.cfg) :
construct_domtree(x.cfg)
end

const LazyDomtree = LazyGenericDomtree{false}
Expand Down Expand Up @@ -227,6 +228,16 @@ struct HandlerInfo
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

struct SlotAssignments
assignments::Vector{BitSet}
SlotAssignments(nslots::Int) = new(Vector{BitSet}(undef, nslots))
end
const ARGUMENT_ASSIGNMENT = BitSet(0)
getassignment(sa::SlotAssignments, sidx::Int) = isassigned(sa.assignments, sidx) ?
sa.assignments[sidx] : ARGUMENT_ASSIGNMENT
setassignment!(sa::SlotAssignments, sidx::Int, pc::Int) = push!(isassigned(sa.assignments, sidx) ?
sa.assignments[sidx] : (sa.assignments[sidx] = BitSet()), pc)

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -249,6 +260,8 @@ mutable struct InferenceState
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Vector{Any}}
stmt_info::Vector{CallInfo}
slotassignments::SlotAssignments
lazydomtree::LazyDomtree

#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
Expand Down Expand Up @@ -308,6 +321,8 @@ mutable struct InferenceState
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
slotassignments = SlotAssignments(nslots)
lazydomtree = LazyDomtree(cfg)
argtypes = result.argtypes

argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva)
Expand All @@ -316,7 +331,7 @@ mutable struct InferenceState
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
if argtyp === Bool && has_conditional(typeinf_lattice(interp))
argtyp = Conditional(i, Const(true), Const(false))
argtyp = Conditional(i, Const(true), Const(false), #=from_ssa=#0)
end
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
Expand Down Expand Up @@ -350,7 +365,8 @@ mutable struct InferenceState

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes,
stmt_edges, stmt_info, slotassignments, lazydomtree,
pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
Expand Down Expand Up @@ -735,15 +751,11 @@ end
_topmod(sv::InferenceState) = _topmod(frame_module(sv))

function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize(new),
frame::InferenceState, slotwrapperssas::BitSet)
frame::InferenceState)
ssavaluetypes = frame.ssavaluetypes
old = ssavaluetypes[ssa_id]
if old === NOT_FOUND || !is_lattice_equal(𝕃ᵢ, new, old)
ssavaluetypes[ssa_id] = new
wnew = ignorelimited(new)
if new isa Conditional || new isa MustAlias
push!(slotwrapperssas, ssa_id)
end
W = frame.ip
for r in frame.ssavalue_uses[ssa_id]
if was_reached(frame, r)
Expand Down
9 changes: 9 additions & 0 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,15 @@ function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAV
return dominates(domtree, xb, yb)
end

function dominates_ssa(cfg::CFG, domtree::DomTree, x::Int, y::Int)
xb = block_for_inst(cfg, x)
yb = block_for_inst(cfg, y)
if xb == yb
return x < y
end
return dominates(domtree, xb, yb)
end

function _count_added_node!(compact::IncrementalCompact, @nospecialize(val))
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ end

function not_tfunc(𝕃::AbstractLattice, @nospecialize(b))
if isa(b, Conditional)
return Conditional(b.slot, b.elsetype, b.thentype)
return Conditional(b.slot, b.elsetype, b.thentype, b.from_ssa)
elseif isa(b, Const)
return Const(not_int(b.val))
end
Expand Down Expand Up @@ -350,14 +350,14 @@ end
if isa(x, Conditional)
y = widenconditional(y)
if isa(y, Const)
y.val === false && return Conditional(x.slot, x.elsetype, x.thentype)
y.val === false && return Conditional(x.slot, x.elsetype, x.thentype, x.from_ssa)
y.val === true && return x
return Const(false)
end
elseif isa(y, Conditional)
x = widenconditional(x)
if isa(x, Const)
x.val === false && return Conditional(y.slot, y.elsetype, y.thentype)
x.val === false && return Conditional(y.slot, y.elsetype, y.thentype, y.from_ssa)
x.val === true && return y
return Const(false)
end
Expand Down
Loading

0 comments on commit 2224949

Please sign in to comment.