Skip to content

Commit

Permalink
inference: backward constraint propagation from call signatures
Browse files Browse the repository at this point in the history
This PR implements another (limited) backward analysis pass in abstract
interpretation; it exploits signatures of matching methods and refines
types of slots.

Here are couple of examples where these changes will improve the accuracy:

> generic function example
```julia
addi(a::Integer, b::Integer) = a + b
Base.return_types((Any,Any,)) do a, b
    c = addi(a, b)
    return a, b, c # now the compiler understands `a::Integer`, `b::Integer`
end
```

> `typeassert` example
```julia
Base.return_types((Any,)) do a
    typeassert(a, Int)
    return a # now the compiler understands `a::Int`
end
```

This PR consists of two main parts: 1.) obtain refinement information
and back-propagate it, and 2.) apply state updates

As for 1., unlike conditional constraints, refinement information isn't
encoded within lattice element, but rather they are stored in the
newly defined field `InferenceState.state_updates`, which is refreshed
on each program counter increment. For now refinement information is
obtained from call signatures of generic functions and `typeassert`.

Finally, in order to apply multiple state updates, this PR extends
`StateUpdate` and `stupdate` so that they can hold and apply multiple
state updates.
  • Loading branch information
aviatesk committed May 23, 2021
1 parent a08a3ff commit 9254d3f
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 28 deletions.
103 changes: 92 additions & 11 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
any_const_result = false
const_results = Union{InferenceResult,Nothing}[]
multiple_matches = napplicable > 1
refine_targets = nothing # keeps refinement information on slot types obtained from call signature
if fargs !== nothing
refine_targets = Union{Nothing,Tuple{SlotNumber,Any}}[]
for x in fargs
push!(refine_targets, isa(x, SlotNumber) ? (x, Bottom) : nothing)
end
end

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
val = pure_eval_call(f, argtypes)
Expand Down Expand Up @@ -197,6 +204,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals[2][i] = tmerge(conditionals[2][i], elsetype)
end
end
if refine_targets !== nothing
for i in 1:length(refine_targets)
target = refine_targets[i]
if target !== nothing
slot, t = target
refine_targets[i] = (slot, tmerge(fieldtype(sig, i), t))
end
end
end
if bail_out_call(interp, rettype, sv)
break
end
Expand All @@ -209,6 +225,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
info = ConstCallInfo(info, const_results)
end

# refinement information from call signatures is valid only when obtained from all the
# matching signatures and we should throw away it if we bailed out early
if seen napplicable
refine_targets = nothing
end

if rettype isa LimitedAccuracy
union!(sv.pclimitations, rettype.causes)
rettype = rettype.typ
Expand Down Expand Up @@ -263,6 +285,18 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
@assert !(rettype isa InterConditional) "invalid lattice element returned from inter-procedural context"

# if refinement information on slot types is available, apply it now
anyrefined = false
if rettype !== Bottom && refine_targets !== nothing
for target in refine_targets
if target !== nothing
slot, t = target
if t !== Bottom
anyrefined |= add_state_update!(slot, t, sv)
end
end
end
end
if call_result_unused(sv) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
Expand All @@ -273,7 +307,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
add_call_backedges!(interp, anyrefined, rettype, edges, fullmatch, mts, atype, sv)
if !isempty(sv.pclimitations) # remove self, if present
delete!(sv.pclimitations, sv)
for caller in sv.callers_in_cycle
Expand All @@ -285,13 +319,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

