Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: invalidate stale slot wrapper types in ssavaluetypes #55551

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ function propagate_conditional(rt::InterConditional, cond::Conditional)
new_elsetype = rt.elsetype === Const(true) ? cond.thentype : cond.elsetype
if rt.thentype == Bottom
@assert rt.elsetype != Bottom
return Conditional(cond.slot, Bottom, new_elsetype)
return Conditional(cond.slot, Bottom, new_elsetype, cond.from_ssa)
elseif rt.elsetype == Bottom
@assert rt.thentype != Bottom
return Conditional(cond.slot, new_thentype, Bottom)
return Conditional(cond.slot, new_thentype, Bottom, cond.from_ssa)
end
return Conditional(cond.slot, new_thentype, new_elsetype)
return Conditional(cond.slot, new_thentype, new_elsetype, cond.from_ssa)
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
Expand Down Expand Up @@ -512,7 +512,7 @@ function from_interconditional(𝕃ᵢ::AbstractLattice, @nospecialize(rt), sv::
if alias !== nothing
return form_mustalias_conditional(alias, thentype, elsetype)
end
return Conditional(slot, thentype, elsetype) # record a Conditional improvement to this slot
return Conditional(slot, thentype, elsetype, #=from_ssa=#sv.currpc) # record a Conditional improvement to this slot
end
return widenconditional(rt)
end
Expand Down Expand Up @@ -1430,7 +1430,7 @@ function matching_cache_argtypes(𝕃::AbstractLattice, mi::MethodInstance,
# TODO bail out here immediately rather than just propagating Bottom ?
given_argtypes[i] = Bottom
else
given_argtypes[i] = Conditional(slotid, thentype, elsetype)
given_argtypes[i] = Conditional(slotid, thentype, elsetype, #=from_ssa=#0)
end
continue
end
Expand Down Expand Up @@ -1929,7 +1929,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
if isa(a, SlotNumber)
cndt = isa_condition(a2, a3, InferenceParams(interp).max_union_splitting, rt)
if cndt !== nothing
return Conditional(a, cndt.thentype, cndt.elsetype)
return Conditional(a, cndt.thentype, cndt.elsetype, #=from_ssa=#sv.currpc)
end
end
if isa(a2, MustAlias)
Expand All @@ -1947,7 +1947,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
# !(x isa T) implies !(Type{a2} <: T)
# TODO: complete splitting, based on which portions of the Union a3 for which isa_tfunc returns Const(true) or Const(false) instead of Bool
elsetype = typesubtract(a3, Type{widenconst(a2)}, InferenceParams(interp).max_union_splitting)
return Conditional(b, a3, elsetype)
return Conditional(b, a3, elsetype, #=from_ssa=#sv.currpc)
end
end
elseif f === (===)
Expand All @@ -1959,15 +1959,15 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
if isa(aty, Const)
if isa(b, SlotNumber)
cndt = egal_condition(aty, bty, InferenceParams(interp).max_union_splitting, rt)
return Conditional(b, cndt.thentype, cndt.elsetype)
return Conditional(b, cndt.thentype, cndt.elsetype, #=from_ssa=#sv.currpc)
elseif isa(bty, MustAlias) && !isa(rt, Const) # skip refinement when the field is known precisely (just optimization)
cndt = egal_condition(aty, bty.fldtyp, InferenceParams(interp).max_union_splitting)
return form_mustalias_conditional(bty, cndt.thentype, cndt.elsetype)
end
elseif isa(bty, Const)
if isa(a, SlotNumber)
cndt = egal_condition(bty, aty, InferenceParams(interp).max_union_splitting, rt)
return Conditional(a, cndt.thentype, cndt.elsetype)
return Conditional(a, cndt.thentype, cndt.elsetype, #=from_ssa=#sv.currpc)
elseif isa(aty, MustAlias) && !isa(rt, Const) # skip refinement when the field is known precisely (just optimization)
cndt = egal_condition(bty, aty.fldtyp, InferenceParams(interp).max_union_splitting)
return form_mustalias_conditional(aty, cndt.thentype, cndt.elsetype)
Expand Down Expand Up @@ -1998,18 +1998,18 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
if isa(b, SlotNumber)
thentype = rt === Const(false) ? Bottom : widenslotwrapper(bty)
elsetype = rt === Const(true) ? Bottom : widenslotwrapper(bty)
return Conditional(b, thentype, elsetype)
return Conditional(b, thentype, elsetype, #=from_ssa=#sv.currpc)
elseif isa(a, SlotNumber)
thentype = rt === Const(false) ? Bottom : widenslotwrapper(aty)
elsetype = rt === Const(true) ? Bottom : widenslotwrapper(aty)
return Conditional(a, thentype, elsetype)
return Conditional(a, thentype, elsetype, #=from_ssa=#sv.currpc)
end
elseif f === Core.Compiler.not_int
aty = argtypes[2]
if isa(aty, Conditional)
thentype = rt === Const(false) ? Bottom : aty.elsetype
elsetype = rt === Const(true) ? Bottom : aty.thentype
return Conditional(aty.slot, thentype, elsetype)
return Conditional(aty.slot, thentype, elsetype, #=from_ssa=#sv.currpc)
end
elseif f === isdefined
a = ssa_def_slot(fargs[2], sv)
Expand All @@ -2032,7 +2032,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
elsetype = elsetype ⊔ ty
end
end
return Conditional(a, thentype, elsetype)
return Conditional(a, thentype, elsetype, #=from_ssa=#sv.currpc)
else
thentype = form_partially_defined_struct(argtype2, argtypes[3])
if thentype !== nothing
Expand All @@ -2042,7 +2042,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
elseif rt === Const(true)
elsetype = Bottom
end
return Conditional(a, thentype, elsetype)
return Conditional(a, thentype, elsetype, #=from_ssa=#sv.currpc)
end
end
end
Expand Down Expand Up @@ -2307,7 +2307,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Any, Any]), si, Tuple{typeof(f), Any, Any}, sv, max_methods)
rty = abstract_call_known(interp, (===), arginfo, si, sv, max_methods).rt
if isa(rty, Conditional)
return CallMeta(Conditional(rty.slot, rty.elsetype, rty.thentype), Bottom, EFFECTS_TOTAL, NoCallInfo()) # swap if-else
newrty = Conditional(rty.slot, rty.elsetype, rty.thentype, #=from_ssa=#sv.currpc)
return CallMeta(newrty, Bottom, EFFECTS_TOTAL, NoCallInfo()) # swap if-else
elseif isa(rty, Const)
return CallMeta(Const(rty.val === false), Bottom, EFFECTS_TOTAL, MethodResultPure())
end
Expand Down Expand Up @@ -2552,7 +2553,7 @@ struct RTEffects
end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState)
unused = call_result_unused(sv, sv.currpc)
unused = call_result_unused(sv, #=from_ssa=#sv.currpc)
if unused
add_curr_ssaflag!(sv, IR_FLAG_UNUSED)
end
Expand Down Expand Up @@ -2695,7 +2696,7 @@ function abstract_eval_new_opaque_closure(interp::AbstractInterpreter, e::Expr,
rt = widenconst(rt)
# Propagation of PartialOpaque disabled
end
if isa(rt, PartialOpaque) && isa(sv, InferenceState) && !call_result_unused(sv, sv.currpc)
if isa(rt, PartialOpaque) && isa(sv, InferenceState) && !call_result_unused(sv, #=from_ssa=#sv.currpc)
# Infer this now so that the specialization is available to
# optimization.
argtypes = most_general_argtypes(rt)
Expand Down Expand Up @@ -3221,6 +3222,7 @@ end
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
changes = StateUpdate(lhs, VarState(rt, false), false)
setassignment!(frame.slotassignments, slot_id(lhs), frame.currpc)
elseif isa(lhs, GlobalRef)
handle_global_assignment!(interp, frame, lhs, rt)
elseif !isa(lhs, SSAValue)
Expand Down Expand Up @@ -3390,7 +3392,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condslot, SlotNumber)
# if this non-`Conditional` object is a slot, we form and propagate
# the conditional constraint on it
condt = Conditional(condslot, Const(true), Const(false))
condt = Conditional(condslot, Const(true), Const(false), #=from_ssa=#frame.currpc)
end
condval = maybe_extract_const_bool(condt)
nothrow = (condval !== nothing) || ⊑(𝕃ᵢ, orig_condt, Bool)
Expand Down Expand Up @@ -3435,7 +3437,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

# We continue with the true branch, but process the false
# branch here.
if isa(condt, Conditional)
if isa(condt, Conditional) && is_valid_conditional(condt, currpc, frame)
else_change = conditional_change(𝕃ᵢ, currstate, condt, #=then_or_else=#false)
if else_change !== nothing
elsestate = copy(currstate)
Expand Down Expand Up @@ -3474,14 +3476,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
return caller.ssavaluetypes[caller_pc] !== Any
end
end
ssavaluetypes[frame.currpc] = Any
ssavaluetypes[currpc] = Any
@goto find_next_bb
elseif isa(stmt, EnterNode)
ssavaluetypes[currpc] = Any
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
if isdefined(stmt, :scope)
scopet = abstract_eval_value(interp, stmt.scope, currstate, frame)
handler = gethandler(frame, frame.currpc+1)::TryCatchFrame
handler = gethandler(frame, currpc+1)::TryCatchFrame
@assert handler.scopet !== nothing
if !⊑(𝕃ᵢ, scopet, handler.scopet)
handler.scopet = tmerge(𝕃ᵢ, scopet, handler.scopet)
Expand Down Expand Up @@ -3581,6 +3583,18 @@ function apply_refinement!(𝕃ᵢ::AbstractLattice, slot::SlotNumber, @nospecia
end
end

function is_valid_conditional(condt::Conditional, use_ssa::Int, sv::InferenceState)
domtree = get!(sv.lazydomtree)
dominates_ssa(sv.cfg, domtree, condt.from_ssa, use_ssa) ||
condt.from_ssa == use_ssa ||
return false
return all(getassignment(sv.slotassignments, condt.slot)) do aidx::Int
return (aidx == condt.from_ssa == 0 ||
dominates_ssa(sv.cfg, domtree, aidx, condt.from_ssa) ||
dominates_ssa(sv.cfg, domtree, use_ssa, aidx))
end
end

function conditional_change(𝕃ᵢ::AbstractLattice, currstate::VarTable, condt::Conditional, then_or_else::Bool)
vtype = currstate[condt.slot]
oldtyp = vtype.typ
Expand Down Expand Up @@ -3608,7 +3622,8 @@ function condition_object_change(currstate::VarTable, condt::Conditional,
vtype = currstate[slot_id(condslot)]
newcondt = Conditional(condt.slot,
then_or_else ? condt.thentype : Union{},
then_or_else ? Union{} : condt.elsetype)
then_or_else ? Union{} : condt.elsetype,
#=from_ssa=#condt.from_ssa)
return StateUpdate(condslot, VarState(newcondt, vtype.undef), false)
end

Expand Down
31 changes: 24 additions & 7 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,16 @@ function get!(x::LazyCFGReachability)
end

mutable struct LazyGenericDomtree{IsPostDom}
ir::IRCode
cfg::CFG
domtree::GenericDomTree{IsPostDom}
LazyGenericDomtree{IsPostDom}(ir::IRCode) where {IsPostDom} = new{IsPostDom}(ir)
LazyGenericDomtree{IsPostDom}(cfg::CFG) where {IsPostDom} = new{IsPostDom}(cfg)
end
LazyGenericDomtree{IsPostDom}(ir::IRCode) where {IsPostDom} = LazyGenericDomtree{IsPostDom}(ir.cfg)
function get!(x::LazyGenericDomtree{IsPostDom}) where {IsPostDom}
isdefined(x, :domtree) && return x.domtree
return @timeit "domtree 2" x.domtree = IsPostDom ?
construct_postdomtree(x.ir) :
construct_domtree(x.ir)
construct_postdomtree(x.cfg) :
construct_domtree(x.cfg)
end

const LazyDomtree = LazyGenericDomtree{false}
Expand Down Expand Up @@ -227,6 +228,16 @@ struct HandlerInfo
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

struct SlotAssignments
assignments::Vector{BitSet}
SlotAssignments(nslots::Int) = new(Vector{BitSet}(undef, nslots))
end
const ARGUMENT_ASSIGNMENT = BitSet(0)
getassignment(sa::SlotAssignments, sidx::Int) = isassigned(sa.assignments, sidx) ?
sa.assignments[sidx] : ARGUMENT_ASSIGNMENT
setassignment!(sa::SlotAssignments, sidx::Int, pc::Int) = push!(isassigned(sa.assignments, sidx) ?
sa.assignments[sidx] : (sa.assignments[sidx] = BitSet()), pc)

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -249,6 +260,8 @@ mutable struct InferenceState
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Vector{Any}}
stmt_info::Vector{CallInfo}
slotassignments::SlotAssignments
lazydomtree::LazyDomtree

#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
Expand Down Expand Up @@ -308,6 +321,8 @@ mutable struct InferenceState
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
slotassignments = SlotAssignments(nslots)
lazydomtree = LazyDomtree(cfg)
argtypes = result.argtypes

argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva)
Expand All @@ -316,7 +331,7 @@ mutable struct InferenceState
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
if argtyp === Bool && has_conditional(typeinf_lattice(interp))
argtyp = Conditional(i, Const(true), Const(false))
argtyp = Conditional(i, Const(true), Const(false), #=from_ssa=#0)
end
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
Expand Down Expand Up @@ -350,7 +365,8 @@ mutable struct InferenceState

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes,
stmt_edges, stmt_info, slotassignments, lazydomtree,
pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
Expand Down Expand Up @@ -734,7 +750,8 @@ end

_topmod(sv::InferenceState) = _topmod(frame_module(sv))

function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize(new), frame::InferenceState)
function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize(new),
frame::InferenceState)
ssavaluetypes = frame.ssavaluetypes
old = ssavaluetypes[ssa_id]
if old === NOT_FOUND || !is_lattice_equal(𝕃ᵢ, new, old)
Expand Down
9 changes: 9 additions & 0 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,15 @@ function dominates_ssa(compact::IncrementalCompact, domtree::DomTree, x::AnySSAV
return dominates(domtree, xb, yb)
end

function dominates_ssa(cfg::CFG, domtree::DomTree, x::Int, y::Int)
xb = block_for_inst(cfg, x)
yb = block_for_inst(cfg, y)
if xb == yb
return x < y
end
return dominates(domtree, xb, yb)
end

function _count_added_node!(compact::IncrementalCompact, @nospecialize(val))
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ end

function not_tfunc(𝕃::AbstractLattice, @nospecialize(b))
if isa(b, Conditional)
return Conditional(b.slot, b.elsetype, b.thentype)
return Conditional(b.slot, b.elsetype, b.thentype, b.from_ssa)
elseif isa(b, Const)
return Const(not_int(b.val))
end
Expand Down Expand Up @@ -350,14 +350,14 @@ end
if isa(x, Conditional)
y = widenconditional(y)
if isa(y, Const)
y.val === false && return Conditional(x.slot, x.elsetype, x.thentype)
y.val === false && return Conditional(x.slot, x.elsetype, x.thentype, x.from_ssa)
y.val === true && return x
return Const(false)
end
elseif isa(y, Conditional)
x = widenconditional(x)
if isa(x, Const)
x.val === false && return Conditional(y.slot, y.elsetype, y.thentype)
x.val === false && return Conditional(y.slot, y.elsetype, y.thentype, y.from_ssa)
x.val === true && return y
return Const(false)
end
Expand Down
Loading