Skip to content

Commit

Permalink
inference: backward constraint propagation from call signatures (Juli…
Browse files Browse the repository at this point in the history
…aLang#55229)

This PR implements another (limited) backward analysis pass in abstract
interpretation; it exploits signatures of matching methods and refines
types of slots.

Here are couple of examples where these changes will improve the
accuracy:

> generic function example
```julia
addint(a::Int, b::Int) = a + b
@test Base.infer_return_type((Any,Any,)) do a, b
    c = addint(a, b)
    return a, b, c # now the compiler understands `a::Int`, `b::Int`
end == Tuple{Int,Int,Int}
```

> `typeassert` example
```julia
@test Base.infer_return_type((Any,)) do a
    a::Int
    return a # now the compiler understands `a::Int`
end == Int
```

Unlike `Conditional` constrained type propagation, this type refinement
information isn't encoded within any lattice element, but rather they
are propagated within the newly added field `frame.curr_stmt_change` of
`frame::InferenceState`.
For now this commit exploits refinement information available from call
signatures of generic functions and `typeassert`.

---

- closes JuliaLang#37866
- fixes JuliaLang#38274
- closes JuliaLang#55115
  • Loading branch information
aviatesk authored and lazarusA committed Aug 17, 2024
1 parent ac3033c commit 339b416
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 41 deletions.
166 changes: 127 additions & 39 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct SlotRefinement
slot::SlotNumber
typ::Any
SlotRefinement(slot::SlotNumber, @nospecialize(typ)) = new(slot, typ)
end

# See if the inference result of the current statement's result value might affect
# the final answer for the method (aside from optimization potential and exceptions).
# To do that, we need to check both for slot assignment and SSA usage.
Expand Down Expand Up @@ -39,6 +45,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
multiple_matches = napplicable > 1
fargs = arginfo.fargs
all_effects = EFFECTS_TOTAL
slotrefinements = nothing # keeps refinement information on slot types obtained from call signature

for i in 1:napplicable
match = applicable[i]::MethodMatch
Expand Down Expand Up @@ -177,11 +184,16 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# there is unanalyzed candidate, widen type and effects to the top
rettype = excttype = Any
all_effects = Effects()
elseif isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
excttype = tmerge(𝕃ₚ, excttype, MethodError)
else
if (matches isa MethodMatches ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches)))
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
excttype = tmerge(𝕃ₚ, excttype, MethodError)
end
if sv isa InferenceState && fargs !== nothing
slotrefinements = collect_slot_refinements(𝕃ᵢ, applicable, argtypes, fargs, sv)
end
end

rettype = from_interprocedural!(interp, rettype, sv, arginfo, conditionals)
Expand Down Expand Up @@ -213,7 +225,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
any_slot_refined = slotrefinements !== nothing
add_call_backedges!(interp, rettype, all_effects, any_slot_refined, edges, matches, atype, sv)
if isa(sv, InferenceState)
# TODO (#48913) implement a proper recursion handling for irinterp:
# This works just because currently the `:terminate` condition guarantees that
Expand All @@ -227,7 +240,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end

return CallMeta(rettype, excttype, all_effects, info)
return CallMeta(rettype, excttype, all_effects, info, slotrefinements)
end