function add_call_backedges!(interp::AbstractInterpreter,
@nospecialize(rettype),
anyrefined::Bool, @nospecialize(rettype),
edges::Vector{MethodInstance},
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
sv::InferenceState)
if rettype === Any
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
# (widen) this type
if !anyrefined && rettype === Any
# for `NativeInterpreter`, we don't add backedges when we've not used refinement
# information from call signature and a new method couldn't refine (widen) this type
return
end
for edge in edges
Expand Down Expand Up @@ -1000,6 +1034,11 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
if 1 <= idx <= length(cti)
rt = unwrapva(cti[idx])
end
elseif f === typeassert
# perform very limited back-propagation of invariants after this type asertion
if rt !== Bottom && isa(fargs, Vector{Any}) && (x2 = fargs[2]; isa(x2, SlotNumber))
add_state_update!(x2, rt, sv)
end
elseif (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any})
# perform very limited back-propagation of type information for `is` and `isa`
if f === isa
Expand Down Expand Up @@ -1356,7 +1395,7 @@ function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(
elseif isa(e, SSAValue)
return abstract_eval_ssavalue(e::SSAValue, sv.src)
elseif isa(e, SlotNumber) || isa(e, Argument)
return (vtypes[slot_id(e)]::VarState).typ
return get_varstate(vtypes, slot_id(e)).typ
elseif isa(e, GlobalRef)
return abstract_eval_global(e.mod, e.name)
end
Expand Down Expand Up @@ -1658,6 +1697,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
stmt = frame.src.code[pc]
changes = states[pc]::VarTable
t = nothing
empty!(frame.state_updates)

hd = isa(stmt, Expr) ? stmt.head : nothing

Expand Down Expand Up @@ -1778,12 +1818,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.src.ssavaluetypes[pc] = t
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
changes = StateUpdate(lhs, VarState(t, false), changes, false)
changes = StateUpdate([(lhs, VarState(t, false))], changes, false)
end
elseif hd === :method
fname = stmt.args[1]
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), changes, false)
changes = StateUpdate([(fname, VarState(Any, false))], changes, false)
end
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
# these do not generate code
Expand Down Expand Up @@ -1821,6 +1861,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

pc´ > n && break # can't proceed with the fast-path fall-through
frame.handler_at[pc´] = frame.cur_hand
changes = collect_state_changes!(changes, frame)
newstate = stupdate!(states[pc´], changes)
if isa(stmt, GotoNode) && frame.pc´´ < pc´
# if we are processing a goto node anyways,
Expand All @@ -1846,21 +1887,61 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
nothing
end

function add_state_update!(slot::SlotNumber, @nospecialize(new), frame::InferenceState)
state = frame.stmt_types[frame.currpc]::VarTable
old = get_varstate(state, slot).typ
if !(old new) # new ⋤ old
push!(frame.state_updates, (slot, new))
return true
end
return false
end

function collect_state_changes!(changes::StateUpdate, frame::InferenceState, conditional::Bool = changes.conditional)
slots = BitSet(slot_id(var) for (var, _) in changes.updates)
while !isempty(frame.state_updates)
var, typ = pop!(frame.state_updates)
slot_id(var) in slots && continue # effects of the statement (like assignment) should have the precedence
changes = add_state_change!(changes, var, typ, conditional)
end
return changes
end

function collect_state_changes!(changes::VarTable, frame::InferenceState, conditional::Bool = false)
while !isempty(frame.state_updates)
var, typ = pop!(frame.state_updates)
changes = add_state_change!(changes, var, typ, conditional)
end
return changes
end

function add_state_change!(changes::StateUpdate, var::SlotNumber, @nospecialize(typ), conditional::Bool)
@assert changes.conditional === conditional
vtype = VarState(typ, get_varstate(changes.state, var).undef)
push!(changes.updates, (var, vtype))
return changes
end

function add_state_change!(changes::VarTable, var::SlotNumber, @nospecialize(typ), conditional::Bool)
vtype = VarState(typ, get_varstate(changes, var).undef)
return StateUpdate([(var, vtype)], changes, conditional)
end

function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNumber)
oldtyp = (changes[slot_id(var)]::VarState).typ
oldtyp = get_varstate(changes, var).typ
# approximate test for `typ ∩ oldtyp` being better than `oldtyp`
# since we probably formed these types with `typesubstract`, the comparison is likely simple
if ignorelimited(typ) ignorelimited(oldtyp)
# typ is better unlimited, but we may still need to compute the tmeet with the limit "causes" since we ignored those in the comparison
oldtyp isa LimitedAccuracy && (typ = tmerge(typ, LimitedAccuracy(Bottom, oldtyp.causes)))
return StateUpdate(var, VarState(typ, false), changes, true)
return StateUpdate([(var, VarState(typ, false))], changes, true)
end
return changes
end

