Skip to content

Commit

Permalink
inference: track reaching defs for slots
Browse files Browse the repository at this point in the history
This change effectively computes the SSA / ϕ-nodes for program slots as
part of type-inference, using the "path-convergence criterion" for SSA.

This allows us to conveniently reason about slot identity (in typical
SSA fashion) without having to quadratically expand all of our SSA type
state over the CFG.
  • Loading branch information
topolarity committed Aug 27, 2024
1 parent 78b0b74 commit 5ffbc97
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 150 deletions.
219 changes: 130 additions & 89 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,13 @@ mutable struct InferenceState
nargtypes = length(argtypes)
for i = 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
# 0 = function entry (think carefully)
if argtyp === Bool && has_conditional(typeinf_lattice(interp))
argtyp = Conditional(i, Const(true), Const(false))
argtyp = Conditional(i, #= ssadef =# 0, Const(true), Const(false))
end
slottypes[i] = argtyp
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
# 0 = function entry (think carefully)
bb_vartable1[i] = VarState(argtyp, #= ssadef =# 0, i > nargtypes)
end
src.ssavaluetypes = ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]

Expand Down Expand Up @@ -712,7 +714,7 @@ function sptypes_from_meth_instance(mi::MethodInstance)
ty = Const(v)
undef = false
end
sptypes[i] = VarState(ty, undef)
sptypes[i] = VarState(ty, typemin(Int), undef)
end
return sptypes
end
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ function OptimizationState(mi::MethodInstance, src::CodeInfo, interp::AbstractIn
bb_vartables = Union{VarTable,Nothing}[]
for block = 1:length(cfg.blocks)
push!(bb_vartables, VarState[
VarState(slottypes[slot], src.slotflags[slot] & SLOT_USEDUNDEF != 0)
VarState(slottypes[slot], typemin(Int), src.slotflags[slot] & SLOT_USEDUNDEF != 0)
for slot = 1:nslots
])
end
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ function abstract_eval_phi_stmt(interp::AbstractInterpreter, phi::PhiNode, ::Int
return abstract_eval_phi(interp, phi, nothing, irsv)
end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, vtypes::Union{VarTable,Nothing}, irsv::IRInterpretationState)
si = StmtInfo(true) # TODO better job here?
call = abstract_call(interp, arginfo, si, irsv)
call = abstract_call(interp, arginfo, si, vtypes, irsv)
irsv.ir.stmts[irsv.curridx][:info] = call.info
return call
end
Expand Down
13 changes: 7 additions & 6 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ end

function not_tfunc(𝕃::AbstractLattice, @nospecialize(b))
if isa(b, Conditional)
return Conditional(b.slot, b.elsetype, b.thentype)
return Conditional(b.slot, b.ssadef, b.elsetype, b.thentype)
elseif isa(b, Const)
return Const(not_int(b.val))
end
Expand Down Expand Up @@ -350,14 +350,14 @@ end
if isa(x, Conditional)
y = widenconditional(y)
if isa(y, Const)
y.val === false && return Conditional(x.slot, x.elsetype, x.thentype)
y.val === false && return Conditional(x.slot, x.ssadef, x.elsetype, x.thentype)
y.val === true && return x
return Const(false)
end
elseif isa(y, Conditional)
x = widenconditional(x)
if isa(x, Const)
x.val === false && return Conditional(y.slot, y.elsetype, y.thentype)
x.val === false && return Conditional(y.slot, y.ssadef, y.elsetype, y.thentype)
x.val === true && return y
return Const(false)
end
Expand Down Expand Up @@ -1415,7 +1415,7 @@ end
# as well as compute the info for the method matches
op = unwrapva(argtypes[op_argi])
v = unwrapva(argtypes[v_argi])
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true), sv, #=max_methods=#1)
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), StmtInfo(true), vtypes, sv, #=max_methods=#1)
TF2 = tmeet(callinfo.rt, widenconst(TF))
if TF2 === Bottom
RT = Bottom
Expand Down Expand Up @@ -2931,10 +2931,11 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
if isa(sv, InferenceState)
old_restrict = sv.restrict_abstract_call_sites
sv.restrict_abstract_call_sites = false
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, #=max_methods=#-1)
# TODO: vtypes?
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, nothing, sv, #=max_methods=#-1)
sv.restrict_abstract_call_sites = old_restrict
else
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, #=max_methods=#-1)
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, nothing, sv, #=max_methods=#-1)
end
info = verbose_stmt_info(interp) ? MethodResultPure(ReturnTypeCallInfo(call.info)) : MethodResultPure()
rt = widenslotwrapper(call.rt)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState)
for slot in 1:nslots
vt = varstate[slot]
widened_type = widenslotwrapper(ignorelimited(vt.typ))
varstate[slot] = VarState(widened_type, vt.undef)
varstate[slot] = VarState(widened_type, vt.ssadef, vt.undef)
end
end
end
Expand Down
73 changes: 50 additions & 23 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ the type of `SlotNumber(cnd.slot)` will be limited by `cnd.thentype`
and in the false branch, it will be limited by `cnd.elsetype`.
Example:
```julia
let cond = isa(x::Union{Int, Float}, Int)::Conditional(x, Int, Float)
let cond = isa(x::Union{Int, Float}, Int)::Conditional(x, _, Int, Float)
if cond
# May assume x is `Int` now
else
Expand All @@ -71,21 +71,22 @@ end
"""
struct Conditional
slot::Int
ssadef::Int
thentype
elsetype
# `isdefined` indicates this `Conditional` is from `@isdefined slot`, implying that
# the `undef` information of `slot` can be improved in the then branch.
# Since this is only beneficial for local inference, it is not translated into `InterConditional`.
isdefined::Bool
function Conditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype);
function Conditional(slot::Int, ssadef::Int, @nospecialize(thentype), @nospecialize(elsetype);
isdefined::Bool=false)
assert_nested_slotwrapper(thentype)
assert_nested_slotwrapper(elsetype)
return new(slot, thentype, elsetype, isdefined)
return new(slot, ssadef, thentype, elsetype, isdefined)
end
end
Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype); isdefined::Bool=false) =
Conditional(slot_id(var), thentype, elsetype; isdefined)
Conditional(var::SlotNumber, ssadef::Int, @nospecialize(thentype), @nospecialize(elsetype); isdefined::Bool=false) =
Conditional(slot_id(var), ssadef, thentype, elsetype; isdefined)

import Core: InterConditional
"""
Expand All @@ -105,8 +106,10 @@ InterConditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetyp
InterConditional(slot_id(var), thentype, elsetype)

const AnyConditional = Union{Conditional,InterConditional}
Conditional(cnd::InterConditional) = Conditional(cnd.slot, cnd.thentype, cnd.elsetype)
InterConditional(cnd::Conditional) = InterConditional(cnd.slot, cnd.thentype, cnd.elsetype)
function InterConditional(cnd::Conditional)
@assert cnd.ssadef == 0
InterConditional(cnd.slot, cnd.thentype, cnd.elsetype)
end

"""
alias::MustAlias
Expand Down Expand Up @@ -184,8 +187,20 @@ end
struct StateUpdate
var::SlotNumber
vtype::VarState
conditional::Bool
StateUpdate(var::SlotNumber, vtype::VarState, conditional::Bool=false) = new(var, vtype, conditional)
end

"""
Similar to `StateUpdate`, except with the additional guarantee that object identity
is preserved by the update (i.e. `x (before) === x (after)`).
"""
struct StateRefinement
slot::Int
# XXX: This should be an intersection of the old type with the new
# (i.e. newtyp ⊑ oldtyp)
newtyp
undef::Bool

StateRefinement(slot::Int, @nospecialize(newtyp), undef::Bool) = new(slot, newtyp, undef)
end

"""
Expand Down Expand Up @@ -328,6 +343,7 @@ end
return false
end

is_same_conditionals(a::Conditional, b::Conditional) = a.slot == b.slot && a.ssadef == b.ssadef
is_same_conditionals(a::C, b::C) where C<:AnyConditional = a.slot == b.slot

@nospecializeinfer is_lattice_bool(lattice::AbstractLattice, @nospecialize(typ)) = typ !== Bottom && (lattice, typ, Bool)
Expand Down Expand Up @@ -387,7 +403,7 @@ end
elsefields === nothing || (elsefields[i] = elsetype)
end
end
return Conditional(slot,
return Conditional(slot, typemin(Int), # TODO
thenfields === nothing ? Bottom : PartialStruct(vartyp.typ, thenfields),
elsefields === nothing ? Bottom : PartialStruct(vartyp.typ, elsefields))
else
Expand All @@ -404,7 +420,7 @@ end
elsefields === nothing || push!(elsefields, t)
end
end
return Conditional(slot,
return Conditional(slot, typemin(Int),
thenfields === nothing ? Bottom : PartialStruct(vartyp_widened, thenfields),
elsefields === nothing ? Bottom : PartialStruct(vartyp_widened, elsefields))
end
Expand Down Expand Up @@ -745,34 +761,39 @@ widenconst(::LimitedAccuracy) = error("unhandled LimitedAccuracy")
# state management #
####################

function smerge(lattice::AbstractLattice, sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState})
function smerge(lattice::AbstractLattice, sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState}, join_pc::Int)
sa === sb && return sa
sa === NOT_FOUND && return sb
sb === NOT_FOUND && return sa
return VarState(tmerge(lattice, sa.typ, sb.typ), sa.undef | sb.undef)
return VarState(tmerge(lattice, sa.typ, sb.typ), sa.ssadef == sb.ssadef ? sa.ssadef : join_pc, sa.undef | sb.undef)
end

@nospecializeinfer @inline schanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o)) =
(n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !(n.undef <= o.undef && (lattice, n.typ, o.typ))))
@nospecializeinfer @inline schanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o), join_pc::Int) =
(n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !(n.undef <= o.undef && (n.ssadef == o.ssadef || o.ssadef == join_pc) && (lattice, n.typ, o.typ))))

# remove any lattice elements that wrap the reassigned slot object from the vartable
function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional::Bool)
function invalidate_slotwrapper(vt::VarState, changeid::Int)
newtyp = ignorelimited(vt.typ)
if (!ignore_conditional && isa(newtyp, Conditional) && newtyp.slot == changeid) ||
(isa(newtyp, MustAlias) && newtyp.slot == changeid)
if ((isa(newtyp, Conditional) && newtyp.slot == changeid) ||
(isa(newtyp, MustAlias) && newtyp.slot == changeid))
newtyp = @noinline widenwrappedslotwrapper(vt.typ)
return VarState(newtyp, vt.undef)
return VarState(newtyp, vt.ssadef, vt.undef)
end
return nothing
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable, join_pc::Int)
changed = false
for i = 1:length(state)
newtype = changes[i]
oldtype = state[i]
if schanged(lattice, newtype, oldtype)
state[i] = smerge(lattice, oldtype, newtype)
# In addition to computing the type, the merge here computes the "reaching definition"
# for a slot. The provided `join_pc` is a "virtual" PC, which corresponds to the ϕ-block
# that would exist at the beginning of the BasicBlock.
#
# This effectively applies the "path-convergence criterion" for SSA construction.
if schanged(lattice, newtype, oldtype, join_pc)
state[i] = smerge(lattice, oldtype, newtype, join_pc)
changed = true
end
end
Expand All @@ -789,7 +810,7 @@ end
function stoverwrite1!(state::VarTable, change::StateUpdate)
changeid = slot_id(change.var)
for i = 1:length(state)
invalidated = invalidate_slotwrapper(state[i], changeid, change.conditional)
invalidated = invalidate_slotwrapper(state[i], changeid)
if invalidated !== nothing
state[i] = invalidated
end
Expand All @@ -799,3 +820,9 @@ function stoverwrite1!(state::VarTable, change::StateUpdate)
state[changeid] = newtype
return state
end

function strefine1!(state::VarTable, refinement::StateRefinement)
(; newtyp, undef, slot) = refinement
state[slot] = VarState(newtyp, state[slot].ssadef, undef)
return state
end
10 changes: 5 additions & 5 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -494,24 +494,24 @@ end
# type-lattice for Conditional wrapper (NOTE never be merged with InterConditional)
if isa(typea, Conditional) && isa(typeb, Const)
if typeb.val === true
typeb = Conditional(typea.slot, Any, Union{})
typeb = Conditional(typea.slot, typea.ssadef, Any, Union{})
elseif typeb.val === false
typeb = Conditional(typea.slot, Union{}, Any)
typeb = Conditional(typea.slot, typea.ssadef, Union{}, Any)
end
end
if isa(typeb, Conditional) && isa(typea, Const)
if typea.val === true
typea = Conditional(typeb.slot, Any, Union{})
typea = Conditional(typeb.slot, typeb.ssadef, Any, Union{})
elseif typea.val === false
typea = Conditional(typeb.slot, Union{}, Any)
typea = Conditional(typeb.slot, typeb.ssadef, Union{}, Any)
end
end
if isa(typea, Conditional) && isa(typeb, Conditional)
if is_same_conditionals(typea, typeb)
thentype = tmerge(widenlattice(lattice), typea.thentype, typeb.thentype)
elsetype = tmerge(widenlattice(lattice), typea.elsetype, typeb.elsetype)
if thentype !== elsetype
return Conditional(typea.slot, thentype, elsetype)
return Conditional(typea.slot, typea.ssadef, thentype, elsetype)
end
end
val = maybe_extract_const_bool(typea)
Expand Down
8 changes: 7 additions & 1 deletion base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,20 @@ MethodInfo(src::CodeInfo) = MethodInfo(
A special wrapper that represents a local variable of a method being analyzed.
This does not participate in the native type system nor the inference lattice, and it thus
should be always unwrapped to `v.typ` when performing any type or lattice operations on it.
`v.undef` represents undefined-ness of this static parameter. If `true`, it means that the
variable _may_ be undefined at runtime, otherwise it is guaranteed to be defined.
If `v.typ === Bottom` it means that the variable is strictly undefined.
`v.ssadef` represents the "reaching definition" for the variable. If negative, this refers
to a "virtual ϕ-block" preceding the given index. If a slot has the same `ssadef` at two
different points of execution, the slot contents are guaranteed to share identity (`x₀ === x₁`).
"""
struct VarState
typ
ssadef::Int
undef::Bool
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
VarState(@nospecialize(typ), ssadef::Int, undef::Bool) = new(typ, ssadef, undef)
end

struct AnalysisResults
Expand Down
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2246,7 +2246,7 @@ function print_statement_costs(io::IO, @nospecialize(tt::Type);
else
empty!(cst)
resize!(cst, length(code.code))
sptypes = Core.Compiler.VarState[Core.Compiler.VarState(sp, false) for sp in match.sparams]
sptypes = Core.Compiler.VarState[Core.Compiler.VarState(sp, #= ssadef =# typemin(Int), false) for sp in match.sparams]
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, params)
nd = ndigits(maxcost)
irshow_config = IRShow.IRShowConfig() do io, linestart, idx
Expand Down
6 changes: 3 additions & 3 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,10 @@ CC.nsplit_impl(info::NoinlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoinlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoinlineCallInfo, idx::Int) = CC.getresult(info.info, idx)

function CC.abstract_call(interp::NoinlineInterpreter,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.InferenceState, max_methods::Int)
function CC.abstract_call(interp::NoinlineInterpreter, arginfo::CC.ArgInfo, si::CC.StmtInfo,
vtypes::Union{VarTable,Nothing}, sv::CC.InferenceState, max_methods::Int)
ret = @invoke CC.abstract_call(interp::CC.AbstractInterpreter,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.InferenceState, max_methods::Int)
arginfo::CC.ArgInfo, si::CC.StmtInfo, vtypes::Union{VarTable,Nothing}, sv::CC.InferenceState, max_methods::Int)
if sv.mod in noinline_modules(interp)
return CC.CallMeta(ret.rt, ret.exct, ret.effects, NoinlineCallInfo(ret.info))
end
Expand Down
Loading

0 comments on commit 5ffbc97

Please sign in to comment.