struct FailedMethodMatch
Expand Down Expand Up @@ -492,8 +505,37 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
end
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), all_effects::Effects,
edges::Vector{MethodInstance}, matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
function collect_slot_refinements(𝕃ᵢ::AbstractLattice, applicable::Vector{Any},
argtypes::Vector{Any}, fargs::Vector{Any}, sv::InferenceState)
, = strictpartialorder(𝕃ᵢ), join(𝕃ᵢ)
slotrefinements = nothing
for i = 1:length(fargs)
fargᵢ = fargs[i]
if fargᵢ isa SlotNumber
fidx = slot_id(fargᵢ)
argt = widenslotwrapper(argtypes[i])
if isvarargtype(argt)
argt = unwrapva(argt)
end
sigt = Bottom
for j = 1:length(applicable)
match = applicable[j]::MethodMatch
sigt = sigt fieldtype(match.spec_types, i)
end
if sigt argt # i.e. signature type is strictly more specific than the type of the argument slot
if slotrefinements === nothing
slotrefinements = fill!(Vector{Any}(undef, length(sv.slottypes)), nothing)
end
slotrefinements[fidx] = sigt
end
end
end
return slotrefinements
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype),
all_effects::Effects, any_slot_refined::Bool, edges::Vector{MethodInstance},
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
sv::AbsIntState)
# don't bother to add backedges when both type and effects information are already
# maximized to the top since a new method couldn't refine or widen them anyway
Expand All @@ -503,7 +545,9 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
if !isoverlayed(method_table(interp))
all_effects = Effects(all_effects; nonoverlayed=ALWAYS_FALSE)
end
all_effects === Effects() && return nothing
if all_effects === Effects() && !any_slot_refined
return nothing
end
end
for edge in edges
add_backedge!(sv, edge)
Expand Down Expand Up @@ -1794,7 +1838,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
@nospecialize f
la = length(argtypes)
𝕃ᵢ = typeinf_lattice(interp)
= (𝕃ᵢ)
, , , = partialorder(𝕃ᵢ), strictpartialorder(𝕃ᵢ), join(𝕃ᵢ), meet(𝕃ᵢ)
if has_conditional(𝕃ᵢ, sv) && f === Core.ifelse && fargs isa Vector{Any} && la == 4
cnd = argtypes[2]
if isa(cnd, Conditional)
Expand All @@ -1809,12 +1853,12 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
a = ssa_def_slot(fargs[3], sv)
b = ssa_def_slot(fargs[4], sv)
if isa(a, SlotNumber) && cnd.slot == slot_id(a)
tx = (cnd.thentype tx ? cnd.thentype : tmeet(𝕃ᵢ, tx, widenconst(cnd.thentype)))
tx = (cnd.thentype tx ? cnd.thentype : tx widenconst(cnd.thentype))
end
if isa(b, SlotNumber) && cnd.slot == slot_id(b)
ty = (cnd.elsetype ty ? cnd.elsetype : tmeet(𝕃ᵢ, ty, widenconst(cnd.elsetype)))
ty = (cnd.elsetype ty ? cnd.elsetype : ty widenconst(cnd.elsetype))
end
return tmerge(𝕃ᵢ, tx, ty)
return tx ty
end
end
end
Expand Down Expand Up @@ -1939,13 +1983,13 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = tmerge(thentype, ty)
thentype = thentype ty
else
elsetype = tmerge(elsetype, ty)
elsetype = elsetype ty
end
else
thentype = tmerge(thentype, ty)
elsetype = tmerge(elsetype, ty)
thentype = thentype ty
elsetype = elsetype ty
end
end
return Conditional(a, thentype, elsetype)
Expand All @@ -1970,8 +2014,8 @@ function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{An
elseif na == 3
a2 = argtypes[2]
a3 = argtypes[3]
= (typeinf_lattice(interp))
nothrow = a2 TypeVar && (a3 Type || a3 TypeVar)
= partialorder(typeinf_lattice(interp))
nothrow = a2 TypeVar && (a3 Type || a3 TypeVar)
else
return CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo())
end
Expand Down Expand Up @@ -2003,7 +2047,8 @@ function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{An
return CallMeta(ret, Any, Effects(EFFECTS_TOTAL; nothrow), call.info)
end

function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo, sv::AbsIntState)
function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState)
argtypes = arginfo.argtypes
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo())
Expand Down Expand Up @@ -2034,6 +2079,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
res = nothing
sig = match.spec_types
argtypes′ = invoke_rewrite(argtypes)
fargs = arginfo.fargs
fargs′ = fargs === nothing ? nothing : invoke_rewrite(fargs)
arginfo = ArgInfo(fargs′, argtypes′)
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
Expand Down Expand Up @@ -2122,7 +2168,17 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
exct = builtin_exct(𝕃ᵢ, f, argtypes, rt)
end
pushfirst!(argtypes, ft)
return CallMeta(rt, exct, effects, NoCallInfo())
refinements = nothing
if sv isa InferenceState && f === typeassert
# perform very limited back-propagation of invariants after this type assertion
if rt !== Bottom && isa(fargs, Vector{Any})
farg2 = fargs[2]
if farg2 isa SlotNumber
refinements = SlotRefinement(farg2, rt)
end
end
end
return CallMeta(rt, exct, effects, NoCallInfo(), refinements)
elseif isa(f, Core.OpaqueClosure)
# calling an OpaqueClosure about which we have no information returns no information
return CallMeta(typeof(f).parameters[2], Any, Effects(), NoCallInfo())
Expand Down Expand Up @@ -2412,10 +2468,14 @@ function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::
end