function bool_rt_to_conditional(@nospecialize(rt), slottypes::Vector{Any}, state::VarTable, slot_id::Int)
old = slottypes[slot_id]
new = widenconditional((state[slot_id]::VarState).typ) # avoid nested conditional
new = widenconditional(get_varstate(state, slot_id).typ)
if new old && !(old new)
if isa(rt, Const)
val = rt.val
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mutable struct InferenceState
stmt_types::Vector{Union{Nothing, Vector{Any}}} # ::Vector{Union{Nothing, VarTable}}
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
stmt_info::Vector{Any}
state_updates::Vector{Tuple{SlotNumber,Any}} # additional state update obtained at currpc
# return type
bestguess #::Type
# current active instruction pointers
Expand Down Expand Up @@ -108,7 +109,7 @@ mutable struct InferenceState
sp, slottypes, inmodule, 0,
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
nargs, s_types, s_edges, stmt_info, Tuple{SlotNumber,Any}[],
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
ssavalue_uses, throw_blocks,
Expand Down
44 changes: 28 additions & 16 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ end

const VarTable = Array{Any,1}

get_varstate(state::VarTable, slot::SlotNumber) = get_varstate(state, slot_id(slot))
get_varstate(state::VarTable, slot::Int) = state[slot]::VarState

struct StateUpdate
var::SlotNumber
vtype::VarState
updates::Vector{Tuple{SlotNumber,VarState}}
state::VarTable
conditional::Bool
end
Expand Down Expand Up @@ -320,32 +322,40 @@ ignorelimited(@nospecialize typ) = typ
ignorelimited(typ::LimitedAccuracy) = typ.typ

function stupdate!(state::Nothing, changes::StateUpdate)
newst = copy(changes.state)
changeid = slot_id(changes.var)
newst[changeid] = changes.vtype
newstate = copy(changes.state)
changeids = Int[]
for (var, vtype) in changes.updates
changeid = slot_id(var)
newstate[changeid] = vtype
push!(changeids, changeid)
end
# remove any Conditional for this slot from the vtable
# (unless this change is came from the conditional)
if !changes.conditional
for i = 1:length(newst)
newtype = newst[i]
for i = 1:length(newstate)
newtype = newstate[i]
if isa(newtype, VarState)
newtypetyp = ignorelimited(newtype.typ)
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) in changeids
newtypetyp = widenwrappedconditional(newtype.typ)
newst[i] = VarState(newtypetyp, newtype.undef)
newstate[i] = VarState(newtypetyp, newtype.undef)
end
end
end
end
return newst
return newstate
end

function stupdate!(state::VarTable, changes::StateUpdate)
changeids = Int[]
for (var, _) in changes.updates
push!(changeids, slot_id(var))
end
newstate = nothing
changeid = slot_id(changes.var)
for i = 1:length(state)
if i == changeid
newtype = changes.vtype
j = findfirst(==(i), changeids)
if j !== nothing
newtype = changes.updates[j][2]
else
newtype = changes.state[i]
end
Expand All @@ -354,7 +364,7 @@ function stupdate!(state::VarTable, changes::StateUpdate)
# (unless this change is came from the conditional)
if !changes.conditional && isa(newtype, VarState)
newtypetyp = ignorelimited(newtype.typ)
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) in changeids
newtypetyp = widenwrappedconditional(newtype.typ)
newtype = VarState(newtypetyp, newtype.undef)
end
Expand Down Expand Up @@ -385,7 +395,9 @@ stupdate!(state::Nothing, changes::VarTable) = copy(changes)
stupdate!(state::Nothing, changes::Nothing) = nothing

function stupdate1!(state::VarTable, change::StateUpdate)
changeid = slot_id(change.var)
@assert length(change.updates) == 1
var, vtype = change.updates[1]
changeid = slot_id(var)
# remove any Conditional for this slot from the catch block vtable
# (unless this change is came from the conditional)
if !change.conditional
Expand All @@ -404,7 +416,7 @@ function stupdate1!(state::VarTable, change::StateUpdate)
end
end
# and update the type of it
newtype = change.vtype
newtype = vtype
oldtype = state[changeid]
if schanged(newtype, oldtype)
state[changeid] = smerge(oldtype, newtype)
Expand Down
Loading

0 comments on commit 9254d3f

Please sign in to comment.