struct RTEffects
rt
exct
rt::Any
exct::Any
effects::Effects
RTEffects(@nospecialize(rt), @nospecialize(exct), effects::Effects) = new(rt, exct, effects)
refinements # ::Union{Nothing,SlotRefinement,Vector{Any}}
function RTEffects(rt, exct, effects::Effects, refinements=nothing)
@nospecialize rt exct refinements
return new(rt, exct, effects, refinements)
end
end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState)
Expand All @@ -2437,8 +2497,8 @@ function abstract_eval_call(interp::AbstractInterpreter, e::Expr, vtypes::Union{
return RTEffects(Bottom, Any, Effects())
end
arginfo = ArgInfo(ea, argtypes)
(; rt, exct, effects) = abstract_call(interp, arginfo, sv)
return RTEffects(rt, exct, effects)
(; rt, exct, effects, refinements) = abstract_call(interp, arginfo, sv)
return RTEffects(rt, exct, effects, refinements)
end

function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing},
Expand Down Expand Up @@ -2780,9 +2840,9 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
rt = old_rt === NOT_FOUND ? rt : tmerge(typeinf_lattice(interp), old_rt, rt)
return RTEffects(rt, Union{}, EFFECTS_TOTAL)
end
(; rt, exct, effects) = abstract_eval_special_value(interp, e, vtypes, sv)
(; rt, exct, effects, refinements) = abstract_eval_special_value(interp, e, vtypes, sv)
else
(; rt, exct, effects) = abstract_eval_statement_expr(interp, e, vtypes, sv)
(; rt, exct, effects, refinements) = abstract_eval_statement_expr(interp, e, vtypes, sv)
if effects.noub === NOUB_IF_NOINBOUNDS
if has_curr_ssaflag(sv, IR_FLAG_INBOUNDS)
effects = Effects(effects; noub=ALWAYS_FALSE)
Expand Down Expand Up @@ -2812,7 +2872,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
set_curr_ssaflag!(sv, flags_for_effects(effects), IR_FLAGS_EFFECTS)
merge_effects!(interp, sv, effects)

return RTEffects(rt, exct, effects)
return RTEffects(rt, exct, effects, refinements)
end

function override_effects(effects::Effects, override::EffectsOverride)
Expand Down Expand Up @@ -3009,13 +3069,13 @@ end
fields = copy(rt.fields)
local anyrefine = false
𝕃 = typeinf_lattice(info.interp)
= strictpartialorder(𝕃)
for i in 1:length(fields)
a = fields[i]
a = isvarargtype(a) ? a : widenreturn_noslotwrapper(𝕃, a, info)
if !anyrefine
# TODO: consider adding && const_prop_profitable(a) here?
anyrefine = has_extended_info(a) ||
(𝕃, a, fieldtype(rt.typ, i))
anyrefine = has_extended_info(a) || a fieldtype(rt.typ, i)
end
fields[i] = a
end
Expand Down Expand Up @@ -3061,7 +3121,12 @@ struct BasicStmtChange
rt::Any # extended lattice element or `nothing` - `nothing` if this statement may not be used as an SSA Value
exct::Any
# TODO effects::Effects
BasicStmtChange(changes::Union{Nothing,StateUpdate}, @nospecialize(rt), @nospecialize(exct)) = new(changes, rt, exct)
refinements # ::Union{Nothing,SlotRefinement,Vector{Any}}
function BasicStmtChange(changes::Union{Nothing,StateUpdate}, rt::Any, exct::Any,
refinements=nothing)
@nospecialize rt exct refinements
return new(changes, rt, exct, refinements)
end
end

@inline function abstract_eval_basic_statement(interp::AbstractInterpreter,
Expand All @@ -3076,9 +3141,9 @@ end
changes = nothing
hd = stmt.head
if hd === :(=)
(; rt, exct) = abstract_eval_statement(interp, stmt.args[2], pc_vartable, frame)
(; rt, exct, refinements) = abstract_eval_statement(interp, stmt.args[2], pc_vartable, frame)
if rt === Bottom
return BasicStmtChange(nothing, Bottom, exct)
return BasicStmtChange(nothing, Bottom, exct, refinements)
end
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
Expand All @@ -3088,7 +3153,7 @@ end
elseif !isa(lhs, SSAValue)
merge_effects!(interp, frame, EFFECTS_UNKNOWN)
end
return BasicStmtChange(changes, rt, exct)
return BasicStmtChange(changes, rt, exct, refinements)
elseif hd === :method
fname = stmt.args[1]
if isa(fname, SlotNumber)
Expand All @@ -3100,8 +3165,8 @@ end
is_meta_expr(stmt)))
return BasicStmtChange(nothing, Nothing, Bottom)
else
(; rt, exct) = abstract_eval_statement(interp, stmt, pc_vartable, frame)
return BasicStmtChange(nothing, rt, exct)
(; rt, exct, refinements) = abstract_eval_statement(interp, stmt, pc_vartable, frame)
return BasicStmtChange(nothing, rt, exct, refinements)
end
end

Expand Down Expand Up @@ -3363,7 +3428,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# Fall through terminator - treat as regular stmt
end
# Process non control-flow statements
(; changes, rt, exct) = abstract_eval_basic_statement(interp,
(; changes, rt, exct, refinements) = abstract_eval_basic_statement(interp,
stmt, currstate, frame)
if !has_curr_ssaflag(frame, IR_FLAG_NOTHROW)
if exct !== Union{}
Expand All @@ -3384,6 +3449,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if changes !== nothing
stoverwrite1!(currstate, changes)
end
if refinements isa SlotRefinement
apply_refinement!(𝕃ᵢ, refinements.slot, refinements.typ, currstate, changes)
elseif refinements isa Vector{Any}
for i = 1:length(refinements)
newtyp = refinements[i]
newtyp === nothing && continue
apply_refinement!(𝕃ᵢ, SlotNumber(i), newtyp, currstate, changes)
end
end
if rt === nothing
ssavaluetypes[currpc] = Any
continue
Expand Down Expand Up @@ -3422,6 +3496,20 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
nothing
end

function apply_refinement!(𝕃ᵢ::AbstractLattice, slot::SlotNumber, @nospecialize(newtyp),
currstate::VarTable, currchanges::Union{Nothing,StateUpdate})
if currchanges !== nothing && currchanges.var == slot
return # type propagation from statement (like assignment) should have the precedence
end
vtype = currstate[slot_id(slot)]
oldtyp = vtype.typ
= strictpartialorder(𝕃ᵢ)
if newtyp oldtyp
stmtupdate = StateUpdate(slot, VarState(newtyp, vtype.undef), false)
stoverwrite1!(currstate, stmtupdate)
end
end

function conditional_change(𝕃ᵢ::AbstractLattice, currstate::VarTable, condt::Conditional, then_or_else::Bool)
vtype = currstate[condt.slot]
oldtyp = vtype.typ
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,13 @@ has_extended_unionsplit(::JLTypeLattice) = false
(𝕃::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (𝕃, a, b)
(𝕃::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (𝕃, a, b)
(𝕃::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (𝕃, a, b)
tmerge(𝕃::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> tmerge(𝕃, a, b)
tmeet(𝕃::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> tmeet(𝕃, a, b)
partialorder(𝕃::AbstractLattice) = (𝕃)
strictpartialorder(𝕃::AbstractLattice) = (𝕃)
strictneqpartialorder(𝕃::AbstractLattice) = (𝕃)
join(𝕃::AbstractLattice) = tmerge(𝕃)
meet(𝕃::AbstractLattice) = tmeet(𝕃)

# Fallbacks for external packages using these methods
const fallback_lattice = InferenceLattice(BaseInferenceLattice.instance)
Expand Down
6 changes: 6 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ struct CallMeta
exct::Any
effects::Effects
info::CallInfo
refinements # ::Union{Nothing,SlotRefinement,Vector{Any}}
function CallMeta(rt::Any, exct::Any, effects::Effects, info::CallInfo,
refinements=nothing)
@nospecialize rt exct info
return new(rt, exct, effects, info, refinements)
end
end

struct NoCallInfo <: CallInfo end
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ ignorelimited(typ::LimitedAccuracy) = typ.typ
# =============

@nospecializeinfer function (lattice::InferenceLattice, @nospecialize(a), @nospecialize(b))
r = (widenlattice(lattice), ignorelimited(a), ignorelimited(b))
r || return false
(widenlattice(lattice), ignorelimited(a), ignorelimited(b)) || return false

isa(b, LimitedAccuracy) || return true

# We've found that ignorelimited(a) ⊑ ignorelimited(b).
Expand Down
Loading

0 comments on commit 339b416

Please sign in to comment.